<font size="5">Import Libraries</font>

In [1]:
import torch
from torchvision import transforms as T
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
import os
from tqdm import tqdm
from dalle_pytorch import DiscreteVAE, DALLE
from dalle_pytorch.tokenizer import SimpleTokenizer
from torchvision.datasets.coco import CocoCaptions

  from .autonotebook import tqdm as notebook_tqdm


<font size="5">Setting Dataset & Training Parameters</font>

In [2]:
# Change your input size here
input_image_size = 256

# Change your batch size here
batch_size = 1

# Change your epoch here
epoch = 5

# Change your train image root path here
train_img_path = "./train2014/"

# Change your train annot json path here
train_annot_path = "./annotations/captions_train2014.json"

# Change your device ("cpu" or "cuda")
device = "cuda"

# Change your VAE save path here (ends with ".pth")
vae_save_path = "./vae.pth"

# Change your dalle model save path here (ends with ".pth")
dalle_save_path = "./dalle.pth"

# Change the test result image save path (should be a directory or folder)
test_img_save_path = "./result"

if not os.path.exists(test_img_save_path):
    os.makedirs(test_img_save_path)

<font size="5">Data Preprocessing</font>

In [3]:
transform = T.Compose([
    T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
    T.Resize(input_image_size),
    T.CenterCrop(input_image_size),
    T.ToTensor()
])

train_data = CocoCaptions(
    root=train_img_path,
    annFile=train_annot_path,
    transform=transform
)

loading annotations into memory...
Done (t=0.92s)
creating index...
index created!


<font size="5">Create VAE Model</font>

In [4]:
vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
).to(device)

if os.path.exists(vae_save_path):
    vae.load_state_dict(torch.load(vae_save_path))

<font size="5">Train VAE Model</font>

In [5]:
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

tokenizer = SimpleTokenizer()
opt = Adam(
    vae.parameters(),
    lr = 3e-4,
    weight_decay=0.01,
    betas = (0.9, 0.999)
)
sched = ReduceLROnPlateau(
    opt,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=10,
    min_lr=1e-6,
    verbose=True,
)

for curr_epoch in range(epoch):
    print("Run training discrete vae ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        image_list = []
        
        for curr_idx in iter_idx:
            image, _ = train_data[curr_idx]
            image = image.unsqueeze(0)

            image_list.append(image)

        images = torch.cat(image_list, dim=0).type(torch.FloatTensor).to(device)

        opt.zero_grad()
        loss = vae(images, return_loss = True)
        loss.backward()
        opt.step()

        if batch_idx != 0 and batch_idx % 100 == 0:
            torch.save(vae.state_dict(), vae_save_path)
            sched.step(loss)
        
        if batch_idx % 1000 == 0:
            print(f"loss: {loss.data}")

torch.save(vae.state_dict(), vae_save_path)

Run training discrete vae ...
Epoch 1 / 5


  0%|          | 4/82783 [00:03<15:50:20,  1.45it/s]

loss: 0.40275079011917114


  1%|          | 1004/82783 [00:43<1:04:01, 21.29it/s]

loss: 0.2165435552597046


  2%|▏         | 2003/82783 [01:20<1:13:22, 18.35it/s]

loss: 0.6831685900688171


  3%|▎         | 2806/82783 [01:50<1:04:30, 20.66it/s]

Epoch 00028: reducing learning rate of group 0 to 1.5000e-04.


  4%|▎         | 3005/82783 [01:57<1:01:12, 21.72it/s]

loss: 0.37978482246398926


  5%|▍         | 4006/82783 [02:37<1:09:56, 18.77it/s]

loss: 0.4275996685028076


  6%|▌         | 4904/82783 [03:10<1:02:42, 20.70it/s]

Epoch 00049: reducing learning rate of group 0 to 7.5000e-05.


  6%|▌         | 5006/82783 [03:14<1:05:56, 19.66it/s]

loss: 0.332499623298645


  7%|▋         | 6004/82783 [03:51<1:00:30, 21.15it/s]

loss: 0.1905183494091034


  8%|▊         | 6987/82783 [04:29<51:30, 24.53it/s]  

<font size="5">Create DALLE Model</font>

In [None]:
tokenizer = SimpleTokenizer()

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
).to(device)

if os.path.exists(dalle_save_path):
    dalle.load_state_dict(torch.load(dalle_save_path))

<font size="5">Train DALLE Model</font>

In [None]:
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

opt = Adam(
    dalle.parameters(),
    lr = 3e-4,
    weight_decay=0.01,
    betas = (0.9, 0.999)
)
sched = ReduceLROnPlateau(
    opt,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=10,
    min_lr=1e-6,
    verbose=True,
)

for curr_epoch in range(epoch):
    print("Run training dalle ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        image_list = []
        text_list = []
        
        for curr_idx in iter_idx:
            image, target = train_data[curr_idx]
            image = image.unsqueeze(0).type(torch.FloatTensor).to(device)
            text = tokenizer.tokenize(target).type(torch.LongTensor).to(device)

            text_size = len(text)
            for i in range(text_size):
                image_list.append(image)
            
            text_list.append(text)

        text = torch.cat(text_list, dim=0).to(device)
        image = torch.cat(image_list, dim=0).to(device)

        opt.zero_grad()
        loss = dalle(text, image, return_loss = True)
        loss.backward()
        opt.step()

        if batch_idx != 0 and batch_idx % 100 == 0:
            torch.save(dalle.state_dict(), dalle_save_path)
            sched.step(loss)
        
        if batch_idx % 1000 == 0:
            print(f"loss: {loss.data}")

torch.save(dalle.state_dict(), dalle_save_path)

<font size="5">Test DALLE model with several inputs</font>

In [None]:
test_inputs = ['Closeup of bins of food that include broccoli and bread.'] # text input for the model (can be more than one)

text = tokenizer.tokenize(test_inputs).to(device)

test_img_tensors = dalle.generate_images(text)

for test_idx, test_img_tensor in enumerate(test_img_tensors):
    test_img = T.ToPILImage()(test_img_tensor)
    test_save_path = os.path.join(test_img_save_path, f"{test_inputs[test_idx]}.jpg")
    test_img.save(Path(test_save_path))