# Conditional Diffusion Model - V2

In [None]:
from dataclasses import dataclass
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from tqdm import tqdm
from copy import deepcopy
from utils.models_v2 import UNet, EMA
from utils.diffusion_v2 import Diffusion

## Configuration

In [None]:
@dataclass
class Config:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # data_dir = r"D:\Data\Datasets\landscape"
    data_dir = r"D:\Data\Datasets\S2TLD_extracted_small"
    model_name = "s2tld_cond1_v2"
    diffusion_steps = 1000
    learn_rate = 1e-4
    batch_size = 3
    image_size = 64
    num_samples = 3  # Number of samples to generate
    sample_freq_epochs = 3  # Frequency with with to generate samples (i.e. every x epochs)
    load_from_checkpoint = True  # Start training from checkpoint if one exists for the given model name
    train_unconditional_prob = 0.1  # Probability of training model unconditionally for a given batch
    cfg_scale = 3  # Classifier free guidance scale

## Utils

In [None]:
def save_image_batch(image_batch, filename):
    """
    Save a batch of image tensors to disk as a single grid image

    Parameters:
    - image_batch: batch of unnormalized image tensors with shape (batch_size, channels, height, width)
    - filename: File path and name to save image without extension 
    """
    # Normalize images to [0, 1]
    image_batch = torch.clamp(image_batch * 0.5 + 0.5, 0, 1)

    # Create image grid and convert to PIL image
    image_grid = make_grid(image_batch)
    image_grid = np.transpose(image_grid.cpu().numpy(), (1, 2, 0))
    pil_image = Image.fromarray((image_grid * 255).astype(np.uint8))

    # Split the path into directory and filename
    path = f"{filename}.png"
    directory, filename = os.path.split(path)

    # Create the directory if it does not exist
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Save image to disk
    pil_image.save(path)


def save_checkpoint(state, filename):
    # Split the path into directory and filename
    path = f"{filename}.pt"
    directory, filename = os.path.split(path)

    # Create the directory if it does not exist
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Save checkpoint state
    torch.save(state, path)


def load_checkpoint(checkpoint, model, optimizer, diffusion, ema_model=None):
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    diffusion.load_state_dict(checkpoint['diffusion'])
    if ema_model:
        ema_model.load_state_dict(checkpoint['ema_state_dict'])
    return checkpoint['epoch'], checkpoint['loss_history']


def get_dataloader(data_dir, batch_size, image_size):
    # Define image transformations
    # TODO: Review if this can be improved.
    transform = transforms.Compose([
        transforms.Resize(round(image_size * 5/4)),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize from [0, 1] to [-1, 1]
    ])

    # Create an instance of the ImageFolder dataset
    dataset = ImageFolder(root=data_dir, transform=transform)

    print(dataset[0][0].size())

    # Create a DataLoader to batch and shuffle the data
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

## Training

In [None]:
def train_model(num_epochs, dataloader, checkpoint=None):

    # Initialize other training parameters
    num_classes = len(dataloader.dataset.classes)
    criterion = nn.MSELoss()
    diffusion = Diffusion(device=Config.device, diffusion_steps = Config.diffusion_steps)
    model = UNet(num_classes=num_classes, device=Config.device).to(Config.device)
    optimizer = optim.AdamW(model.parameters(), lr=Config.learn_rate)

    # Initialize EMA model variation
    ema = EMA(0.995)
    ema_model = deepcopy(model).eval().requires_grad_(False)

    if checkpoint:
        # Initialize parameters for training from last checkpoint
        start_epoch, loss_history = load_checkpoint(checkpoint, model, optimizer, diffusion, ema_model)
    else:
        # Initialize parameters for training from scratch
        start_epoch = 0
        loss_history = []

    for epoch in range(start_epoch, num_epochs):
        model.train()  # Set to model to train mode
        epoch_loss = 0

        # Train model (algorithm 1)
        for x_0, y in tqdm(dataloader, desc="Training"):
            x_0 = x_0.to(Config.device)  # Move images to active compute device
            y = None if torch.rand(1) < Config.train_unconditional_prob else y.to(Config.device) # Randomly choose conditional/unconitional training

            x_t, t, epsilon = diffusion.apply_noise(x_0)  # Apply noise to images (forward process)
            epsilon_pred = model(x_t, t, y)  # Predict noise using model (reverse process)
            batch_loss = criterion(epsilon_pred, epsilon)  # Compute loss

            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            epoch_loss += batch_loss.item()

            # Update EMA model
            ema.step_ema(ema_model, model)

        # Update loss history
        loss_history.append(epoch_loss)

        # Display epoch and loss
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss/len(dataloader):.4f}")

        # Save model training checkpoint
        save_checkpoint({
            'epoch': epoch + 1,
            'loss_history': loss_history,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'diffusion': diffusion.state_dict(),
            'ema_state_dict': ema_model.state_dict()
        }, os.path.join("models", Config.model_name, f"checkpoint"))

        # Sample images from model (algorithm 2)
        if (epoch + 1) % Config.sample_freq_epochs == 0:
            with torch.no_grad():
                model.eval()
                sample_shape = (Config.num_samples, *dataloader.dataset[0][0].size())
                x_t = torch.normal(0, 1, sample_shape, device=Config.device)  # Initialized to x_T

                # Iterate over all reverse diffusion time steps from T to 1
                for t in tqdm(range(diffusion.diffusion_steps, 0, -1), desc="Sampling"):
                    t_vec = t * torch.ones(x_t.shape[0], device=Config.device)

                    label = torch.randint(0, num_classes, (1,)).item()
                    y = torch.full((x_t.shape[0],), label, dtype=torch.long).to(Config.device)
                    epsilon_pred = model(x_t, t_vec, y)
                    if Config.cfg_scale > 0:
                        unc_epsilon_pred = model(x_t, t_vec, None)
                        epsilon_pred = torch.lerp(unc_epsilon_pred, epsilon_pred, Config.cfg_scale)

                    x_t_minus_1 = diffusion.remove_noise(x_t, t, epsilon_pred)
                    x_t = x_t_minus_1

                save_image_batch(x_t, os.path.join("results", Config.model_name, f"{epoch + 1}"))

In [None]:
# Load the checkpoint
if Config.load_from_checkpoint:
    checkpoint = torch.load(os.path.join("models", Config.model_name, f"checkpoint.pt"))
else:
    checkpoint = None

# Initialize dataloader
dataloader = get_dataloader(Config.data_dir, Config.batch_size, Config.image_size)

# Train model
# {'green': 0, 'off': 1, 'red': 2, 'yellow': 3}
torch.cuda.empty_cache()
train_model(1000, dataloader, checkpoint)