<font size="5">Import Libraries</font>

In [1]:
import torch
from torch.optim.lr_scheduler import ExponentialLR
from torchvision import transforms as T
from pathlib import Path
import os
from tqdm import tqdm
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dalle2_pytorch.optimizer import get_optimizer
from torchvision.datasets.coco import CocoCaptions

  from .autonotebook import tqdm as notebook_tqdm


<font size="5">Setting Dataset & Training Parameters</font>

In [10]:
# Change your input size here
input_image_size = 256

# Change your batch size here
batch_size = 4

# Change your epoch here
epoch = 5

# Change your train image root path here
train_img_path = "./train2014/"

# Change your train annot json path here
train_annot_path = "./coco_annotations/captions_train2014.json"

# Change your device ("cpu" or "cuda")
device = "cuda"

# Change your diffusion prior model save path here (end with ".pth")
diff_save_path = "./diff_prior.pth"

# Change your diffusion prior model save path here (end with ".pth")
decoder_save_path = "./decoder.pth"

# Change the model weight save path here (end with ".pth")
dalle2_save_path = "./dalle2.pth"

# Change the test result image save path (should be a directory or folder)
test_img_save_path = "./result"

if not os.path.exists(test_img_save_path):
    os.makedirs(test_img_save_path)

<font size="5">Data Preprocessing</font>

In [11]:
transform = T.Compose([
    T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
    T.Resize(input_image_size),
    T.CenterCrop(input_image_size),
    T.ToTensor()
])

train_data = CocoCaptions(
    root=train_img_path,
    annFile=train_annot_path,
    transform=transform
)

loading annotations into memory...
Done (t=1.13s)
creating index...
index created!


<font size="5">Create Model</font>

In [4]:
# openai pretrained clip - defaults to ViT/B-32
OpenAIClip = OpenAIClipAdapter()

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
)

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = OpenAIClip,
    timesteps = 100,
    cond_drop_prob = 0.2
).to(device)

unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).to(device)

# decoder, which contains the unet and clip

decoder = Decoder(
    unet = unet,
    clip = OpenAIClip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    condition_on_text_encodings=True
).to(device)

if os.path.exists(diff_save_path):
    diffusion_prior.load_state_dict(torch.load(diff_save_path))

if os.path.exists(decoder_save_path):
    decoder.load_state_dict(torch.load(decoder_save_path))

<font size="5">Run training</font>

In [6]:
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

tokenizer = SimpleTokenizer()
opt = get_optimizer(diffusion_prior.parameters())
sched = ExponentialLR(opt, gamma=0.01)

for curr_epoch in range(epoch):
    print("Run training diffusion prior ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        batch_len = 0
        total_loss = torch.tensor(0., device=device)
        
        for curr_idx in iter_idx:
            image, target = train_data[curr_idx]
            image = image.unsqueeze(0).to(device)

            texts = tokenizer.tokenize(target).to(device)

            for text in texts:
                if total_loss == torch.tensor(0., device=device):
                    total_loss = diffusion_prior(text.unsqueeze(0), image)
                else:
                    total_loss += diffusion_prior(text.unsqueeze(0), image)
                batch_len += 1
                
        avg_loss = total_loss / batch_len

        opt.zero_grad()
        avg_loss.backward()
        opt.step()

        if batch_idx != 0 and batch_idx % 100 == 0:
            torch.save(diffusion_prior.state_dict(), diff_save_path)
            sched.step()

        if batch_idx % 100 == 0:
            print(f"average loss: {avg_loss.data}")

torch.save(diffusion_prior.state_dict(), diff_save_path)

NameError: name 'train_data' is not defined

In [12]:
opt = get_optimizer(decoder.parameters())
sched = ExponentialLR(opt, gamma=0.01)

for curr_epoch in range(epoch):
    print("Run training decoder ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        batch_len = 0
        total_loss = torch.tensor(0., device=device)
        
        for curr_idx in iter_idx:
            image, target = train_data[curr_idx]
            image = image.unsqueeze(0).to(device)

            texts = tokenizer.tokenize(target).to(device)

            for text in texts:
                if total_loss == torch.tensor(0., device=device):
                    total_loss = decoder(text.unsqueeze(0), image)
                else:
                    total_loss += decoder(text.unsqueeze(0), image)
                batch_len += 1
                
        avg_loss = total_loss / batch_len

        opt.zero_grad()
        avg_loss.backward()
        opt.step()

        if batch_idx != 0 and batch_idx % 100 == 0:
            torch.save(decoder.state_dict(), decoder_save_path)
            sched.step()
        
        if batch_idx % 100 == 0:
            print(f"average loss: {avg_loss.data}")

torch.save(decoder.state_dict(), decoder_save_path)

Run training decoder ...
Epoch 1 / 5


  0%|          | 1/82783 [00:00<18:45:50,  1.23it/s]

loss: 0.8478520512580872


  1%|          | 1001/82783 [11:28<41:51:19,  1.84s/it]

loss: 0.1316903829574585


  2%|▏         | 2001/82783 [22:58<39:40:01,  1.77s/it]

loss: 0.14820896089076996


  4%|▎         | 3001/82783 [34:23<40:20:06,  1.82s/it]

loss: 0.17632226645946503


  5%|▍         | 4001/82783 [45:49<42:53:26,  1.96s/it]

loss: 0.2155010998249054


  6%|▌         | 5001/82783 [57:13<42:07:00,  1.95s/it]

loss: 0.2644650340080261


  7%|▋         | 6001/82783 [1:08:38<45:45:41,  2.15s/it]

loss: 0.16106680035591125


  8%|▊         | 7001/82783 [1:20:46<47:52:21,  2.27s/it]

loss: 0.1432100236415863


 10%|▉         | 8001/82783 [1:32:14<38:56:53,  1.87s/it]

loss: 0.21223637461662292


 11%|█         | 9001/82783 [1:43:39<40:24:51,  1.97s/it]

loss: 0.21979470551013947


 12%|█▏        | 10001/82783 [1:55:07<42:38:04,  2.11s/it]

loss: 0.1237279400229454


 13%|█▎        | 11001/82783 [2:06:32<38:25:33,  1.93s/it]

loss: 0.15691719949245453


 14%|█▍        | 12001/82783 [2:18:00<36:09:38,  1.84s/it]

loss: 0.16451862454414368


 16%|█▌        | 13001/82783 [2:29:27<36:19:21,  1.87s/it]

loss: 0.21377217769622803


 17%|█▋        | 14001/82783 [2:40:55<42:32:29,  2.23s/it]

loss: 0.14279809594154358


 18%|█▊        | 15001/82783 [2:52:23<35:55:51,  1.91s/it]

loss: 0.20181158185005188


 19%|█▉        | 16001/82783 [3:03:48<34:33:28,  1.86s/it]

loss: 0.18692629039287567


 21%|██        | 17001/82783 [3:15:15<35:08:05,  1.92s/it]

loss: 0.13630299270153046


 22%|██▏       | 18001/82783 [3:26:42<34:04:24,  1.89s/it]

loss: 0.14057797193527222


 23%|██▎       | 19001/82783 [3:38:07<33:41:22,  1.90s/it]

loss: 0.13574248552322388


 24%|██▍       | 20001/82783 [3:49:34<34:21:51,  1.97s/it]

loss: 0.190187007188797


 25%|██▌       | 21001/82783 [4:00:57<30:48:02,  1.79s/it]

loss: 0.1508261114358902


 27%|██▋       | 22001/82783 [4:12:23<31:48:19,  1.88s/it]

loss: 0.18532687425613403


 28%|██▊       | 23001/82783 [4:23:50<34:37:13,  2.08s/it]

loss: 0.15921562910079956


 29%|██▉       | 24001/82783 [4:35:16<31:35:17,  1.93s/it]

loss: 0.13520236313343048


 30%|███       | 25001/82783 [4:46:41<29:10:12,  1.82s/it]

loss: 0.1500079482793808


 31%|███▏      | 26001/82783 [4:58:09<30:17:02,  1.92s/it]

loss: 0.16632650792598724


 33%|███▎      | 27001/82783 [5:09:34<28:30:00,  1.84s/it]

loss: 0.24948136508464813


 34%|███▍      | 28001/82783 [5:21:02<29:34:50,  1.94s/it]

loss: 0.19952097535133362


 35%|███▌      | 29001/82783 [5:32:28<28:23:45,  1.90s/it]

loss: 0.13934557139873505


 36%|███▌      | 30001/82783 [5:43:53<26:48:18,  1.83s/it]

loss: 0.16201719641685486


 37%|███▋      | 31001/82783 [5:55:19<26:20:42,  1.83s/it]

loss: 0.1404748409986496


 39%|███▊      | 32001/82783 [6:06:44<26:45:26,  1.90s/it]

loss: 0.18309907615184784


 40%|███▉      | 33001/82783 [6:18:06<28:51:06,  2.09s/it]

loss: 0.21096989512443542


 41%|████      | 34001/82783 [6:29:29<25:37:51,  1.89s/it]

loss: 0.19746223092079163


 42%|████▏     | 35001/82783 [6:40:55<24:10:29,  1.82s/it]

loss: 0.1665845811367035


 43%|████▎     | 36001/82783 [6:52:24<26:02:14,  2.00s/it]

loss: 0.1433427929878235


 45%|████▍     | 37001/82783 [7:03:50<23:57:25,  1.88s/it]

loss: 0.1420016884803772


 46%|████▌     | 38001/82783 [7:15:15<24:03:37,  1.93s/it]

loss: 0.146202951669693


 47%|████▋     | 39001/82783 [7:26:38<22:00:31,  1.81s/it]

loss: 0.17722220718860626


 48%|████▊     | 40001/82783 [7:38:07<21:51:49,  1.84s/it]

loss: 0.14618706703186035


 50%|████▉     | 41001/82783 [7:49:31<21:34:42,  1.86s/it]

loss: 0.12324140220880508


 51%|█████     | 42001/82783 [8:00:56<20:20:37,  1.80s/it]

loss: 0.18844908475875854


 52%|█████▏    | 43001/82783 [8:12:20<19:33:44,  1.77s/it]

loss: 0.1459890753030777


 53%|█████▎    | 44001/82783 [8:23:46<20:39:50,  1.92s/it]

loss: 0.17045307159423828


 54%|█████▍    | 45001/82783 [8:35:11<20:10:30,  1.92s/it]

loss: 0.17558257281780243


 56%|█████▌    | 46001/82783 [8:46:37<18:48:56,  1.84s/it]

loss: 0.12048511952161789


 57%|█████▋    | 47001/82783 [8:58:02<18:12:25,  1.83s/it]

loss: 0.18470372259616852


 58%|█████▊    | 48001/82783 [9:09:26<17:41:25,  1.83s/it]

loss: 0.17088569700717926


 59%|█████▉    | 49001/82783 [9:20:52<17:27:54,  1.86s/it]

loss: 0.16614121198654175


 60%|██████    | 50001/82783 [9:32:16<16:49:16,  1.85s/it]

loss: 0.14262811839580536


 62%|██████▏   | 51001/82783 [9:43:42<16:47:59,  1.90s/it]

loss: 0.1875983476638794


 63%|██████▎   | 52001/82783 [9:55:08<16:08:59,  1.89s/it]

loss: 0.13361525535583496


 64%|██████▍   | 53001/82783 [10:06:34<15:52:02,  1.92s/it]

loss: 0.13841943442821503


 65%|██████▌   | 54001/82783 [10:17:58<15:32:08,  1.94s/it]

loss: 0.14640529453754425


 66%|██████▋   | 55001/82783 [10:29:24<14:01:46,  1.82s/it]

loss: 0.128090962767601


 68%|██████▊   | 56001/82783 [10:40:50<13:47:17,  1.85s/it]

loss: 0.1412775069475174


 69%|██████▉   | 57001/82783 [10:52:18<13:21:21,  1.86s/it]

loss: 0.15001873672008514


 70%|███████   | 58001/82783 [11:03:43<12:30:53,  1.82s/it]

loss: 0.1538134217262268


 71%|███████▏  | 59001/82783 [11:15:07<12:33:12,  1.90s/it]

loss: 0.15126799046993256


 72%|███████▏  | 60001/82783 [11:26:33<12:02:22,  1.90s/it]

loss: 0.20201095938682556


 74%|███████▎  | 61001/82783 [11:37:59<11:18:00,  1.87s/it]

loss: 0.14178985357284546


 75%|███████▍  | 62001/82783 [11:49:24<11:40:43,  2.02s/it]

loss: 0.25346097350120544


 76%|███████▌  | 63001/82783 [12:00:49<10:05:04,  1.84s/it]

loss: 0.2046862542629242


 77%|███████▋  | 64001/82783 [12:12:14<9:41:13,  1.86s/it] 

loss: 0.12824460864067078


 79%|███████▊  | 65001/82783 [12:23:40<8:53:58,  1.80s/it] 

loss: 0.16032381355762482


 80%|███████▉  | 66001/82783 [12:35:07<11:14:24,  2.41s/it]

loss: 0.18393327295780182


 81%|████████  | 67001/82783 [12:46:34<7:59:57,  1.82s/it] 

loss: 0.13660895824432373


 82%|████████▏ | 68001/82783 [12:57:59<7:48:40,  1.90s/it]

loss: 0.1370510458946228


 83%|████████▎ | 69001/82783 [13:09:24<7:03:04,  1.84s/it]

loss: 0.20160020887851715


 85%|████████▍ | 70001/82783 [13:20:54<6:43:03,  1.89s/it]

loss: 0.16550502181053162


 86%|████████▌ | 71001/82783 [13:32:19<6:12:44,  1.90s/it]

loss: 0.1497700959444046


 87%|████████▋ | 72001/82783 [13:43:46<5:29:12,  1.83s/it]

loss: 0.20412731170654297


 88%|████████▊ | 73001/82783 [13:55:11<5:04:04,  1.87s/it]

loss: 0.13970842957496643


 89%|████████▉ | 74001/82783 [14:06:37<4:35:49,  1.88s/it]

loss: 0.13396908342838287


 91%|█████████ | 75001/82783 [14:18:02<4:19:24,  2.00s/it]

loss: 0.16510213911533356


 92%|█████████▏| 76001/82783 [14:29:27<3:38:22,  1.93s/it]

loss: 0.15648697316646576


 93%|█████████▎| 77001/82783 [14:40:52<2:54:32,  1.81s/it]

loss: 0.13247765600681305


 94%|█████████▍| 78001/82783 [14:52:17<2:27:39,  1.85s/it]

loss: 0.12941677868366241


 95%|█████████▌| 79001/82783 [15:03:41<1:56:07,  1.84s/it]

loss: 0.170022651553154


 97%|█████████▋| 80001/82783 [15:15:05<1:22:23,  1.78s/it]

loss: 0.19038181006908417


 98%|█████████▊| 81001/82783 [15:26:29<54:31,  1.84s/it]  

loss: 0.14887650310993195


 99%|█████████▉| 82001/82783 [15:37:54<24:10,  1.86s/it]

loss: 0.1915241777896881


100%|██████████| 82783/82783 [15:46:46<00:00,  1.46it/s]


Run training decoder ...
Epoch 2 / 5


  0%|          | 1/82783 [00:00<14:36:16,  1.57it/s]

loss: 0.23310789465904236


  1%|          | 1001/82783 [11:24<40:51:13,  1.80s/it]

loss: 0.1270824670791626


  2%|▏         | 2001/82783 [22:48<41:44:02,  1.86s/it]

loss: 0.15453661978244781


  4%|▎         | 3001/82783 [34:13<41:40:18,  1.88s/it]

loss: 0.16385430097579956


  5%|▍         | 4001/82783 [45:38<40:00:27,  1.83s/it]

loss: 0.17026159167289734


  6%|▌         | 5001/82783 [57:03<38:59:29,  1.80s/it]

loss: 0.16456855833530426


  7%|▋         | 6001/82783 [1:08:27<38:10:41,  1.79s/it]

loss: 0.2258499711751938


  8%|▊         | 7001/82783 [1:19:53<39:02:46,  1.85s/it]

loss: 0.12577447295188904


 10%|▉         | 8001/82783 [1:31:19<42:22:46,  2.04s/it]

loss: 0.16480731964111328


 11%|█         | 9001/82783 [1:42:44<38:53:47,  1.90s/it]

loss: 0.2172001302242279


 12%|█▏        | 10001/82783 [1:54:10<36:53:12,  1.82s/it]

loss: 0.16825106739997864


 13%|█▎        | 11001/82783 [2:05:34<36:47:32,  1.85s/it]

loss: 0.12178140133619308


 14%|█▍        | 12001/82783 [2:16:58<35:32:57,  1.81s/it]

loss: 0.26701486110687256


 16%|█▌        | 13001/82783 [2:28:23<35:23:18,  1.83s/it]

loss: 0.1596044898033142


 17%|█▋        | 14001/82783 [2:39:47<34:37:59,  1.81s/it]

loss: 0.15203048288822174


 18%|█▊        | 15001/82783 [2:51:11<34:12:14,  1.82s/it]

loss: 0.12548156082630157


 19%|█▉        | 16001/82783 [3:02:36<34:03:32,  1.84s/it]

loss: 0.33075153827667236


 21%|██        | 17001/82783 [3:14:00<33:10:02,  1.82s/it]

loss: 0.13980942964553833


 22%|██▏       | 18001/82783 [3:25:27<32:21:48,  1.80s/it]

loss: 0.18993264436721802


 23%|██▎       | 19001/82783 [3:36:52<33:06:30,  1.87s/it]

loss: 0.14045608043670654


 24%|██▍       | 20001/82783 [3:48:17<35:07:21,  2.01s/it]

loss: 0.12826047837734222


 25%|██▌       | 21001/82783 [3:59:41<30:35:58,  1.78s/it]

loss: 0.1564246416091919


 27%|██▋       | 22001/82783 [4:11:06<29:49:50,  1.77s/it]

loss: 0.21814560890197754


 28%|██▊       | 23001/82783 [4:22:30<30:33:46,  1.84s/it]

loss: 0.13960644602775574


 29%|██▉       | 24001/82783 [4:33:56<30:40:36,  1.88s/it]

loss: 0.19498668611049652


 30%|███       | 25001/82783 [4:45:23<30:45:35,  1.92s/it]

loss: 0.13574808835983276


 31%|███▏      | 26001/82783 [4:56:51<28:51:27,  1.83s/it]

loss: 0.14678740501403809


 33%|███▎      | 27001/82783 [5:08:17<30:47:49,  1.99s/it]

loss: 0.12100744992494583


 34%|███▍      | 28001/82783 [5:19:43<27:28:56,  1.81s/it]

loss: 0.14251883327960968


 35%|███▌      | 29001/82783 [5:31:07<28:59:39,  1.94s/it]

loss: 0.12998369336128235


 36%|███▌      | 30001/82783 [5:42:34<27:10:56,  1.85s/it]

loss: 0.14072763919830322


 37%|███▋      | 31001/82783 [5:54:01<25:34:40,  1.78s/it]

loss: 0.16690316796302795


 39%|███▊      | 32001/82783 [6:05:25<27:15:59,  1.93s/it]

loss: 0.25086063146591187


 40%|███▉      | 33001/82783 [6:16:51<26:03:52,  1.88s/it]

loss: 0.12526826560497284


 41%|████      | 34001/82783 [6:28:16<26:14:39,  1.94s/it]

loss: 0.17078107595443726


 42%|████▏     | 35001/82783 [6:39:43<25:16:50,  1.90s/it]

loss: 0.1447305679321289


 43%|████▎     | 36001/82783 [6:51:07<24:26:00,  1.88s/it]

loss: 0.154936745762825


 45%|████▍     | 37001/82783 [7:02:36<24:16:26,  1.91s/it]

loss: 0.1304730325937271


 46%|████▌     | 38001/82783 [7:14:04<23:47:34,  1.91s/it]

loss: 0.18060274422168732


 47%|████▋     | 39001/82783 [7:25:32<23:13:21,  1.91s/it]

loss: 0.2733386754989624


 48%|████▊     | 40001/82783 [7:36:59<21:08:25,  1.78s/it]

loss: 0.13329406082630157


 50%|████▉     | 41001/82783 [7:48:26<22:15:59,  1.92s/it]

loss: 0.14860837161540985


 51%|█████     | 42001/82783 [7:59:52<21:40:58,  1.91s/it]

loss: 0.16586720943450928


 52%|█████▏    | 43001/82783 [8:11:20<20:08:57,  1.82s/it]

loss: 0.1890133172273636


 53%|█████▎    | 44001/82783 [8:22:46<19:48:29,  1.84s/it]

loss: 0.136681467294693


 54%|█████▍    | 45001/82783 [8:34:14<20:34:52,  1.96s/it]

loss: 0.146916463971138


 56%|█████▌    | 46001/82783 [8:45:41<19:23:33,  1.90s/it]

loss: 0.14563794434070587


 57%|█████▋    | 47001/82783 [8:57:10<19:59:54,  2.01s/it]

loss: 0.26821067929267883


 58%|█████▊    | 48001/82783 [9:08:37<18:12:23,  1.88s/it]

loss: 0.13860349357128143


 59%|█████▉    | 49001/82783 [9:20:03<18:56:53,  2.02s/it]

loss: 0.22073397040367126


 60%|██████    | 50001/82783 [9:31:29<17:15:15,  1.89s/it]

loss: 0.14235571026802063


 62%|██████▏   | 51001/82783 [9:42:56<16:14:08,  1.84s/it]

loss: 0.21280266344547272


 63%|██████▎   | 52001/82783 [9:54:23<15:47:52,  1.85s/it]

loss: 0.1640334576368332


 64%|██████▍   | 53001/82783 [10:05:47<15:29:08,  1.87s/it]

loss: 0.13844367861747742


 65%|██████▌   | 54001/82783 [10:17:09<14:36:50,  1.83s/it]

loss: 0.1700219362974167


 66%|██████▋   | 55001/82783 [10:28:33<14:14:06,  1.84s/it]

loss: 0.1361301839351654


 68%|██████▊   | 56001/82783 [10:40:01<14:54:39,  2.00s/it]

loss: 0.15378984808921814


 69%|██████▉   | 57001/82783 [10:51:26<14:21:22,  2.00s/it]

loss: 0.12436414510011673


 70%|███████   | 58001/82783 [11:02:49<13:53:54,  2.02s/it]

loss: 0.15493178367614746


 71%|███████▏  | 59001/82783 [11:14:11<12:58:19,  1.96s/it]

loss: 0.1812020242214203


 72%|███████▏  | 60001/82783 [11:25:37<12:57:35,  2.05s/it]

loss: 0.14269030094146729


 74%|███████▎  | 61001/82783 [11:37:05<11:30:47,  1.90s/it]

loss: 0.14222301542758942


 75%|███████▍  | 62001/82783 [11:48:32<11:05:04,  1.92s/it]

loss: 0.1248227059841156


 76%|███████▌  | 63001/82783 [11:59:56<10:29:22,  1.91s/it]

loss: 0.1686263382434845


 77%|███████▋  | 63389/82783 [12:04:18<3:26:59,  1.56it/s] 

<font size="5">Save Trained Model</font>

In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
).to(device)

torch.save(dalle2.state_dict(), dalle2_save_path)