[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/khetansarvesh/CV/blob/main/data_augmentation/ddpm/dit_runner.ipynb)

In [None]:
import torch
import yaml
import argparse
import os
import numpy as np
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
import glob
import cv2
import torchvision
from PIL import Image
from torch.utils.data.dataset import Dataset
import torch

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using mps')

# Dataset

In [None]:
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # Define the transformation to normalize the data between 1 and -1 (mean = 0.5 and variance = 0.5 will transform to values between 1 and -1)
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True) # downloading the MNIST train dataset and then applying some transformations
data_loader = DataLoader(dataset=mnist, batch_size=32, shuffle=True) # loading the downloaded dataset


# Modelling

In [None]:
def get_time_embedding(time_steps, # 1D array of timesteps eg [1,10,500,40,300]
                       temb_dim): # dimension of vector to which each of these timestep needs to be converted to eg 128

    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange(start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)))

    # pos / factor
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor

    # now taking sin and cos of t_emb
    return torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)

In [None]:
self.ti_1(t_emb)

In [None]:
class ae_transformers(nn.Module):

    def __init__(self):
        super().__init__()

        self.patch_embedding = nn.Sequential(nn.LayerNorm(1*4*4), nn.Linear(1*4*4, 768), nn.LayerNorm(768))
        self.position_embedding = nn.Parameter(data=torch.randn(1, 49, 768),requires_grad=True)
        self.embedding_dropout = nn.Dropout(p=0.1)

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer=nn.TransformerEncoderLayer(d_model=768,
                                                                                                  nhead=2,
                                                                                                  dim_feedforward=3072,
                                                                                                  activation="gelu",
                                                                                                  batch_first=True,
                                                                                                  norm_first=True), # Create a single Transformer Encoder Layer
                                                        num_layers=2) # Stack it N times

        # Final Linear Layer
        self.proj_out = nn.Linear(768, 1*4*4)

        # Time projection
        self.ti_1 = nn.Linear(128, 400)
        self.ti_2 = nn.Linear(400, 768)

    def forward(self, x):

        # getting time embeddings
        t_emb = get_time_embedding(torch.as_tensor(t).long(), 128)

        # projecting time embeddings to D = 768 dimensions
        time_proj1 = self.ti_1(t_emb)
        time_proj2 = self.ti_2(time_proj1)

        # 32, 1, 28, 28 -> 32, 1, 7*4, 7*4 -> 32, 1, 7, 7, 4, 4 -> 32, 7, 7, 4, 4, 1 -> 32, 7*7, 4*4*1 - > 32, num_patches, patch_dim
        x = rearrange(x, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)', ph=4, pw=4)

        # Create patch embedding for all images in the batch
        x = self.patch_embedding(x)

        #Add position embedding to patch embedding
        x = self.position_embedding + x

        # concatenating time embedding
        x = torch.cat([x, time_proj2[:, :, None, None]], dim=-1)

        #Run embedding dropout
        x = self.embedding_dropout(x)

        #Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
        x = self.transformer_encoder(x)

        # Unpatchify i.e. (B,patches,hidden_size) -> (B,patches,channels * patch_width * patch_height)
        x = self.proj_out(x)

        # combine all the patches to form image
        x = rearrange(x, 'b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)',ph=4,pw=4,nw=7,nh=7)
        return x

In [None]:
model = DIT(
            im_size=32,  #128
            im_channels=4,  #3
            config = {
                        'patch_size' : 2,
                        'num_layers' : 12,
                        'hidden_size' : 768,
                        'num_heads' : 12,
                        'head_dim' : 64,
                        'timestep_emb_dim' : 768
                        }
            ).to(device)

# Training

In [None]:
batch_size = 64
num_samples = 100
num_grid_rows = 10

model.train()
optimizer = Adam(model.parameters(), lr = 0.0001) #optimizer = AdamW(model.parameters(), lr=1E-5, weight_decay=0)

betas = torch.linspace(0.0001, 0.02, 1000).to(device) # creating a linear beta schedule for all the timestamps
alpha_cum_prod = torch.cumprod(1. - betas, dim=0).to(device) # calculating alpha_bar for each timestamp
sqrt_alpha_cum_prod = torch.sqrt(alpha_cum_prod).to(device) # calculating sqrt(alpha_bar) for each timestamp
sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - alpha_cum_prod).to(device) # calculating sqrt(1-alpha_bar) for each timestamp

In [None]:
if os.path.exists(os.path.join('celebhq', 'dit_ckpt.pth')):
    print('Loaded DiT checkpoint')
    model.load_state_dict(torch.load(os.path.join('celebhq', 'dit_ckpt.pth'), map_location=device))

In [None]:
for epoch in range(40): # running for 40 epochs
  losses = []

  for im,_ in tqdm(mnist_loader):
    optimizer.zero_grad()

    im = im.float().to(device)
    noise = torch.randn_like(im).to(device) # sample random noise
    t = torch.randint(low = 0, high = 1000, size = (im.shape[0],)).to(device) # sample a random timestamp for each image in the batch
    noisy_im = torch.sqrt(alpha_cum_prod[t])[:, None, None, None].to(device) * im + torch.sqrt(1 - alpha_cum_prod[t])[:, None, None, None].to(device) * noise # add noise to image according to the timestamp
    noise_pred = model(noisy_im, t) # predicting the added noise

    loss = torch.nn.MSELoss()(noise_pred, noise) # loss fucntion
    losses.append(loss.item())
    loss.backward() # backpropagating the loss
    optimizer.step()
    
  print('Finished epoch:{} | Loss : {:.4f}'.format(epoch_idx + 1,np.mean(losses)))
  torch.save(model.state_dict(), os.path.join('celebhq', 'dit_ckpt.pth'))
  print()