<font size="5">Import Libraries</font>

In [20]:
import torch
from torch.optim.lr_scheduler import ExponentialLR
from torchvision import transforms as T
from pathlib import Path
from PIL import Image
import os
from tqdm import tqdm
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dalle2_pytorch.optimizer import get_optimizer
from torchvision.datasets.coco import CocoCaptions
import pandas as pd

<font size="5">Setting Dataset & Training Parameters</font>

In [3]:
# Change your input size here
input_image_size = 256

# Change your batch size here
batch_size = 1

# Change your epoch here
epoch = 5

# Change your train image root path here
train_img_path = "./train2014/"

# Change your train annot csv path here
train_annot_path = "./coco_annotations/captions_train2014.csv"

# Change your device ("cpu" or "cuda")
device = "cuda"

# Change your diffusion prior model save path here (end with ".pth")
diff_save_path = "./diff_prior.pth"

# Change your diffusion prior model save path here (end with ".pth")
decoder_save_path = "./decoder.pth"

# Change the model weight save path here (end with ".pth")
dalle2_save_path = "./dalle2.pth"

# Change the test result image save path (should be a directory or folder)
test_img_save_path = "./result"

if not os.path.exists(test_img_save_path):
    os.makedirs(test_img_save_path)

<font size="5">Data Preprocessing</font>

In [19]:
transform = T.Compose([
    T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
    T.Resize(input_image_size),
    T.CenterCrop(input_image_size),
    T.ToTensor()
])

train_csv= pd.read_csv(train_annot_path)

train_csv = train_csv.drop_duplicates()
train_csv = train_csv.dropna()

<font size="5">Create Model</font>

In [4]:
# openai pretrained clip - defaults to ViT/B-32
OpenAIClip = OpenAIClipAdapter()

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
)

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = OpenAIClip,
    timesteps = 100,
    cond_drop_prob = 0.2
).to(device)

unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).to(device)

# decoder, which contains the unet and clip

decoder = Decoder(
    unet = unet,
    clip = OpenAIClip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    condition_on_text_encodings=True
).to(device)

if os.path.exists(diff_save_path):
    diffusion_prior.load_state_dict(torch.load(diff_save_path))

if os.path.exists(decoder_save_path):
    decoder.load_state_dict(torch.load(decoder_save_path))

<font size="5">Run training</font>

In [6]:
train_size = len(train_csv)
idx_list = range(0, train_size, batch_size)

tokenizer = SimpleTokenizer()
opt = get_optimizer(diffusion_prior.parameters())
sched = ExponentialLR(opt, gamma=0.01)

for curr_epoch in range(epoch):
    print("Run training diffusion prior ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        image_list = []
        text_list = []
        
        for curr_idx in iter_idx:
            image_name = train_csv.loc[curr_idx]['file_name']
            image_path = os.path.join(train_img_path, image_name)
            image = Image.open(image_path)
            image = transform(image)
            image = image.unsqueeze(0).to(device)

            target = list(train_csv.loc[curr_idx]['caption'])
            text = tokenizer.tokenize(target).to(device)

            text_size = len(text)
            for i in range(text_size):
                image_list.append(image)
            
            text_list.append(text)

        text = torch.cat(text_list, dim=0).to(device)
        image = torch.cat(image_list, dim=0).to(device)
    
        loss = diffusion_prior(text, image)
        opt.zero_grad()
        loss.backward()
        opt.step()

        if batch_idx != 0 and batch_idx % 100 == 0:
            torch.save(diffusion_prior.state_dict(), diff_save_path)
            sched.step()

        if batch_idx % 1000 == 0:
            print(f"loss: {loss.data}")

torch.save(diffusion_prior.state_dict(), diff_save_path)

NameError: name 'train_data' is not defined

In [12]:
opt = get_optimizer(decoder.parameters())
sched = ExponentialLR(opt, gamma=0.01)

for curr_epoch in range(epoch):
    print("Run training decoder ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        image_list = []
        text_list = []
        
        for curr_idx in iter_idx:
            image_name = train_csv.loc[curr_idx]['file_name']
            image_path = os.path.join(train_img_path, image_name)
            image = Image.open(image_path)
            image = transform(image)
            image = image.unsqueeze(0).to(device)

            target = list(train_csv.loc[curr_idx]['caption'])
            text = tokenizer.tokenize(target).to(device)

            text_size = len(text)
            for i in range(text_size):
                image_list.append(image)
            
            text_list.append(text)
            
        text = torch.cat(text_list, dim=0).to(device)
        image = torch.cat(image_list, dim=0).to(device)

        loss = decoder(image, text) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
        opt.zero_grad()
        loss.backward()
        opt.step()

        if batch_idx != 0 and batch_idx % 100 == 0:
            torch.save(decoder.state_dict(), decoder_save_path)
            sched.step()
        
        if batch_idx % 1000 == 0:
            print(f"loss: {loss.data}")

torch.save(decoder.state_dict(), decoder_save_path)

Run training decoder ...
Epoch 1 / 5


  0%|          | 1/82783 [00:00<18:45:50,  1.23it/s]

loss: 0.8478520512580872


  1%|          | 1001/82783 [11:28<41:51:19,  1.84s/it]

loss: 0.1316903829574585


  2%|▏         | 2001/82783 [22:58<39:40:01,  1.77s/it]

loss: 0.14820896089076996


  4%|▎         | 3001/82783 [34:23<40:20:06,  1.82s/it]

loss: 0.17632226645946503


  5%|▍         | 4001/82783 [45:49<42:53:26,  1.96s/it]

loss: 0.2155010998249054


  6%|▌         | 5001/82783 [57:13<42:07:00,  1.95s/it]

loss: 0.2644650340080261


  7%|▋         | 6001/82783 [1:08:38<45:45:41,  2.15s/it]

loss: 0.16106680035591125


  8%|▊         | 7001/82783 [1:20:46<47:52:21,  2.27s/it]

loss: 0.1432100236415863


 10%|▉         | 8001/82783 [1:32:14<38:56:53,  1.87s/it]

loss: 0.21223637461662292


 11%|█         | 9001/82783 [1:43:39<40:24:51,  1.97s/it]

loss: 0.21979470551013947


 12%|█▏        | 10001/82783 [1:55:07<42:38:04,  2.11s/it]

loss: 0.1237279400229454


 13%|█▎        | 11001/82783 [2:06:32<38:25:33,  1.93s/it]

loss: 0.15691719949245453


 14%|█▍        | 12001/82783 [2:18:00<36:09:38,  1.84s/it]

loss: 0.16451862454414368


 16%|█▌        | 13001/82783 [2:29:27<36:19:21,  1.87s/it]

loss: 0.21377217769622803


 17%|█▋        | 14001/82783 [2:40:55<42:32:29,  2.23s/it]

loss: 0.14279809594154358


 18%|█▊        | 15001/82783 [2:52:23<35:55:51,  1.91s/it]

loss: 0.20181158185005188


 19%|█▉        | 16001/82783 [3:03:48<34:33:28,  1.86s/it]

loss: 0.18692629039287567


 21%|██        | 17001/82783 [3:15:15<35:08:05,  1.92s/it]

loss: 0.13630299270153046


 22%|██▏       | 18001/82783 [3:26:42<34:04:24,  1.89s/it]

loss: 0.14057797193527222


 23%|██▎       | 19001/82783 [3:38:07<33:41:22,  1.90s/it]

loss: 0.13574248552322388


 24%|██▍       | 20001/82783 [3:49:34<34:21:51,  1.97s/it]

loss: 0.190187007188797


 25%|██▌       | 21001/82783 [4:00:57<30:48:02,  1.79s/it]

loss: 0.1508261114358902


 27%|██▋       | 22001/82783 [4:12:23<31:48:19,  1.88s/it]

loss: 0.18532687425613403


 28%|██▊       | 23001/82783 [4:23:50<34:37:13,  2.08s/it]

loss: 0.15921562910079956


 29%|██▉       | 24001/82783 [4:35:16<31:35:17,  1.93s/it]

loss: 0.13520236313343048


 30%|███       | 25001/82783 [4:46:41<29:10:12,  1.82s/it]

loss: 0.1500079482793808


 31%|███▏      | 26001/82783 [4:58:09<30:17:02,  1.92s/it]

loss: 0.16632650792598724


 33%|███▎      | 27001/82783 [5:09:34<28:30:00,  1.84s/it]

loss: 0.24948136508464813


 34%|███▍      | 28001/82783 [5:21:02<29:34:50,  1.94s/it]

loss: 0.19952097535133362


 35%|███▌      | 29001/82783 [5:32:28<28:23:45,  1.90s/it]

loss: 0.13934557139873505


 36%|███▌      | 30001/82783 [5:43:53<26:48:18,  1.83s/it]

loss: 0.16201719641685486


 37%|███▋      | 31001/82783 [5:55:19<26:20:42,  1.83s/it]

loss: 0.1404748409986496


 39%|███▊      | 32001/82783 [6:06:44<26:45:26,  1.90s/it]

loss: 0.18309907615184784


 39%|███▊      | 32030/82783 [6:07:02<9:02:07,  1.56it/s] 

<font size="5">Save Trained Model and test on several text input</font>

In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
).to(device)

torch.save(dalle2.state_dict(), dalle2_save_path)

test_input = ['Closeup of bins of food that include broccoli and bread.'] # text input for the model (can be more than one)

test_img_tensors = dalle2(
    test_input,
    cond_scale = 2., # classifier free guidance strength (> 1 would strengthen the condition)
)

for test_idx, test_img_tensor in enumerate(test_img_tensors):
    test_img = T.ToPILImage()(test_img_tensor)
    test_save_path = os.path.join(test_img_save_path, f"{test_input[test_idx]}.jpg")
    test_img.save(Path(test_save_path))