<font size="5">Import Libraries</font>

In [1]:
import torch
import torch.nn as nn
from torchvision import transforms as T
from PIL import Image
import json
from pathlib import Path
import os
from tqdm import tqdm
import numpy as np
import cv2
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter

  from .autonotebook import tqdm as notebook_tqdm


<font size="5">Data Preprocessing</font>

In [2]:
# Image Transformations

# Change your input size here (must be square like 224 x 224)
input_size = (256, 256)

# Change your text embedding size here (default setting is the same with input size)
text_embedding_size = 256

transform = T.Compose([
    T.Resize(input_size),
    T.ToTensor(),
    T.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
    )
])

# Open json file
f = open("./annotations/stuff_train2017.json", "r")
json_file = json.load(f)
f.close()

categories = json_file['categories']
raw_annot = json_file['annotations']

word_to_ix = {}

for idx, category in enumerate(categories):
    word_to_ix[category['name']] = idx

num_categories = len(categories)
embeds = nn.Embedding(num_categories, text_embedding_size)

# Image folder path
train_path = Path("./train2017/")

category_index_list = []
images_list = []

for annot in tqdm(raw_annot):
    annot_cat = categories[int(annot['category_id']) - num_categories]['name']
    lookup_tensor = torch.tensor([word_to_ix[annot_cat]], dtype=torch.int)
    annot_embed = embeds(lookup_tensor)
    category_index_list.append(annot_embed.type(torch.LongTensor))

    bbox = (int(annot['bbox'][0]), int(annot['bbox'][1]), int(annot['bbox'][2]), int(annot['bbox'][3]))

    patt = "*" + str(annot['image_id']) + ".jpg"

    for img_path in train_path.glob(patt):
        basename = os.path.basename(img_path)
        basename = basename.lstrip('0')
        img = Image.open(img_path).convert('RGB')
        img = img.crop(bbox)
        img_tensor = transform(img)
        img_tensor = img_tensor.type(torch.FloatTensor)
        images_list.append(img_tensor.unsqueeze(0))
        break

  0%|          | 72/32801 [00:23<2:58:28,  3.06it/s]

<font size="5">Normalize the Text Embedding to minimum = 0</font>

In [None]:
text_raw = torch.cat(category_index_list, dim=1)
minimum_text_embedding = torch.min(text_raw)

new_category_index_list = []

if minimum_text_embedding < 0:
    for category_index in category_index_list:
        new_category_index_list.append(category_index + torch.abs(minimum_text_embedding))
else:
    new_category_index_list = category_index_list.copy()

<font size="5">Run Batch Training</font>

In [None]:
torch.cuda.empty_cache()
# openai pretrained clip - defaults to ViT/B-32
clip = OpenAIClipAdapter()

size = len(new_category_index_list)

# Change your batch size here
batch_size = 32

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    condition_on_text_encodings = False  # set this to True if you wish to condition on text during training and sampling
).cuda()

idx_list = range(0, size, batch_size)

for idx in tqdm(idx_list):
    if (idx + batch_size) > (size - 1):
        text = torch.cat(new_category_index_list[idx:], dim=0).cuda()
        images = torch.cat(images_list[idx:], dim=0).cuda()
    else:
        text = torch.cat(new_category_index_list[idx:idx+batch_size], dim=0).cuda()
        images = torch.cat(images_list[idx:idx+batch_size], dim=0).cuda()
    
    loss = diffusion_prior(text, images)
    loss.backward()

    for unet_number in (1, 2):
        loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
        loss.backward()

100%|██████████| 4101/4101 [1:07:04<00:00,  1.02it/s]


In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
).cuda()

# Change the model weight save path here (end with ".pt")
save_path = "./dalle2.pt"

# Change the test result image save path (should be a directory or folder)
test_img_save_path = "./result"

if not os.path.exists(test_img_save_path):
    os.makedirs(test_img_save_path)

torch.save(dalle2.state_dict(), save_path)

test_imgs_tensor = dalle2(
    ['blanket', 'branch'], # text input for the model (can be more than one)
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

for test_idx, test_img_tensor in enumerate(test_imgs_tensor):
    test_img = test_img_tensor.cpu().numpy().reshape(3, input_size[0], input_size[1])
    
    new_test_img = np.zeros([input_size[0], input_size[1], 3])
    new_test_img[:,:,0] = test_img[0]
    new_test_img[:,:,1] = test_img[1]
    new_test_img[:,:,2] = test_img[2]

    test_save_path = os.path.join(test_img_save_path, f"test_{test_idx}.jpg")

    cv2.imwrite(test_save_path, new_test_img)


sampling loop time step: 100%|██████████| 100/100 [00:01<00:00, 74.25it/s]
sampling loop time step: 100%|██████████| 100/100 [00:23<00:00,  4.27it/s]
sampling loop time step: 100%|██████████| 100/100 [00:10<00:00,  9.66it/s]
2it [00:37, 18.87s/it]


<font size="5">Load DALLE2 Model (Use if you want to get the trained model)</font>

In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
).cuda()

# Change your model path (".pt" file)
load_model_path = "./dalle2.pt"

dalle2.load_state_dict(torch.load(load_model_path))

<All keys matched successfully>