In [None]:
import torch
import yaml
import argparse
import os
import numpy as np
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
import glob
import cv2
import torchvision
from PIL import Image
from torch.utils.data.dataset import Dataset
import torch

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using mps')

# Dataset

In [None]:
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # Define the transformation to normalize the data between 1 and -1 (mean = 0.5 and variance = 0.5 will transform to values between 1 and -1)
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True) # downloading the MNIST train dataset and then applying some transformations
data_loader = DataLoader(dataset=mnist, batch_size=32, shuffle=True) # loading the downloaded dataset


# Modelling

In [None]:
from dit_model import DIT
model = DIT(
            im_size=32,  #128
            im_channels=4,  #3
            config = {
                        'patch_size' : 2,
                        'num_layers' : 12,
                        'hidden_size' : 768,
                        'num_heads' : 12,
                        'head_dim' : 64,
                        'timestep_emb_dim' : 768
                        }
            ).to(device)

# Training

In [None]:
class LinearNoiseScheduler:
    r"""
    Class for the linear noise scheduler that is used in DDPM.
    """

    def __init__(self, num_timesteps, beta_start, beta_end):
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1. - self.betas
        self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)

    def add_noise(self, original, noise, t):
        r"""
        Forward method for diffusion
        :param original: Image on which noise is to be applied
        :param noise: Random Noise Tensor (from normal dist)
        :param t: timestep of the forward process of shape -> (B,)
        :return:
        """
        original_shape = original.shape
        batch_size = original_shape[0]

        sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)

        # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
        for _ in range(len(original_shape) - 1):
            sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
        for _ in range(len(original_shape) - 1):
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)

        # Apply and Return Forward process equation
        return (sqrt_alpha_cum_prod.to(original.device) * original
                + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)

    def sample_prev_timestep(self, xt, pred, t):
        r"""
            Use the noise prediction by model to get
            xt-1 using xt and the noise predicted
        :param xt: current timestep sample
        :param pred: model noise prediction
        :param t: current timestep we are at
        :return:
        """
        x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * pred)) /
              torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
        x0 = torch.clamp(x0, -1., 1.)

        mean = xt - ((self.betas.to(xt.device)[t]) * pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
        mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])

        if t == 0:
            return mean, x0
        else:
            variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
            variance = variance * self.betas.to(xt.device)[t]
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)
            return mean + sigma * z, x0

In [None]:
scheduler = LinearNoiseScheduler(num_timesteps=1000, beta_start=0.0001, beta_end=0.02)
optimizer = AdamW(model.parameters(), lr=1E-5, weight_decay=0)
criterion = torch.nn.MSELoss()
model.train()

In [None]:
if os.path.exists(os.path.join('celebhq', 'dit_ckpt.pth')):
    print('Loaded DiT checkpoint')
    model.load_state_dict(torch.load(os.path.join('celebhq', 'dit_ckpt.pth'), map_location=device))

In [None]:
for epoch_idx in range(500):
    losses = []
    step_count = 0

    for im in tqdm(data_loader):
        step_count += 1
        optimizer.zero_grad()
        im = im.float().to(device)

        mean, logvar = torch.chunk(im, 2, dim=1)
        std = torch.exp(0.5 * logvar)
        im = mean + std * torch.randn(mean.shape).to(device=im.device)

        # Sample random noise
        noise = torch.randn_like(im).to(device)

        # Sample timestep
        t = torch.randint(0, 1000, (im.shape[0],)).to(device)

        # Add noise to images according to timestep
        noisy_im = scheduler.add_noise(im, noise, t)
        pred = model(noisy_im, t)
        loss = criterion(pred, noise)
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    optimizer.step()
    optimizer.zero_grad()
    print('Finished epoch:{} | Loss : {:.4f}'.format(epoch_idx + 1,np.mean(losses)))
    torch.save(model.state_dict(), os.path.join('celebhq', 'dit_ckpt.pth'))