# Imagen Training Script on CocoCaptions Dataset

CocoCaptions Dataset - [documentation](https://pytorch.org/vision/main/generated/torchvision.datasets.CocoCaptions.html)

imagen-pytorch [library](https://github.com/lucidrains/imagen-pytorch)

In [None]:
import os 
import yaml
import math
import wandb
import logging
import warnings
import numpy as np
from time import time
from PIL import Image
from pathlib import Path
from flatdict import FlatDict
import matplotlib.pyplot as plt
from IPython.display import clear_output

import torch
import torchvision
from torch.utils import data
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision import datasets, utils
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from utils.data_utils import CocoDataset
from utils.train_utils import get_emb_tensor, display_images, save_checkpoint, print_epoch_stats
import webdataset as wds
warnings.filterwarnings("ignore")

%config InlineBackend.figure_format = 'retina'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
cfg = yaml.safe_load(Path("configs\\imagen-large-config.yaml").read_text())
cfg_flat = dict(FlatDict(cfg, delimiter='.'))

In [None]:
# wandb.init(project="imagen", entity="camlaedtke", config=cfg_flat)

In [None]:
preproc = T.Compose([
    T.Resize(cfg["dataset"]["image_size"]),
    T.RandomHorizontalFlip(),
    T.CenterCrop(cfg["dataset"]["image_size"]),
    T.ToTensor()
])

cc3m_dataset = (
    wds.WebDataset("cc3m/{00000..00331}.tar")
    .shuffle(1000)
    .decode("pil")
    .rename(image="jpg;png", caption="txt")
    .map_dict(image=preproc)
    .to_tuple("image", "caption")
)

cc3m_dataloader = DataLoader(
    dataset = cc3m_dataset, 
    batch_size = cfg["train"]["batch_size"], 
    drop_last = True,
    num_workers = 4,
    prefetch_factor = 8,
    pin_memory = True
)

In [None]:
X, y = next(iter(cc3m_dataset))

In [None]:
X.numpy()

In [None]:
X_batch, y_batch = next(iter(cc3m_dataloader))

In [None]:
def display_data(display_list, captions):
    plt.figure(figsize=(16, 16))
    for i in range(len(display_list)):
        img = display_list[i].cpu().permute(1,2,0).numpy() * 255
        img = img.astype(np.uint8)
        plt.subplot(3, len(display_list)//3, i+1)
        plt.title(captions[i], fontsize=10)
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
%matplotlib inline
display_data(X_batch[0:6], y_batch[0:6])

In [None]:
##### MODEL #####
BaseUnet = Unet(
    dim = cfg["model"]["base_unet"]["dim"],
    cond_dim = cfg["model"]["base_unet"]["cond_dim"],
    dim_mults = cfg["model"]["base_unet"]['dim_mults'], 
    num_resnet_blocks = cfg["model"]["base_unet"]["num_resnet_blocks"],
    layer_attns = cfg["model"]["base_unet"]['layer_attns'], 
    layer_cross_attns = cfg["model"]["base_unet"]['layer_cross_attns'], 
    attn_heads = cfg["model"]["base_unet"]["attn_heads"],
    ff_mult = cfg["model"]["base_unet"]["ff_mult"],
    memory_efficient = cfg["model"]["base_unet"]["memory_efficient"],
    dropout = cfg["model"]["base_unet"]["dropout"]
)


SRUnet = Unet(
    dim = cfg["model"]["sr_unet1"]["dim"],
    cond_dim = cfg["model"]["sr_unet1"]["cond_dim"],
    dim_mults = cfg["model"]["sr_unet1"]["dim_mults"], 
    num_resnet_blocks = cfg["model"]["sr_unet1"]["num_resnet_blocks"], 
    layer_attns = cfg["model"]["sr_unet1"]["layer_attns"],
    layer_cross_attns = cfg["model"]["sr_unet1"]["layer_cross_attns"], 
    attn_heads = cfg["model"]["sr_unet1"]["attn_heads"],
    ff_mult = cfg["model"]["sr_unet1"]["ff_mult"],
    memory_efficient = cfg["model"]["sr_unet1"]["memory_efficient"],
    dropout = cfg["model"]["sr_unet1"]["dropout"]
)


imagen = Imagen(
    unets = (BaseUnet, SRUnet),
    text_encoder_name = cfg["model"]["text_encoder_name"], 
    image_sizes = cfg["model"]["image_sizes"], 
    cond_drop_prob = cfg["model"]["cond_drop_prob"],
    timesteps = cfg["model"]["timesteps"],
).cuda()

##### TRAINING #####
trainer = ImagenTrainer(
    imagen, 
    lr = cfg["train"]["lr"],
    fp16 = cfg["train"]["amp"],
    use_ema = cfg["train"]["use_ema"],
    max_grad_norm = cfg["train"]["max_grad_norm"],
    warmup_steps = eval(cfg["train"]["warmup_steps"]),
    cosine_decay_max_steps = eval(cfg["train"]["cosine_decay_max_steps"]),
)

In [None]:
def train(cfg, dataloader, trainer, epoch, i, device):
    loss_arr = []
    fetch_times = []; embed_times = []; loss_times = []; update_times = []; step_times = []
    for step, batch in enumerate(dataloader):
        step_start = time()
        
        fetch_start = time()
        images, texts = batch
        images = images.to(device)
        fetch_end = time()
        fetch_times.append(fetch_end-fetch_start)
        
        embed_start = time()
        text_embeds, text_masks = get_emb_tensor(cfg, texts, device)
        embed_end = time()
        embed_times.append(embed_end-embed_start)

        loss_start = time()
        loss = trainer(
            images, 
            text_embeds = text_embeds, 
            text_masks = text_masks, 
            unet_number = i, 
            max_batch_size=cfg["train"]["base_unet_max_batch_size"] if i==1 else cfg["train"]["sr_unet1_max_batch_size"]
        )
        loss_end = time()
        loss_times.append(loss_end-loss_start)
        
        update_start = time()
        trainer.update(unet_number = i)
        update_end = time()
        update_times.append(update_end-update_start)
        
        step_end = time()
        step_times.append(step_end-step_start)
        
        
        loss_arr.append(loss)
        save_checkpoint(cfg, step, loss, trainer)
        
        curr_step = int(len(dataloader)*(epoch-1) + step)
        wandb.log({f"Train Loss {i}": loss, f"Train {i} Step": curr_step})
        print(f"\r   Train Step {step+1}/{len(dataloader)}, Train Loss: {loss:.4f}", end='')
        
    
    step_time = np.mean(step_times)
    fetch_time = np.mean(fetch_times)
    embed_time = np.mean(embed_times)
    loss_time = np.mean(loss_times)
    update_time = np.mean(update_times)
    print()
    print(f"      Step: {step_time:.4f}s, Img load: {fetch_time:.4f}s, Embed: {embed_time:.4f}s, "\
          f"Loss: {loss_time:.4f}s, Update: {update_time:.4f}s")
    return trainer, loss_arr



def run_train_loop(cfg, trainer, dataloader, device):
    
    for epoch in range(1, cfg["train"]["epochs"]+1):
        print(f"\nEpoch {epoch}/{cfg['train']['epochs']}")
        
        for i in (1,2):
            
            print(f"--- Unet {i} ---")
            start = time()

            trainer, loss_arr = train(cfg, dataloader, trainer, epoch, i, device)

            end = time()
            e_time = (end-start)/60 

            print_epoch_stats(e_time, loss_arr)
            if not math.isnan(loss_arr[-1]): 
                trainer.save(cfg["train"]["checkpoint_path"])
            
        texts = [
            'red flowers in a white vase',
            'a puppy looking anxiously at a giant donut on the table',
            'the milky way galaxy in the style of monet'
        ]
        sampled_images = trainer.sample(texts, cond_scale = cfg["train"]["cond_scale"])
        image_list = display_images(sampled_images)
        images_pil = [Image.fromarray(image) for image in image_list]
        wandb.log({"Samples": [wandb.Image(image) for image in images_pil], "Epoch": epoch})




def run_train_loop(cfg, trainer, train_dataloader, valid_dataloader, device):
    
    for epoch in range(1, cfg["train"]["epochs"]+1):
        print(f"\nEpoch {epoch}/{cfg['train']['epochs']}")
        
        for i in (1,2):
            
            print(f"--- Unet {i} ---")
            start = time()

            trainer, train_loss_arr = train(cfg, train_dataloader, trainer, epoch, i, device)

            valid_loss_arr = [0]
            if epoch % 5 == 0:
                valid_loss_arr = validate(cfg, valid_dataloader, trainer, epoch, i, device)

            end = time()
            e_time = (end-start)/60 

            print_epoch_stats(e_time, train_loss_arr, valid_loss_arr)
            if not math.isnan(valid_loss_arr[-1]): 
                trainer.save(cfg["train"]["checkpoint_path"])
            
        texts = [
            'red flowers in a white vase',
            'a puppy looking anxiously at a giant donut on the table',
            'the milky way galaxy in the style of monet'
        ]
        sampled_images = trainer.sample(texts, cond_scale = cfg["train"]["cond_scale"])
        clear_output()
        image_list = display_images(sampled_images)
        images_pil = [Image.fromarray(image) for image in image_list]
        wandb.log({"Samples": [wandb.Image(image) for image in images_pil], "Epoch": epoch})

In [None]:
if cfg["train"]["load_checkpoint"]:
    try:
        trainer.load(cfg["train"]["checkpoint_path"], strict=False, only_model=True)
        print("Loaded checkpoint")
    except: 
        pass

In [None]:
trainer.load(cfg["train"]["checkpoint_path"], strict=False, only_model=False)

In [None]:
# torch.backends.cudnn.benchmark = True # Uses ~2GB more VRAM

In [None]:
run_train_loop(cfg, trainer, train_dataloader, valid_dataloader, device)

In [None]:
texts = [
    'red flowers on a beach by the sunset',
    'a puppy looking anxiously at a giant donut on the table',
    'the milky way galaxy in the style of monet'
]

In [None]:
sampled_images = trainer.sample(texts, cond_scale = cfg["train"]["cond_scale"])

In [None]:
imgs = display_images(sampled_images)