<font size="5">Import Libraries</font>

In [10]:
import torch
from torchvision import transforms as T
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
import os
from tqdm import tqdm
from dalle_pytorch import DiscreteVAE, DALLE
from dalle_pytorch.tokenizer import SimpleTokenizer
from torchvision.datasets.coco import CocoCaptions

<font size="5">Setting Dataset & Training Parameters</font>

In [None]:
# Change your input size here
input_image_size = 256

# Change your batch size here
batch_size = 1

# Change your epoch here
epoch = 5

# Change your train image root path here
train_img_path = "./train2014/"

# Change your train annot json path here
train_annot_path = "./annotations/captions_train2014.json"

# Change your device ("cpu" or "cuda")
device = "cuda"

# Change your VAE save path here (ends with ".pth")
vae_save_path = "./vae.pth"

# Change your dalle model save path here (ends with ".pth")
dalle_save_path = "./dalle.pth"

# Change the test result image save path (should be a directory or folder)
test_img_save_path = "./result"

if not os.path.exists(test_img_save_path):
    os.makedirs(test_img_save_path)

<font size="5">Data Preprocessing</font>

In [13]:
transform = T.Compose([
    T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
    T.Resize(input_image_size),
    T.CenterCrop(input_image_size),
    T.ToTensor()
])

train_data = CocoCaptions(
    root=train_img_path,
    annFile=train_annot_path,
    transform=transform
)

loading annotations into memory...
Done (t=0.92s)
creating index...
index created!


<font size="5">Create VAE Model</font>

In [14]:
vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,           # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,        # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,       # codebook dimension
    hidden_dim = 64,          # hidden dimension
    num_resnet_blocks = 1,    # number of resnet blocks
    temperature = 0.9,        # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
).to(device)

<font size="5">Train VAE Model</font>

In [16]:
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

tokenizer = SimpleTokenizer()
opt = Adam(
    vae.parameters(),
    lr = 3e-4,
    weight_decay=0.01,
    betas = (0.9, 0.999)
)
sched = ReduceLROnPlateau(
    opt,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=10,
    min_lr=1e-6,
    verbose=True,
)

for curr_epoch in range(epoch):
    print("Run training discrete vae ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        image_list = []
        
        for curr_idx in iter_idx:
            image, _ = train_data[curr_idx]
            image = image.unsqueeze(0)

            image_list.append(image)

        images = torch.cat(image_list, dim=0).type(torch.FloatTensor).to(device)

        opt.zero_grad()
        loss = vae(images, return_loss = True)
        loss.backward()
        opt.step()

        if batch_idx != 0 and batch_idx % 100 == 0:
            torch.save(vae.state_dict(), vae_save_path)
            sched.step(loss)
        
        if batch_idx % 1000 == 0:
            print(f"loss: {loss.data}")

torch.save(vae.state_dict(), vae_save_path)

Run training discrete vae ...
Epoch 1 / 5


  0%|          | 3/82783 [00:00<1:02:04, 22.23it/s]

loss: 0.4792012572288513


  1%|          | 1005/82783 [00:38<55:30, 24.55it/s] 

loss: 0.2182229459285736


  2%|▏         | 2004/82783 [01:16<54:20, 24.77it/s]  

loss: 0.6841167211532593


  3%|▎         | 2804/82783 [01:47<57:05, 23.34it/s]  

Epoch 00028: reducing learning rate of group 0 to 1.5000e-04.


  4%|▎         | 3005/82783 [01:55<56:06, 23.70it/s]

loss: 0.3768330216407776


  5%|▍         | 4005/82783 [02:33<53:37, 24.49it/s]  

loss: 0.42622512578964233


  6%|▌         | 4906/82783 [03:07<53:59, 24.04it/s]

Epoch 00049: reducing learning rate of group 0 to 7.5000e-05.


  6%|▌         | 5005/82783 [03:11<54:05, 23.96it/s]

loss: 0.33252251148223877


  7%|▋         | 6005/82783 [03:49<51:54, 24.66it/s]

loss: 0.19035694003105164


  8%|▊         | 7006/82783 [04:26<51:35, 24.48it/s]

loss: 0.06842350214719772


 10%|▉         | 8006/82783 [05:05<51:01, 24.43it/s]  

Epoch 00080: reducing learning rate of group 0 to 3.7500e-05.
loss: 0.29023900628089905


 11%|█         | 9005/82783 [05:43<51:21, 23.94it/s]

loss: 0.366156667470932


 12%|█▏        | 10005/82783 [06:21<49:39, 24.43it/s]

loss: 0.3892052173614502


 12%|█▏        | 10104/82783 [06:24<51:03, 23.72it/s]

Epoch 00101: reducing learning rate of group 0 to 1.8750e-05.


 13%|█▎        | 11005/82783 [06:59<49:10, 24.33it/s]

loss: 0.1416725218296051


 15%|█▍        | 12004/82783 [07:36<50:03, 23.57it/s]

loss: 0.2335434854030609


 15%|█▍        | 12205/82783 [07:44<48:20, 24.34it/s]

Epoch 00122: reducing learning rate of group 0 to 9.3750e-06.


 16%|█▌        | 13004/82783 [08:15<49:39, 23.42it/s]

loss: 0.42001184821128845


 17%|█▋        | 14006/82783 [08:53<48:02, 23.86it/s]

loss: 0.49545201659202576


 17%|█▋        | 14304/82783 [09:04<47:16, 24.14it/s]

Epoch 00143: reducing learning rate of group 0 to 4.6875e-06.


 18%|█▊        | 15004/82783 [09:31<46:13, 24.44it/s]

loss: 0.31533607840538025


 19%|█▉        | 16004/82783 [10:08<45:58, 24.20it/s]

loss: 0.16962222754955292


 20%|█▉        | 16406/82783 [10:24<46:09, 23.97it/s]

Epoch 00164: reducing learning rate of group 0 to 2.3437e-06.


 21%|██        | 17004/82783 [10:46<44:41, 24.53it/s]

loss: 0.2574365437030792


 22%|██▏       | 18006/82783 [11:24<43:36, 24.76it/s]

loss: 0.3019389808177948


 22%|██▏       | 18504/82783 [11:43<43:45, 24.48it/s]

Epoch 00185: reducing learning rate of group 0 to 1.1719e-06.


 23%|██▎       | 19006/82783 [12:02<44:07, 24.09it/s]

loss: 0.22093415260314941


 24%|██▍       | 20005/82783 [12:40<44:17, 23.62it/s]

loss: 0.3270815312862396


 25%|██▍       | 20606/82783 [13:03<43:41, 23.72it/s]

Epoch 00206: reducing learning rate of group 0 to 1.0000e-06.


 25%|██▌       | 21006/82783 [13:18<41:53, 24.57it/s]

loss: 0.6719629764556885


 27%|██▋       | 22006/82783 [13:56<42:26, 23.87it/s]

loss: 0.4299389719963074


 28%|██▊       | 23006/82783 [14:34<41:44, 23.87it/s]

loss: 0.19703738391399384


 29%|██▉       | 24002/82783 [15:11<41:07, 23.83it/s]

loss: 0.4664909243583679


 30%|███       | 25005/82783 [15:49<39:12, 24.56it/s]

loss: 0.303005188703537


 31%|███▏      | 26005/82783 [16:27<37:52, 24.99it/s]

loss: 0.21167854964733124


 33%|███▎      | 27003/82783 [17:04<36:10, 25.70it/s]

loss: 0.12718932330608368


 34%|███▍      | 28006/82783 [17:41<36:32, 24.98it/s]

loss: 0.0907534584403038


 35%|███▌      | 29005/82783 [18:19<36:38, 24.46it/s]

loss: 0.4405573308467865


 36%|███▌      | 30004/82783 [18:57<35:03, 25.09it/s]

loss: 0.4872623085975647


 37%|███▋      | 31004/82783 [19:35<36:11, 23.85it/s]

loss: 0.1918952316045761


 39%|███▊      | 32004/82783 [20:13<35:12, 24.04it/s]

loss: 0.4147369861602783


 40%|███▉      | 33005/82783 [20:51<34:22, 24.14it/s]

loss: 0.24289242923259735


 41%|████      | 34005/82783 [21:29<34:35, 23.50it/s]

loss: 0.30056869983673096


 42%|████▏     | 35006/82783 [22:07<33:51, 23.52it/s]

loss: 0.24452996253967285


 43%|████▎     | 36004/82783 [22:45<32:45, 23.80it/s]

loss: 0.2708534598350525


 45%|████▍     | 37005/82783 [23:23<31:37, 24.12it/s]

loss: 0.43998128175735474


 46%|████▌     | 38005/82783 [24:01<30:59, 24.08it/s]

loss: 0.5151618719100952


 47%|████▋     | 39006/82783 [24:39<31:00, 23.52it/s]

loss: 0.12097372859716415


 48%|████▊     | 40005/82783 [25:17<29:33, 24.12it/s]

loss: 0.15393434464931488


 50%|████▉     | 41005/82783 [25:55<28:42, 24.25it/s]

loss: 0.2519097328186035


 51%|█████     | 42005/82783 [26:33<28:12, 24.09it/s]

loss: 0.3121302127838135


 52%|█████▏    | 43006/82783 [27:11<27:08, 24.43it/s]

loss: 0.26039010286331177


 53%|█████▎    | 44006/82783 [27:49<25:48, 25.03it/s]

loss: 0.3563811480998993


 54%|█████▍    | 45006/82783 [28:27<26:03, 24.16it/s]

loss: 0.1872224062681198


 56%|█████▌    | 46006/82783 [29:05<24:46, 24.74it/s]

loss: 0.17189344763755798


 57%|█████▋    | 47003/82783 [29:43<24:20, 24.51it/s]

loss: 0.22667360305786133


 58%|█████▊    | 48005/82783 [30:21<23:35, 24.57it/s]

loss: 0.18011391162872314


 59%|█████▉    | 49004/82783 [30:59<22:44, 24.75it/s]

loss: 0.4304956793785095


 60%|██████    | 50005/82783 [31:37<22:05, 24.72it/s]

loss: 0.3135997951030731


 62%|██████▏   | 51006/82783 [32:15<21:53, 24.19it/s]

loss: 0.6089404225349426


 63%|██████▎   | 52004/82783 [32:53<21:06, 24.30it/s]

loss: 0.226525217294693


 64%|██████▍   | 53006/82783 [33:31<20:12, 24.57it/s]

loss: 0.28646373748779297


 65%|██████▌   | 54006/82783 [34:08<19:23, 24.73it/s]

loss: 0.3154357671737671


 66%|██████▋   | 55004/82783 [34:46<18:58, 24.40it/s]

loss: 0.1944580376148224


 68%|██████▊   | 56006/82783 [35:24<18:12, 24.51it/s]

loss: 0.3177455961704254


 69%|██████▉   | 57005/82783 [36:02<17:22, 24.72it/s]

loss: 0.20344150066375732


 70%|███████   | 58005/82783 [36:40<16:46, 24.62it/s]

loss: 0.26109063625335693


 71%|███████▏  | 59005/82783 [37:18<16:38, 23.81it/s]

loss: 0.40666383504867554


 72%|███████▏  | 60006/82783 [37:56<15:38, 24.28it/s]

loss: 0.32443785667419434


 74%|███████▎  | 61006/82783 [38:34<15:03, 24.10it/s]

loss: 0.4113595485687256


 75%|███████▍  | 62004/82783 [39:12<14:36, 23.71it/s]

loss: 0.21808096766471863


 76%|███████▌  | 63006/82783 [39:50<14:01, 23.50it/s]

loss: 0.486672967672348


 77%|███████▋  | 64006/82783 [40:28<12:30, 25.02it/s]

loss: 0.19768282771110535


 79%|███████▊  | 65006/82783 [41:06<12:16, 24.14it/s]

loss: 0.1865364909172058


 80%|███████▉  | 66006/82783 [41:44<11:21, 24.62it/s]

loss: 0.11751298606395721


 81%|████████  | 67005/82783 [42:22<10:52, 24.19it/s]

loss: 0.41332006454467773


 82%|████████▏ | 68005/82783 [43:00<09:49, 25.06it/s]

loss: 0.16617222130298615


 83%|████████▎ | 69004/82783 [43:37<09:37, 23.87it/s]

loss: 0.22339344024658203


 85%|████████▍ | 70006/82783 [44:16<08:59, 23.67it/s]

loss: 0.07793080061674118


 86%|████████▌ | 71005/82783 [44:54<07:46, 25.23it/s]

loss: 1.0516767501831055


 87%|████████▋ | 72004/82783 [45:31<07:31, 23.86it/s]

loss: 0.382320761680603


 88%|████████▊ | 73004/82783 [46:09<06:32, 24.93it/s]

loss: 0.21938340365886688


 89%|████████▉ | 74004/82783 [46:46<05:54, 24.78it/s]

loss: 0.23967763781547546


 91%|█████████ | 75004/82783 [47:24<05:11, 24.93it/s]

loss: 0.27625763416290283


 92%|█████████▏| 76005/82783 [48:02<04:31, 25.00it/s]

loss: 0.3032580018043518


 93%|█████████▎| 77004/82783 [48:40<04:07, 23.35it/s]

loss: 0.38872236013412476


 94%|█████████▍| 78006/82783 [49:18<03:15, 24.44it/s]

loss: 0.32130393385887146


 95%|█████████▌| 79005/82783 [49:56<02:37, 23.99it/s]

loss: 0.41097089648246765


 97%|█████████▋| 80004/82783 [50:34<01:55, 24.13it/s]

loss: 0.3159750998020172


 98%|█████████▊| 81004/82783 [51:13<01:13, 24.26it/s]

loss: 0.14754584431648254


 99%|█████████▉| 82006/82783 [51:51<00:31, 24.78it/s]

loss: 0.5548713803291321


100%|██████████| 82783/82783 [52:18<00:00, 26.38it/s]


Run training discrete vae ...
Epoch 2 / 5


  0%|          | 3/82783 [00:00<46:24, 29.73it/s]

loss: 0.4032898545265198


  1%|          | 1007/82783 [00:35<49:06, 27.76it/s]

loss: 0.21841195225715637


  2%|▏         | 2005/82783 [01:09<48:34, 27.72it/s]

loss: 0.686534583568573


  4%|▎         | 3006/82783 [01:44<52:44, 25.21it/s]

loss: 0.3833394944667816


  5%|▍         | 4004/82783 [02:18<48:48, 26.90it/s]

loss: 0.42584094405174255


  6%|▌         | 5004/82783 [02:53<49:22, 26.26it/s]

loss: 0.3324825167655945


  7%|▋         | 6004/82783 [03:28<49:03, 26.09it/s]

loss: 0.189140185713768


  8%|▊         | 7005/82783 [04:02<45:16, 27.90it/s]

loss: 0.06934300810098648


 10%|▉         | 8004/82783 [04:37<47:14, 26.38it/s]

loss: 0.29024097323417664


 11%|█         | 9006/82783 [05:12<45:52, 26.80it/s]

loss: 0.3657972812652588


 12%|█▏        | 10005/82783 [05:46<45:42, 26.54it/s]

loss: 0.3882780969142914


 13%|█▎        | 11004/82783 [06:21<45:08, 26.50it/s]

loss: 0.14180468022823334


 15%|█▍        | 12007/82783 [06:56<43:33, 27.08it/s]

loss: 0.23321208357810974


 16%|█▌        | 13006/82783 [07:30<42:03, 27.65it/s]

loss: 0.42162299156188965


 17%|█▋        | 14003/82783 [08:04<44:21, 25.85it/s]

loss: 0.4977042078971863


 18%|█▊        | 15004/82783 [08:39<45:20, 24.91it/s]

loss: 0.31732743978500366


 19%|█▉        | 16006/82783 [09:15<40:52, 27.23it/s]

loss: 0.16976770758628845


 21%|██        | 17005/82783 [09:52<40:06, 27.33it/s]

loss: 0.2566562592983246


 22%|██▏       | 18006/82783 [10:29<42:28, 25.42it/s]

loss: 0.3000287413597107


 23%|██▎       | 19005/82783 [11:06<43:23, 24.50it/s]

loss: 0.22131852805614471


 24%|██▍       | 20004/82783 [11:44<43:51, 23.85it/s]

loss: 0.3282475471496582


 25%|██▌       | 21005/82783 [12:22<42:25, 24.27it/s]

loss: 0.6688704490661621


 27%|██▋       | 22006/82783 [12:59<41:39, 24.31it/s]

loss: 0.42890164256095886


 28%|██▊       | 23005/82783 [13:37<40:26, 24.64it/s]

loss: 0.19660088419914246


 29%|██▉       | 24005/82783 [14:15<39:13, 24.97it/s]

loss: 0.4640920162200928


 30%|███       | 25006/82783 [14:52<37:50, 25.44it/s]

loss: 0.30257344245910645


 31%|███▏      | 26004/82783 [15:29<36:59, 25.58it/s]

loss: 0.21188992261886597


 33%|███▎      | 27006/82783 [16:06<37:40, 24.68it/s]

loss: 0.12677127122879028


 34%|███▍      | 28006/82783 [16:42<36:27, 25.04it/s]

loss: 0.09072913974523544


 35%|███▌      | 29004/82783 [17:20<36:18, 24.69it/s]

loss: 0.4393555521965027


 36%|███▌      | 30006/82783 [17:57<35:49, 24.55it/s]

loss: 0.4895300269126892


 37%|███▋      | 31004/82783 [18:34<35:54, 24.03it/s]

loss: 0.1910790205001831


 39%|███▊      | 32005/82783 [19:12<35:09, 24.07it/s]

loss: 0.4145066440105438


 40%|███▉      | 33005/82783 [19:51<35:06, 23.63it/s]

loss: 0.2434176206588745


 41%|████      | 34004/82783 [20:31<34:46, 23.38it/s]

loss: 0.3004128932952881


 42%|████▏     | 35004/82783 [21:09<33:53, 23.50it/s]

loss: 0.24432465434074402


 43%|████▎     | 36004/82783 [21:48<32:18, 24.14it/s]

loss: 0.2716233730316162


 45%|████▍     | 37005/82783 [22:27<32:10, 23.71it/s]

loss: 0.4387552738189697


 46%|████▌     | 38005/82783 [23:07<32:55, 22.67it/s]

loss: 0.5149331092834473


 47%|████▋     | 39006/82783 [23:45<29:49, 24.47it/s]

loss: 0.12136490643024445


 48%|████▊     | 40004/82783 [24:23<29:19, 24.31it/s]

loss: 0.15372894704341888


 50%|████▉     | 41006/82783 [25:02<28:36, 24.34it/s]

loss: 0.25231844186782837


 51%|█████     | 42004/82783 [25:40<27:22, 24.83it/s]

loss: 0.3127586841583252


 52%|█████▏    | 43006/82783 [26:18<27:15, 24.33it/s]

loss: 0.2601202726364136


 53%|█████▎    | 44004/82783 [26:55<25:44, 25.11it/s]

loss: 0.35730409622192383


 54%|█████▍    | 45004/82783 [27:32<25:59, 24.22it/s]

loss: 0.18755653500556946


 56%|█████▌    | 46006/82783 [28:10<25:44, 23.81it/s]

loss: 0.17212939262390137


 57%|█████▋    | 47006/82783 [28:47<24:07, 24.72it/s]

loss: 0.2276177853345871


 58%|█████▊    | 48006/82783 [29:24<22:41, 25.54it/s]

loss: 0.17984166741371155


 59%|█████▉    | 49006/82783 [30:02<23:35, 23.86it/s]

loss: 0.4309530258178711


 60%|██████    | 50004/82783 [30:40<21:47, 25.07it/s]

loss: 0.3133019804954529


 62%|██████▏   | 51004/82783 [31:17<21:06, 25.10it/s]

loss: 0.6099177598953247


 63%|██████▎   | 52006/82783 [31:53<20:57, 24.48it/s]

loss: 0.2267472743988037


 64%|██████▍   | 53005/82783 [32:32<20:56, 23.69it/s]

loss: 0.28648024797439575


 65%|██████▌   | 54005/82783 [33:10<19:57, 24.03it/s]

loss: 0.31566721200942993


 66%|██████▋   | 55006/82783 [33:49<18:59, 24.37it/s]

loss: 0.19454555213451385


 68%|██████▊   | 56005/82783 [34:28<18:34, 24.03it/s]

loss: 0.3184148669242859


 69%|██████▉   | 57005/82783 [35:06<17:07, 25.09it/s]

loss: 0.20326510071754456


 70%|███████   | 58004/82783 [35:43<17:20, 23.80it/s]

loss: 0.2616921365261078


 71%|███████▏  | 59003/82783 [36:21<17:02, 23.25it/s]

loss: 0.406127393245697


 72%|███████▏  | 60005/82783 [37:00<15:02, 25.23it/s]

loss: 0.32388851046562195


 74%|███████▎  | 61004/82783 [37:38<16:11, 22.42it/s]

loss: 0.41158074140548706


 75%|███████▍  | 62006/82783 [38:15<13:50, 25.02it/s]

loss: 0.2181495726108551


 76%|███████▌  | 63006/82783 [38:53<13:41, 24.09it/s]

loss: 0.48632869124412537


 77%|███████▋  | 64005/82783 [39:30<12:13, 25.62it/s]

loss: 0.19716668128967285


 79%|███████▊  | 65005/82783 [40:07<11:42, 25.31it/s]

loss: 0.1865205019712448


 80%|███████▉  | 66006/82783 [40:44<11:23, 24.56it/s]

loss: 0.11775366961956024


 81%|████████  | 67006/82783 [41:22<10:47, 24.36it/s]

loss: 0.4122638404369354


 82%|████████▏ | 68005/82783 [41:58<09:36, 25.62it/s]

loss: 0.1659170389175415


 83%|████████▎ | 69005/82783 [42:35<09:25, 24.38it/s]

loss: 0.22377154231071472


 85%|████████▍ | 70005/82783 [43:12<08:34, 24.85it/s]

loss: 0.07795368134975433


 86%|████████▌ | 71004/82783 [43:50<07:51, 24.96it/s]

loss: 1.049720287322998


 87%|████████▋ | 72004/82783 [44:29<07:12, 24.93it/s]

loss: 0.38188445568084717


 88%|████████▊ | 73006/82783 [45:08<06:47, 23.98it/s]

loss: 0.2188907116651535


 89%|████████▉ | 74006/82783 [45:46<06:13, 23.49it/s]

loss: 0.24002131819725037


 91%|█████████ | 75004/82783 [46:25<05:19, 24.32it/s]

loss: 0.27596911787986755


 92%|█████████▏| 76005/82783 [47:03<04:51, 23.23it/s]

loss: 0.30283698439598083


 93%|█████████▎| 77006/82783 [47:43<04:12, 22.87it/s]

loss: 0.38826197385787964


 93%|█████████▎| 77249/82783 [47:53<03:46, 24.42it/s]

<font size="5">Create DALLE Model</font>

In [None]:
tokenizer = SimpleTokenizer()

dalle = DALLE(
    dim = 1024,
    vae = vae,                                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,                    # vocab size for text
    text_seq_len = 256,                         # text sequence length
    depth = 1,                                  # should aim to be 64
    heads = 16,                                 # attention heads
    dim_head = 64,                              # attention head dimension
    attn_dropout = 0.1,                         # attention dropout
    ff_dropout = 0.1                            # feedforward dropout
).to(device)

<font size="5">Train DALLE Model</font>

In [None]:
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

opt = Adam(
    dalle.parameters(),
    lr = 3e-4,
    weight_decay=0.01,
    betas = (0.9, 0.999)
)
sched = ReduceLROnPlateau(
    opt,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=10,
    min_lr=1e-6,
    verbose=True,
)

for curr_epoch in range(epoch):
    print("Run training dalle ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        image_list = []
        text_list = []
        
        for curr_idx in iter_idx:
            image, target = train_data[curr_idx]
            image = image.unsqueeze(0)
            text = tokenizer.tokenize(target)

            text_size = len(text)
            for i in range(text_size):
                image_list.append(image)
            
            text_list.append(text)

        text = torch.cat(text_list, dim=0).to(device)
        image = torch.cat(image_list, dim=0).to(device)

        opt.zero_grad()
        loss = dalle(text, image, return_loss = True)
        loss.backward()
        opt.step()

        if batch_idx != 0 and batch_idx % 100 == 0:
            torch.save(dalle.state_dict(), dalle_save_path)
            sched.step(loss)
        
        if batch_idx % 1000 == 0:
            print(f"loss: {loss.data}")

torch.save(dalle.state_dict(), dalle_save_path)

<font size="5">Test DALLE model with several inputs</font>

In [None]:
test_inputs = ['Closeup of bins of food that include broccoli and bread.'] # text input for the model (can be more than one)

text = tokenizer.tokenize(test_inputs).to(device)

test_img_tensors = dalle.generate_images(text)

for test_idx, test_img_tensor in enumerate(test_img_tensors):
    test_img = T.ToPILImage()(test_img_tensor)
    test_save_path = os.path.join(test_img_save_path, f"{test_inputs[test_idx]}.jpg")
    test_img.save(Path(test_save_path))