<font size="5">Import Libraries</font>

In [None]:
import torch
from torchvision import transforms as T
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import os
from tqdm import tqdm
from dalle_pytorch import DiscreteVAE, DALLE
from dalle_pytorch.tokenizer import SimpleTokenizer
import pandas as pd

<font size="5">Setting Dataset & Training Parameters</font>

In [None]:
# Change your input size here
input_image_size = 256

# Change your batch size here
batch_size = 1

# Change your epoch here
epoch = 5

# Change your train image root path here
train_img_path = "./Flower_Dataset_Combine/ImagesCombine/"

# Change your train annot csv path here
train_annot_path = "./Flower_Dataset_Combine/New_captions.csv"

# Change your device ("cpu" or "cuda")
device = "cuda"

# Change your VAE save path here (ends with ".pth")
vae_save_path = "./vae.pth"

# Change your dalle model save path here (ends with ".pth")
dalle_save_path = "./dalle.pth"

<font size="5">Data Preprocessing</font>

In [None]:
transform = T.Compose([
    T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
    T.Resize(input_image_size),
    T.CenterCrop(input_image_size),
    T.ToTensor()
])

train_csv= pd.read_csv(train_annot_path)

train_csv = train_csv.drop_duplicates()
train_csv = train_csv.dropna()

<font size="5">Create VAE Model</font>

In [None]:
vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
).to(device)

if os.path.exists(vae_save_path):
    vae.load_state_dict(torch.load(vae_save_path))

<font size="5">Train VAE Model</font>

In [None]:
train_size = len(train_csv)
idx_list = range(0, train_size, batch_size)

tokenizer = SimpleTokenizer()
opt = Adam(
    vae.parameters(),
    lr = 3e-4,
    weight_decay=0.01,
    betas = (0.9, 0.999)
)
sched = ReduceLROnPlateau(
    opt,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=10,
    min_lr=1e-6,
    verbose=True,
)

for curr_epoch in range(epoch):
    print("Run training discrete vae ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        batch_len = 0
        total_loss = torch.tensor(0., device=device)

        for curr_idx in iter_idx:
            image_name = train_csv.loc[curr_idx]['file_name']
            image_path = os.path.join(train_img_path, image_name)
            image = Image.open(image_path)
            image = transform(image)
            image = image.unsqueeze(0).to(device)

            if total_loss == torch.tensor(0., device=device):
                total_loss = vae(image, return_loss=True)
            else:
                total_loss += vae(image, return_loss=True)
            batch_len += 1
                
        avg_loss = total_loss / batch_len

        opt.zero_grad()
        avg_loss.backward()
        opt.step()

        if batch_idx % 100 == 0:
            torch.save(vae.state_dict(), vae_save_path)
            print(f"loss: {avg_loss.data}")
        
    sched.step(avg_loss)

torch.save(vae.state_dict(), vae_save_path)

<font size="5">Create DALLE Model</font>

In [None]:
tokenizer = SimpleTokenizer()

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 49408,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 1,                  # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
).to(device)

if os.path.exists(dalle_save_path):
    dalle.load_state_dict(torch.load(dalle_save_path))

<font size="5">Train DALLE Model</font>

In [None]:
train_size = len(train_csv)
idx_list = range(0, train_size, batch_size)

opt = Adam(
    dalle.parameters(),
    lr = 3e-4,
    weight_decay=0.01,
    betas = (0.9, 0.999)
)
sched = ReduceLROnPlateau(
    opt,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=10,
    min_lr=1e-6,
    verbose=True,
)

for curr_epoch in range(epoch):
    print("Run training dalle ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        batch_len = 0
        total_loss = torch.tensor(0., device=device)

        for curr_idx in iter_idx:
            image_name = train_csv.loc[curr_idx]['file_name']
            image_path = os.path.join(train_img_path, image_name)
            image = Image.open(image_path)
            image = transform(image)
            image = image.unsqueeze(0).to(device)

            target = [train_csv.loc[curr_idx]['caption']]
            texts = tokenizer.tokenize(target).to(device)

            for text in texts:
                if total_loss == torch.tensor(0., device=device):
                    total_loss = dalle(text.unsqueeze(0), image, return_loss=True)
                else:
                    total_loss += dalle(text.unsqueeze(0), image, return_loss=True)
                batch_len += 1
                
        avg_loss = total_loss / batch_len

        opt.zero_grad()
        avg_loss.backward()
        opt.step()

        if batch_idx % 100 == 0:
            torch.save(dalle.state_dict(), dalle_save_path)
            print(f"average loss: {avg_loss.data}")
        
    sched.step(avg_loss)

torch.save(dalle.state_dict(), dalle_save_path)