# Unconditional Diffusion Model

In [None]:
from dataclasses import dataclass
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from tqdm import tqdm
from utils.models import UNet

## Configuration

In [None]:
@dataclass
class Config:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # data_dir = r"D:\Data\Datasets\landscape"
    data_dir = r"D:\Data\Datasets\S2TLD_extracted_small"
    model_name = "s2tld2"
    diffusion_steps = 1000
    learn_rate = 1e-4
    batch_size = 3
    image_size = 64
    num_samples = 3  # Number of samples to generate
    sample_freq_epochs = 3  # Frequency with with to generate samples (i.e. every x epochs)
    load_from_checkpoint = True  # Start training from checkpoint if one exists for the given model name

## Utils

In [None]:
def save_image_batch(image_batch, filename):
    """
    Save a batch of image tensors to disk as a single grid image

    Parameters:
    - image_batch: batch of unnormalized image tensors with shape (batch_size, channels, height, width)
    - filename: File path and name to save image without extension 
    """
    # Normalize images to [0, 1]
    image_batch = torch.clamp(image_batch * 0.5 + 0.5, 0, 1)

    # Create image grid and convert to PIL image
    image_grid = make_grid(image_batch)
    image_grid = np.transpose(image_grid.cpu().numpy(), (1, 2, 0))
    pil_image = Image.fromarray((image_grid * 255).astype(np.uint8))

    # Split the path into directory and filename
    path = f"{filename}.png"
    directory, filename = os.path.split(path)

    # Create the directory if it does not exist
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Save image to disk
    pil_image.save(path)


def save_checkpoint(state, filename):
    # Split the path into directory and filename
    path = f"{filename}.pt"
    directory, filename = os.path.split(path)

    # Create the directory if it does not exist
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Save checkpoint state
    torch.save(state, path)


def load_checkpoint(checkpoint, model, optimizer, diffusion):
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    diffusion.load_state_dict(checkpoint['diffusion'])
    return checkpoint['epoch'], checkpoint['loss_history']


def get_dataloader(data_dir, batch_size, image_size):
    # Define image transformations
    # TODO: Review if this can be improved.
    transform = transforms.Compose([
        transforms.Resize(round(image_size * 5/4)),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize from [0, 1] to [-1, 1]
    ])

    # Create an instance of the ImageFolder dataset
    dataset = ImageFolder(root=data_dir, transform=transform)

    print(dataset[0][0].size())

    # Create a DataLoader to batch and shuffle the data
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

## Diffusion Process

In [None]:
class Diffusion:
    def __init__(self, device, diffusion_steps=1000, image_size=64, beta_1=1e-4, beta_T=0.02):
        # Initialize class variables
        self.diffusion_steps = diffusion_steps
        self.device = device
        self.beta_1 = beta_1
        self.beta_T = beta_T

        # Compute beta decay schedule
        self.compute_beta_schedule()

    def compute_beta_schedule(self):
        self.beta = torch.linspace(self.beta_1, self.beta_T, self.diffusion_steps, device=self.device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def apply_noise(self, x_0):
        """
        Apply noise to batch of image tensors based on beta decay schedule at a uniformly distributed time step sample.
        """
        # Sample timestep: t ∼ Uniform({1, . . . , T})
        t = torch.randint(1, self.diffusion_steps + 1, (x_0.shape[0],), device=self.device)

        # Sample noise: epsilon ~ N(0,I)
        epsilon = torch.randn_like(x_0, device=self.device)

        # Apply noise to image based on beta schedule
        sqrt_alpha_bar_t = torch.sqrt(self.alpha_bar[t-1]).view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - self.alpha_bar[t-1]).view(-1, 1, 1, 1)
        x_t = (sqrt_alpha_bar_t * x_0) + (sqrt_one_minus_alpha_bar_t * epsilon)

        return x_t, t, epsilon
    
    def remove_noise(self, x_t, t, epsilon_pred):
        if t > 1:
            z = torch.normal(0, 1, x_t.size(), device=Config.device)
        else:
            z = torch.zeros(x_t.size(), device=Config.device)

        inv_sqrt_alpha_t = 1 / torch.sqrt(self.alpha[t-1])
        one_minus_alpha_t = 1 - self.alpha[t-1]
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - self.alpha_bar[t-1])
        subtract = (one_minus_alpha_t / sqrt_one_minus_alpha_bar_t) * epsilon_pred
        sigma_t = torch.sqrt(self.beta[t-1])

        return inv_sqrt_alpha_t * (x_t - subtract) + sigma_t * z
        

    def state_dict(self):
        """
        Get the current state of the diffusion model for checkpointing.
        """
        return {
            'diffusion_steps': self.diffusion_steps,
            'beta_1': self.beta_1,
            'beta_T': self.beta_T,
            'alpha_t': self.alpha,
            'alpha_bar_t': self.alpha_bar
        }

    def load_state_dict(self, state):
        """
        Load the state of the diffusion model from a checkpoint.
        """
        self.diffusion_steps = state['diffusion_steps']
        self.beta_1 = state['beta_1']
        self.beta_T = state['beta_T']
        self.compute_beta_schedule()

## Training

In [None]:
def train_model(num_epochs, dataloader, checkpoint=None):

    # Initialize other training parameters
    criterion = nn.MSELoss()
    diffusion = Diffusion(device=Config.device, diffusion_steps = Config.diffusion_steps)
    model = UNet(device=Config.device).to(Config.device)
    optimizer = optim.AdamW(model.parameters(), lr=Config.learn_rate)

    if checkpoint:
        # Initialize parameters for training from last checkpoint
        start_epoch, loss_history = load_checkpoint(checkpoint, model, optimizer, diffusion)
    else:
        # Initialize parameters for training from scratch
        start_epoch = 0
        loss_history = []

    for epoch in range(start_epoch, num_epochs):
        model.train()  # Set to model to train mode
        epoch_loss = 0

        # Train model (algorithm 1)
        for x_0, _ in tqdm(dataloader, desc="Training"):
            x_0 = x_0.to(Config.device)  # Move images to active compute device

            x_t, t, epsilon = diffusion.apply_noise(x_0)  # Apply noise to images (forward process)
            epsilon_pred = model(x_t, t)  # Predict noise using model (reverse process)
            batch_loss = criterion(epsilon_pred, epsilon)  # Compute loss

            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            epoch_loss += batch_loss.item()

        # Update loss history
        loss_history.append(epoch_loss)

        # Display epoch and loss
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss/len(dataloader):.4f}")

        # Save model training checkpoint
        save_checkpoint({
            'epoch': epoch + 1,
            'loss_history': loss_history,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'diffusion': diffusion.state_dict()
        }, os.path.join("models", Config.model_name, f"checkpoint"))

        # Sample images from model (algorithm 2)
        if (epoch + 1) % Config.sample_freq_epochs == 0:
            with torch.no_grad():
                model.eval()
                sample_shape = (Config.num_samples, *dataloader.dataset[0][0].size())
                x_t = torch.normal(0, 1, sample_shape, device=Config.device)  # Initialized to x_T

                # Iterate over all reverse diffusion time steps from T to 1
                for t in tqdm(range(diffusion.diffusion_steps, 0, -1), desc="Sampling"):
                    t_vec = t * torch.ones(x_t.shape[0], device=Config.device)
                    epsilon_pred = model(x_t, t_vec)
                    x_t_minus_1 = diffusion.remove_noise(x_t, t, epsilon_pred)
                    x_t = x_t_minus_1

                save_image_batch(x_t, os.path.join("results", Config.model_name, f"{epoch + 1}"))

In [None]:
# Load the checkpoint
if Config.load_from_checkpoint:
    checkpoint = torch.load(os.path.join("models", Config.model_name, f"checkpoint.pt"))
else:
    checkpoint = None

# Initialize dataloader
dataloader = get_dataloader(Config.data_dir, Config.batch_size, Config.image_size)

# Train model
torch.cuda.empty_cache()
train_model(500, dataloader, checkpoint, )