<font size="5">Import Libraries</font>

In [1]:
import torch
import torch.nn as nn
from torchvision import transforms as T
from PIL import Image
import json
from pathlib import Path
import os
from tqdm import tqdm
import numpy as np
import cv2
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
from torchvision.datasets.coco import CocoCaptions
from IPython.display import clear_output
import clip

  from .autonotebook import tqdm as notebook_tqdm


<font size="5">Setting Dataset & Training Parameters</font>

In [2]:
# Change your input size here
input_image_size = 256

# Change your max text embedding value here
max_text_embedding_val = 1

# Change your batch size here
batch_size = 8

# Change your epoch here
epoch = 5

# Change your train image root path here
train_img_path = "./train2014/"

# Change your train annot json path here
train_annot_path = "./annotations/captions_train2014.json"

# Change your device ("cpu" or "cuda")
device = "cuda"

<font size="5">Data Preprocessing</font>

In [3]:
transform = T.Compose([
    T.Resize((input_image_size, input_image_size)),
    T.ToTensor(),
    T.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
    )
])

train_data = CocoCaptions(
    root=train_img_path,
    annFile=train_annot_path,
    transform=transform
)

loading annotations into memory...
Done (t=0.94s)
creating index...
index created!


<font size="5">Create Model</font>

In [4]:
# openai pretrained clip - defaults to ViT/B-32
OpenAIClip = OpenAIClipAdapter()

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
)

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = OpenAIClip,
    timesteps = 100,
    cond_drop_prob = 0.2
).to(device)

# diffusion_prior = nn.DataParallel(diffusion_prior)

diffusion_prior.train()

unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).to(device)

# decoder, which contains the unet and clip

decoder = Decoder(
    unet = unet,
    clip = OpenAIClip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    condition_on_text_encodings=True
).to(device)

# decoder = nn.DataParallel(decoder)

decoder.train()

Decoder(
  (clip): OpenAIClipAdapter(
    (clip): CLIP(
      (visual): VisionTransformer(
        (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
        (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (transformer): Transformer(
          (resblocks): Sequential(
            (0): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
              )
              (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=768, out_features=3072, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=3072, out_features=768, bias=True)
              )
              (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            )
            (1): ResidualAttentionBlock(
              (attn): Multih

<font size="5">Run training</font>

In [5]:
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

model, _ = clip.load("ViT-B/32", device=device)

for curr_epoch in range(epoch):
    print("Run training diffusion prior ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        image_list = []
        text_list = []
        
        for curr_idx in iter_idx:
            image, target = train_data[curr_idx]
            image = image.unsqueeze(0).to(device)
            text = clip.tokenize(target).to(device)
            text = model.encode_text(text).to(device)

            text_size = len(text)
            for i in range(text_size):
                image_list.append(image)
            
            text_list.append(text)

        text = torch.cat(text_list, dim=0)

        new_text_list = []

        for text_embed in text_list:
            text_embed -= text_embed.min(1, keepdim=True)[0]
            text_embed /= text_embed.max(1, keepdim=True)[0]
            text_embed *= max_text_embedding_val

            new_text_list.append(text_embed.type(torch.LongTensor).to(device))
            
        text = torch.cat(new_text_list, dim=0).to(device)
        image = torch.cat(image_list, dim=0).to(device)
    
        loss = diffusion_prior(text, image)
        loss.backward()

        loss = decoder(image, text) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
        loss.backward()
        break
    break

# Change your diffusion prior model save path here
diff_save_path = "./diff_prior.pt"

torch.save(diffusion_prior, diff_save_path)

Run training diffusion prior ...
Epoch 1 / 5


  0%|          | 0/10348 [00:05<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 980.00 MiB (GPU 0; 15.88 GiB total capacity; 9.94 GiB already allocated; 413.06 MiB free; 10.16 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
# torch.cuda.empty_cache()

# for curr_epoch in range(epoch):
#     print("Run training decoder ...")
#     print(f"Epoch {curr_epoch+1} / {epoch}")
    
#     for batch_idx in tqdm(idx_list):
#         if (batch_idx + batch_size) > train_size - 1:
#             iter_idx = range(batch_idx, train_size, 1)
#         else:
#             iter_idx = range(batch_idx, batch_idx+batch_size, 1)

#         image_list = []
#         text_list = []
        
#         for curr_idx in iter_idx:
#             image, target = train_data[curr_idx]
#             image = image.unsqueeze(0).to(device)
#             # text = clip.tokenize(target).to(device)
#             # text = model.encode_text(text).to(device)

#             # text_size = len(text)
#             # for i in range(text_size):
#             image_list.append(image)
            
#         #     text_list.append(text)

#         # text = torch.cat(text_list, dim=0)

#         # new_text_list = []

#         # for text_embed in text_list:
#         #     text_embed -= text_embed.min(1, keepdim=True)[0]
#         #     text_embed /= text_embed.max(1, keepdim=True)[0]
#         #     text_embed *= max_text_embedding_val

#         #     new_text_list.append(text_embed.type(torch.LongTensor).to(device))
            
#         # text = torch.cat(new_text_list, dim=0).to(device)
#         image = torch.cat(image_list, dim=0).to(device)

#         loss = decoder(image) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
#         loss.backward()
#         break
#     break

<font size="5">Save Trained Model and test on several text input</font>

In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
).to(device)

dalle2.eval()

# Change the model weight save path here (end with ".pt")
save_path = "./dalle2.pt"

# Change the test result image save path (should be a directory or folder)
test_img_save_path = "./result"

if not os.path.exists(test_img_save_path):
    os.makedirs(test_img_save_path)

torch.save(dalle2.state_dict(), save_path)

test_input = ['Closeup of bins of food that include broccoli and bread.'] # text input for the model (can be more than one)

test_img_tensors = dalle2(
    test_input,
    cond_scale = 2., # classifier free guidance strength (> 1 would strengthen the condition)
)

for test_idx, test_img_tensor in enumerate(test_img_tensors):
    # test_img_tensor -= test_img_tensor.min(1, keepdim=True)[0]
    # test_img_tensor /= test_img_tensor.max(1, keepdim=True)[0]
    test_img = T.ToPILImage()(test_img_tensor)
    test_save_path = os.path.join(test_img_save_path, f"{test_input[test_idx]}.jpg")
    test_img.save(Path(test_save_path))


<font size="5">Load DALLE2 Model (Use if you want to get the trained model)</font>

In [None]:
# device = "cuda"

# # openai pretrained clip - defaults to ViT/B-32
# OpenAIClip = OpenAIClipAdapter()

# prior_network = DiffusionPriorNetwork(
#     dim = 512,
#     depth = 6,
#     dim_head = 64,
#     heads = 8
# )

# diffusion_prior = DiffusionPrior(
#     net = prior_network,
#     clip = OpenAIClip,
#     timesteps = 100,
#     cond_drop_prob = 0.2
# ).to(device)

# # diffusion_prior = nn.DataParallel(diffusion_prior)

# unet1 = Unet(
#     dim = 128,
#     image_embed_dim = 512,
#     cond_dim = 128,
#     channels = 3,
#     dim_mults=(1, 2, 4, 8)
# )

# unet2 = Unet(
#     dim = 16,
#     image_embed_dim = 256,
#     cond_dim = 128,
#     channels = 3,
#     dim_mults = (1, 2, 4, 8, 16)
# )

# decoder = Decoder(
#     unet = (unet1, unet2),
#     image_sizes = (128, 256),
#     clip = OpenAIClip,
#     timesteps = 100,
#     image_cond_drop_prob = 0.1,
#     text_cond_drop_prob = 0.5,
#     condition_on_text_encodings = True  # set this to True if you wish to condition on text during training and sampling
# ).to(device)

# dalle2 = DALLE2(
#     prior = diffusion_prior,
#     decoder = decoder
# ).cuda()

# # Change your model path (".pt" file)
# load_model_path = "./dalle2.pt"

# dalle2.load_state_dict(torch.load(load_model_path))