<font size="5">Import Libraries</font>

In [1]:
import torch
import torch.nn as nn
from torchvision import transforms as T
from PIL import Image
import json
from pathlib import Path
import os
from tqdm import tqdm
import numpy as np
import cv2
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
from torchvision.datasets.coco import CocoCaptions
from IPython.display import clear_output
import clip

  from .autonotebook import tqdm as notebook_tqdm


<font size="5">Setting Dataset & Training Parameters</font>

In [2]:
# Change your input size here
input_image_size = 512

# Change your text embedding size here (default setting is the same with input size)
text_embedding_size = 512

# Change your batch size here
batch_size = 1

# Change your epoch here
epoch = 1

# Change your train image root path here
train_img_path = "./train2014/"

# Change your train annot json path here
train_annot_path = "./annotations/captions_train2014.json"

# Change your device ("cpu" or "cuda")
device = "cuda"

<font size="5">Data Preprocessing</font>

In [3]:
transform = T.Compose([
    T.Resize((input_image_size, input_image_size)),
    T.ToTensor(),
    T.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
    )
])

train_data = CocoCaptions(
    root=train_img_path,
    annFile=train_annot_path,
    transform=transform
)

loading annotations into memory...
Done (t=0.93s)
creating index...
index created!


<font size="5">Create Model</font>

In [4]:
# openai pretrained clip - defaults to ViT/B-32
OpenAIClip = OpenAIClipAdapter()

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
)

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = OpenAIClip,
    timesteps = 100,
    cond_drop_prob = 0.2
).to(device)

# diffusion_prior = nn.DataParallel(diffusion_prior)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
)

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
)

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (256, 512),
    clip = OpenAIClip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    condition_on_text_encodings = False  # set this to True if you wish to condition on text during training and sampling
).to(device)

# decoder = nn.DataParallel(decoder)

<font size="5">Run training</font>

In [5]:
curr_size = 1
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

model, _ = clip.load("ViT-B/32", device=device)

for curr_epoch in range(epoch):
    print("Run training ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_size, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        image_list = []
        text_list = []
        
        for curr_idx in iter_idx:
            image, target = train_data[curr_idx]
            image = image.unsqueeze(0).to(device)
            text = clip.tokenize(target).to(device)
            if device == "cuda":
                text = model.encode_text(text).type(torch.cuda.LongTensor)
            else:
                text = model.encode_text(text).type(torch.LongTensor)

            text_size = len(text)
            for i in range(text_size):
                image_list.append(image)
            
            text_list.append(text)


        text = torch.cat(text_list, dim=0)
        minimum_text_embedding = torch.min(text)

        new_text_list = []

        if minimum_text_embedding < 0:
            for text_embed in text_list:
                new_text_list.append(text_embed + torch.abs(minimum_text_embedding))
        else:
            new_text_list = text_list.copy()
            
        text = torch.cat(new_text_list, dim=0).to(device)
        image = torch.cat(image_list, dim=0).to(device)
    
        loss = diffusion_prior(text, image)
        loss.backward()

        for unet_number in (1, 2):
            loss = decoder(image, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
            loss.backward()

        torch.cuda.empty_cache()

Run training ...
Epoch 1 / 1


  0%|          | 208/82783 [09:24<61:52:06,  2.70s/it]

In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
).to(device)

# Change the model weight save path here (end with ".pt")
save_path = "./dalle2.pt"

# Change the test result image save path (should be a directory or folder)
test_img_save_path = "./result"

if not os.path.exists(test_img_save_path):
    os.makedirs(test_img_save_path)

torch.save(dalle2.state_dict(), save_path)

test_imgs_tensor = dalle2(
    ['glistening morning dew on a flower petal'], # text input for the model (can be more than one)
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

for test_idx, test_img_tensor in enumerate(test_imgs_tensor):
    test_img = test_img_tensor.cpu().numpy().reshape(3, input_image_size, input_image_size)
    
    new_test_img = np.zeros([input_image_size, input_image_size, 3])
    new_test_img[:,:,0] = test_img[0]
    new_test_img[:,:,1] = test_img[1]
    new_test_img[:,:,2] = test_img[2]

    print(new_test_img.shape)

    test_save_path = os.path.join(test_img_save_path, f"test_{test_idx}.jpg")

    cv2.imwrite(test_save_path, new_test_img)


<font size="5">Load DALLE2 Model (Use if you want to get the trained model)</font>

In [None]:
# dalle2 = DALLE2(
#     prior = diffusion_prior,
#     decoder = decoder
# ).cuda()

# # Change your model path (".pt" file)
# load_model_path = "./dalle2.pt"

# dalle2.load_state_dict(torch.load(load_model_path))