# Imagen Training Script on CocoCaptions Dataset

CocoCaptions Dataset - [documentation](https://pytorch.org/vision/main/generated/torchvision.datasets.CocoCaptions.html)

In [None]:
import os 
import yaml
import math
import wandb
import logging
import warnings
import numpy as np
from time import time
from PIL import Image
from pathlib import Path
from flatdict import FlatDict
import matplotlib.pyplot as plt
from IPython.display import clear_output

import torch
import torchvision
from torch.utils import data
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision import datasets, utils
from imagen_pytorch.t5 import t5_encode_text
from imagen_pytorch import Unet, Imagen, ImagenTrainer, ElucidatedImagen
from transformations import ComposeDouble, FunctionWrapperDouble, select_random_label, select_fixed_label

logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

%config InlineBackend.figure_format = 'retina'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
cfg = yaml.safe_load(Path("configs\\elucidated_imagen-medium-config.yaml").read_text())
# cfg = yaml.safe_load(Path("configs\\imagen-medium-config.yaml").read_text())
cfg_flat = dict(FlatDict(cfg, delimiter='.'))

In [None]:
# wandb.login()

In [None]:
wandb.init(project="imagen", entity="camlaedtke", config=cfg_flat)#, resume=True)

In [None]:
def get_emb_tensor(cfg, targets, device):
    text_embeds, text_masks = t5_encode_text(targets, name = cfg["model"]["text_encoder_name"])
    text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
    return text_embeds, text_masks


def display_data(display_list, label_list):
    plt.figure(figsize=(16, 12))
    for i in range(len(display_list)):
        j = 1 if (i < 2) else 2
        img = (display_list[i].cpu().permute(1,2,0).numpy() * 255).astype(np.uint8)
        plt.subplot(len(display_list)//3, 3, i+1)
        plt.title(label_list[i], fontsize=12)
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
data_transforms = ComposeDouble([
    FunctionWrapperDouble(T.Resize(cfg["dataset"]["image_size"]), input=True, target=False),
    FunctionWrapperDouble(T.RandomHorizontalFlip(), input=True, target=False),
    FunctionWrapperDouble(T.CenterCrop(cfg["dataset"]["image_size"]), input=True, target=False),
    FunctionWrapperDouble(T.ToTensor(), input=True, target=False),
    FunctionWrapperDouble(select_random_label, input=False, target=True),
])


train_dataset = datasets.CocoCaptions(
    root = cfg["dataset"]["train"]["root"], 
    annFile = cfg["dataset"]["train"]["ann_file"],
    transforms = data_transforms,
)

valid_dataset = datasets.CocoCaptions(
    root = cfg["dataset"]["val"]["root"],
    annFile = cfg["dataset"]["val"]["ann_file"],
    transforms=data_transforms,
)


train_dataloader = DataLoader(
    dataset = train_dataset, 
    batch_size = cfg["train"]["batch_size"], 
    shuffle = True,
    drop_last = True,
    num_workers = 3,
    prefetch_factor = 5,
    pin_memory=True
)

valid_dataloader = DataLoader(
    dataset = valid_dataset, 
    batch_size = cfg["train"]["batch_size"], 
    shuffle = True,
    drop_last = True,
    num_workers = 3,
    prefetch_factor = 5,
    pin_memory=True
)

In [None]:
X_batch, y_batch = next(iter(train_dataloader))

In [None]:
%matplotlib inline
display_data(X_batch[0:6], y_batch[0:6])

In [None]:
BaseUnet = Unet(
    dim = cfg["model"]["base_unet"]["dim"],
    text_embed_dim = cfg["model"]["text_embed_dim"],
    cond_dim = cfg["model"]["base_unet"]["cond_dim"],
    dim_mults = cfg["model"]["base_unet"]['dim_mults'], 
    num_resnet_blocks = cfg["model"]["base_unet"]["num_resnet_blocks"],
    layer_attns = cfg["model"]["base_unet"]['layer_attns'], 
    layer_cross_attns = cfg["model"]["base_unet"]['layer_cross_attns'], 
    attn_heads = cfg["model"]["base_unet"]["attn_heads"],
    ff_mult = cfg["model"]["base_unet"]["ff_mult"],
    memory_efficient = cfg["model"]["base_unet"]["memory_efficient"],
    dropout = cfg["model"]["base_unet"]["dropout"]
)


SRUnet = Unet(
    dim = cfg["model"]["sr_unet1"]["dim"],
    text_embed_dim = cfg["model"]["text_embed_dim"],
    cond_dim = cfg["model"]["sr_unet1"]["cond_dim"],
    dim_mults = cfg["model"]["sr_unet1"]["dim_mults"], 
    num_resnet_blocks = cfg["model"]["sr_unet1"]["num_resnet_blocks"], 
    layer_attns = cfg["model"]["sr_unet1"]["layer_attns"],
    layer_cross_attns = cfg["model"]["sr_unet1"]["layer_cross_attns"], 
    attn_heads = cfg["model"]["sr_unet1"]["attn_heads"],
    ff_mult = cfg["model"]["sr_unet1"]["ff_mult"],
    memory_efficient = cfg["model"]["sr_unet1"]["memory_efficient"],
    dropout = cfg["model"]["sr_unet1"]["dropout"]
)

imagen = ElucidatedImagen(
    unets = (BaseUnet, SRUnet),
    image_sizes = cfg["model"]["image_sizes"],
    text_embed_dim = cfg["model"]["text_embed_dim"],
    text_encoder_name = cfg["model"]["text_encoder_name"],
    cond_drop_prob = cfg["model"]["cond_drop_prob"],
    num_sample_steps = cfg["model"]["num_sample_steps"], 
    sigma_min = cfg["model"]["sigma_min"],          
    sigma_max = cfg["model"]["sigma_max"],       
    sigma_data = cfg["model"]["sigma_delta"],            
    rho = cfg["model"]["rho"],                     
    P_mean = cfg["model"]["P_mean"],               
    P_std = cfg["model"]["P_std"],                 
    S_churn = cfg["model"]["S_churn"],                
    S_tmin = cfg["model"]["S_tmin"],
    S_tmax = cfg["model"]["S_tmax"],
    S_noise = cfg["model"]["S_noise"],
).cuda()

# imagen = Imagen(
#     unets = (BaseUnet, SRUnet),
#     text_encoder_name = cfg["model"]["text_encoder_name"], 
#     image_sizes = cfg["model"]["image_sizes"], 
#     cond_drop_prob = cfg["model"]["cond_drop_prob"],
#     timesteps = cfg["model"]["timesteps"],
# ).cuda()

##### TRAINING #####
trainer = ImagenTrainer(
    imagen, 
    lr = cfg["train"]["lr"],
    amp = cfg["train"]["amp"],
    use_ema = cfg["train"]["use_ema"],
    warmup_steps = eval(cfg["train"]["warmup_steps"]),
    cosine_decay_max_steps = eval(cfg["train"]["cosine_decay_max_steps"]),
)

In [None]:
def display_images(display_list):
    image_list = []
    plt.figure(figsize=(10, 10), dpi=150)
    for i in range(len(display_list)):
        img = display_list[i].cpu().numpy() * 255
        img = np.swapaxes(img,0,2).astype(np.uint8)
        image_list.append(img)
        plt.subplot(1, len(display_list), i+1)
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()
    return image_list


def save_checkpoint(cfg, step, loss, trainer):
    if step % cfg["train"]["checkpoint_rate"] == 0 and step !=0 and not math.isnan(loss): 
        trainer.save(cfg["train"]["checkpoint_path"])
        
        
def print_epoch_stats(e_time, train_loss_arr, valid_loss_arr):
    print(f"   Time: {e_time:.0f} min, "\
          f"Train Loss: {np.mean(train_loss_arr, where=np.isnan(train_loss_arr)==False):.4f}, "\
          f"Valid Loss: {np.mean(valid_loss_arr, where=np.isnan(valid_loss_arr)==False):.4f}")
    
    
def train(cfg, train_dataloader, trainer, epoch, i, device):
    train_loss_arr = []
    fetch_times = []; embed_times = []; loss_times = []; update_times = []; step_times = []
    for step, batch in enumerate(train_dataloader):
        step_start = time()
        
        fetch_start = time()
        images, texts = batch
        images = images.to(device)
        fetch_end = time()
        fetch_times.append(fetch_end-fetch_start)
        
        embed_start = time()
        text_embeds, text_masks = get_emb_tensor(cfg, texts, device)
        embed_end = time()
        embed_times.append(embed_end-embed_start)

        loss_start = time()
        loss = trainer(
            images, 
            text_embeds = text_embeds, 
            text_masks = text_masks, 
            unet_number = i, 
            max_batch_size=cfg["train"]["base_unet_max_batch_size"] if i==1 else cfg["train"]["sr_unet1_max_batch_size"]
        )
        loss_end = time()
        loss_times.append(loss_end-loss_start)
        
        update_start = time()
        trainer.update(unet_number = i)
        update_end = time()
        update_times.append(update_end-update_start)
        
        step_end = time()
        step_times.append(step_end-step_start)
        
        
        train_loss_arr.append(loss)
        save_checkpoint(cfg, step, loss, trainer)
        
        curr_step = int(len(train_dataloader)*(epoch-1) + step)
        wandb.log({f"Train Loss {i}": loss, f"Train {i} Step": curr_step})
        print(f"\r   Train Step {step+1}/{len(train_dataloader)}, Train Loss: {loss:.4f}", end='')
        
    
    step_time = np.mean(step_times)
    fetch_time = np.mean(fetch_times)
    embed_time = np.mean(embed_times)
    loss_time = np.mean(loss_times)
    update_time = np.mean(update_times)
    print()
    print(f"      Step: {step_time:.4f}s, Img load: {fetch_time:.4f}s, Embed: {embed_time:.4f}s, "\
          f"Loss: {loss_time:.4f}s, Update: {update_time:.4f}s")
    return trainer, train_loss_arr


def validate(cfg, valid_dataloader, trainer, epoch, i, device):
    valid_loss_arr = []
    for step, batch in enumerate(valid_dataloader):
        images, texts = batch
        images = images.to(device)
        text_embeds, text_masks = get_emb_tensor(cfg, texts, device)

        loss = trainer(
            images, 
            text_embeds = text_embeds, 
            text_masks = text_masks, 
            unet_number = i, 
            max_batch_size=cfg["train"]["base_unet_max_batch_size"] if i==1 else cfg["train"]["sr_unet1_max_batch_size"]
        )
        valid_loss_arr.append(loss)
        
        curr_step = int(len(valid_dataloader)*(epoch-1) + step)
        wandb.log({f"Validation Loss {i}": loss, f"Valid {i} Step": curr_step})
        print(f"\r   Valid Step {step+1}/{len(valid_dataloader)}, Valid Loss: {loss:.4f}", end='')
    print()
    return valid_loss_arr



def run_train_loop(cfg, trainer, train_dataloader, valid_dataloader, device):
    
    for epoch in range(1, cfg["train"]["epochs"]+1):
        print(f"\nEpoch {epoch}/{cfg['train']['epochs']}")
        
        for i in (1,2):
            
            print(f"--- Unet {i} ---")
            start = time()

            trainer, train_loss_arr = train(cfg, train_dataloader, trainer, epoch, i, device)

            valid_loss_arr = [0]
            if epoch % 5 == 0:
                valid_loss_arr = validate(cfg, valid_dataloader, trainer, epoch, i, device)

            end = time()
            e_time = (end-start)/60 

            print_epoch_stats(e_time, train_loss_arr, valid_loss_arr)
            if not math.isnan(valid_loss_arr[-1]): 
                trainer.save(cfg["train"]["checkpoint_path"])
            
        texts = [
            'red flowers in a white vase',
            'a puppy looking anxiously at a giant donut on the table',
            'the milky way galaxy in the style of monet'
        ]
        sampled_images = trainer.sample(texts, cond_scale = cfg["train"]["cond_scale"])
        image_list = display_images(sampled_images)
        # image_list = format_images(sampled_images)
        images_pil = [Image.fromarray(image) for image in image_list]
        wandb.log({"Samples": [wandb.Image(image) for image in images_pil], "Epoch": epoch})

In [None]:
try:
    trainer.load(cfg["train"]["checkpoint_path"], strict=False)
    print("Loaded checkpoint")
except: 
    pass

In [None]:
# torch.backends.cudnn.benchmark = True

In [None]:
run_train_loop(cfg, trainer, train_dataloader, valid_dataloader, device)

In [None]:
texts = [
    'red flowers on a beach by the sunset',
    'a puppy looking anxiously at a giant donut on the table',
    'the milky way galaxy in the style of monet'
]

In [None]:
sampled_images = trainer.sample(texts, cond_scale = 5)

In [None]:
imgs = display_images(sampled_images)