# Lab 04 : Diffusion Model (DDPM) for MNIST Images -- exercise


In [None]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/CS5242_2025_codes/labs_lecture08/lab04_dm_image'
    print(path_to_file)
    # move to Google Drive directory
    os.chdir(path_to_file)
    !pwd

In [None]:
# Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import time

#import utils
import matplotlib.pyplot as plt
import logging
logging.getLogger().setLevel(logging.CRITICAL) # remove warnings
import os, datetime

# PyTorch version and GPU
print(torch.__version__)
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    device= torch.device("cuda") # use GPU
else:
    device= torch.device("cpu")
print(device)


### MNIST dataset 

In [None]:
from utils import check_mnist_dataset_exists
data_path=check_mnist_dataset_exists()

train_data=torch.load(data_path+'mnist/train_data.pt')
print(train_data.size())


In [None]:
# Global constants
n = train_data.size(1) # num of pixels along one spatial dimension
dz = 128 # latent dimension
dID = 128 # hidden dimension for ID features
bs = 100 # batch size
N = train_data.size(0) # num of training data
print('n,dz,dID,bs,N:',n,dz,dID,bs,N)

d = 64 # hidden dimension for image features
d = 48
dPE = 128 # hidden dimension for time 
beta_1 = 0.0001
beta_T = 0.02
num_t = 150
print('beta_1,beta_T,num_t,d,dPE:',beta_1,beta_T,num_t,d,dPE)


### DDPM denoiser with UNet architecture
https://arxiv.org/pdf/1505.04597

Diffusion models require an expressive denoiser to predict the noise that is added to the clean image. A standard denoiser for image is UNet.

The task is to implement UNet, which is designed according to the diagram below:
<center>
<img src="pic/unet.png" style="height:500px"/>
</center>
    
Implement UNet with batch normalization, ReLU activation and residual connection.

Hints: You may use PyTorch modules `nn.Conv2d`, `nn.ConvTranspose2d` and `nn.BatchNorm2d`.


In [None]:
# Network design

class first_block(nn.Module): 
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # COMPLETE HERE
        
    def forward(self, h, t):
        # Add time information t to h
        t = # COMPLETE HERE # [bs, in_dim]
        h = # COMPLETE HERE               # [bs, in_dim, in_n, in_n]
        # First convolution layer
        h = # COMPLETE HERE           # [bs, out_dim, in_n, in_n]
        # Second convolution layer
        h = # COMPLETE HERE       # [bs, out_dim, in_n, in_n]
        return h

class down_sampling_block(nn.Module):
    def __init__(self, in_dim, out_dim, padding):
        super().__init__()
        # COMPLETE HERE
        
    def forward(self, h, t):
        # Add time information t to h
        t = # COMPLETE HERE # [bs, in_dim]
        h = # COMPLETE HERE               # [bs, in_dim, in_n, in_n]        
        # First convolution layer
        h = # COMPLETE HERE       # [bs, out_dim, in_n/2, in_n/2]
        # Second convolution layer
        h = # COMPLETE HERE       # [bs, out_dim, in_n/2, in_n/2]
        # Third convolution layer
        h = # COMPLETE HERE       # [bs, out_dim, in_n/2, in_n/2]
        return h

class up_sampling_block(nn.Module):
    def __init__(self, in_dim, out_dim, output_padding):
        super().__init__()
        # COMPLETE HERE
        
    def forward(self, h_level, h_level_minus_one, t):
        # Add time information t to h
        t = # COMPLETE HERE      # [bs, in_dim=2*out_dim]
        # First convolution layer
        h_level = # COMPLETE HERE  # [bs, out_dim, in_n*2, in_n*2]
        # Concatenate down-sampling and up-sampling
        h = # COMPLETE HERE # [bs, in_dim=2*out_dim, in_n*2, in_n*2]
        # Second convolution layer
        h = # COMPLETE HERE                # [bs, out_dim, in_n*2, in_n*2]
        # Third convolution layer
        h = # COMPLETE HERE            # [bs, out_dim, in_n*2, in_n*2]           
        return h

# Define UNet architecture
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # COMPLETE HERE
    
    def forward(self, h_t, sample_t):
        h = # COMPLETE HERE                        # [bs, 1, n, n], [100, 1, 28, 28]
        t = # COMPLETE HERE                 # [bs, 1], [100, 1]
        h1_down = # COMPLETE HERE            # [bs, d, n, n], [100, 64, 28, 28]
        h2_down = # COMPLETE HERE      # [bs, 2d, n/2, n/2], [100, 128, 14, 14]     
        h3_down = # COMPLETE HERE      # [bs, 4d, n/4, n/4], [100, 256, 7, 7]
        h4_down = # COMPLETE HERE      # [bs, 8d, n/8, n/8], [100, 512, 4, 4]
        h3_up = # COMPLETE HERE # [bs, 4d, n/4, n/4], [100, 256, 7, 7]
        h2_up = # COMPLETE HERE   # [bs, 2d, n/2, n/2], [100, 128, 14, 14]
        h1_up = # COMPLETE HERE   # [bs, d, n, n], [100, 64, 28, 28] 
        h = # COMPLETE HERE                   # [bs, 1, n, n], [bs, 1, 28, 28]
        h = # COMPLETE HERE                             # [bs, n, n], [bs, 28, 28]
        return h

# Define DDPM architecture
class DDPM(nn.Module):

    def __init__(self, num_t, beta_1, beta_T):
        super().__init__()
        self.num_t = num_t
        self.alpha_t = 1.0 - torch.linspace(beta_1, beta_T, num_t).to(device) # [num_t]
        self.alpha_bar_t = torch.cumprod( self.alpha_t, dim=0) # [num_t]
        self.UNet = UNet()

    def forward_process(self, x0, sample_t, eps): # add noise
        sqrt_alpha_bar_t = self.alpha_bar_t[sample_t].sqrt() # [bs]
        sqrt_one_minus_alpha_bar_t = ( 1.0 - self.alpha_bar_t[sample_t] ).sqrt() # [bs]
        x_t = sqrt_alpha_bar_t.view(bs,1,1) * x0 + sqrt_one_minus_alpha_bar_t.view(bs,1,1) * eps # [bs, n, n]
        return x_t

    def backward_process(self, x_t, sample_t): # denoise
        x_t_minus_one = self.UNet(x_t, sample_t) # [bs, n, n]
        return x_t_minus_one

    def generate_process_ppdm(self, num_images):
        t = num_t-1
        batch_t = (t * torch.ones(num_images)).long().to(device)
        batch_x_t = torch.randn(num_images, n, n).to(device) # t=T => t=T-1 in python
        set_t = list(range(t-1,0,-1)); set_t = set_t + [0]
        # print('num_steps:',len(set_t)+1,'set_t:',set_t)
        for t_minus_one in set_t: # for t=T,T-step_size,T-2*step_size,...,step_size,0
            batch_t_minus_one = (t_minus_one * torch.ones(num_images)).long().to(device)
            batch_noise_pred_t = self.backward_process(batch_x_t, batch_t)
            sigma_t = ( (1.0-self.alpha_bar_t[t_minus_one])/ (1.0-self.alpha_bar_t[t])* (1.0-self.alpha_bar_t[t]/self.alpha_bar_t[t_minus_one]) ).sqrt()
            c1 = self.alpha_bar_t[t_minus_one].sqrt() / self.alpha_bar_t[t].sqrt()
            c2 = ( 1.0 - self.alpha_bar_t[t] + 1e-10 ).sqrt()
            c3 = ( 1.0 - self.alpha_bar_t[t_minus_one] - sigma_t.square() + 1e-10 ).sqrt()
            batch_x_t_minus_one = c1 * ( batch_x_t - c2 * batch_noise_pred_t ) + c3 * batch_noise_pred_t + sigma_t* torch.randn(num_images, n, n).to(device)
            t = t_minus_one
            batch_x_t = batch_x_t_minus_one
            batch_t = batch_t_minus_one
        return batch_x_t



# Instantiate the network
net = DDPM(num_t, beta_1, beta_T)
net = net.to(device)
def display_num_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    print('Number of parameters: {} ({:.2f} million)'.format(nb_param, nb_param/1e6))
display_num_param(net)


# Test the forward pass, backward pass and gradient update with a single batch
init_lr = 0.001
optimizer = torch.optim.Adam(net.parameters(), lr=init_lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=1, verbose=True)
idx_images = torch.LongTensor(bs).random_(0,N)
batch_x0 = train_data[idx_images,:,:].to(device) # [bs, n, n]
batch_sample_t = torch.randint(0, num_t, (bs,)).long().to(device) # random interger in {0,1,...,T-1} [bs]
print('batch_sample_t',batch_sample_t.size())
batch_noise_t = torch.randn(batch_x0.size()).to(device) # [bs, n, n]
x_t = net.forward_process(batch_x0, batch_sample_t, batch_noise_t) # [bs, n, n]
print('x_t',x_t.size())
noise_pred_t = net.backward_process(x_t, batch_sample_t) # [bs, n, n]
print('noise_pred_t',noise_pred_t.size())
loss_PPDM = torch.nn.MSELoss()(noise_pred_t, batch_noise_t)
loss = loss_PPDM
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
    batch_x_0 = net.generate_process_ppdm(4)
    print('batch_x_0',batch_x_0.size())


In [None]:
## Training loop
net = DDPM(num_t, beta_1, beta_T)
net = net.to(device)
display_num_param(net)

# Optimizer
init_lr = 0.0003
optimizer = torch.optim.AdamW(net.parameters(), lr=init_lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=1, verbose=True)

# Number of mini-batches per epoch
nb_epochs = 20

# Training loop
start = time.time()
for epoch in range(nb_epochs):

    running_loss = 0.0
    num_batches = 0

    shuffled_indices = torch.randperm(60000)

    for count in range(0,60000,bs):

        idx_images = shuffled_indices[count : count+bs]
        batch_x0 = train_data[idx_images,:,:].to(device) # [bs, n, n]
        batch_sample_t = torch.randint(0, num_t, (bs,)).long().to(device) # [bs]
        batch_noise_t = torch.randn(batch_x0.size()).to(device) # [bs, n, n]
        x_t = net.forward_process(batch_x0, batch_sample_t, batch_noise_t) # [bs, n, n]
        noise_pred_t = net.backward_process(x_t, batch_sample_t) # [bs, n, n]
        loss_PPDM = torch.nn.MSELoss()(noise_pred_t, batch_noise_t)
        loss = loss_PPDM
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # COMPUTE STATS
        running_loss += loss.detach().item()
        num_batches += 1

    # AVERAGE STATS THEN DISPLAY
    total_loss = running_loss/num_batches
    scheduler.step(total_loss)
    elapsed = (time.time()-start)/60
    print('epoch=',epoch, '\t time=', elapsed,'min', '\t lr=', optimizer.param_groups[0]['lr']  ,'\t loss=', total_loss )

    # PLOT
    if epoch>0 and not epoch%5:
        net.eval()
        with torch.no_grad():
            num_generated_images = 16
            batch_x_0 = net.generate_process_ppdm(num_generated_images)
            x_hat = batch_x_0.squeeze().detach().to('cpu')
        figure, axis = plt.subplots(4, 4)
        figure.set_size_inches(10,10)
        i,j,cpt=0,0,0; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=1,0,1; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=2,0,2; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=3,0,3; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=0,1+0,4; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=1,1+0,5; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=2,1+0,6; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=3,1+0,7; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=0,2+0,8; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=1,2+0,9; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=2,2+0,10; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=3,2+0,11; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=0,3+0,12; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=1,3+0,13; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=2,3+0,14; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        i,j,cpt=3,3+0,15; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
        plt.show()
        net.train()

    # Check lr value
    if optimizer.param_groups[0]['lr'] < 2*10**-4: 
        print("\n lr is equal to min lr -- training stopped\n")
        break
         

In [None]:
# Generated images with DDPM

net.eval()
with torch.no_grad():
    num_generated_images = 16
    batch_x_0 = net.generate_process_ppdm(num_generated_images)
    print('batch_x_0',batch_x_0.size())
    x_hat = batch_x_0.squeeze().detach().to('cpu')

figure, axis = plt.subplots(4, 4)
figure.set_size_inches(10,10)

i,j,cpt=0,0,0; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=1,0,1; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=2,0,2; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=3,0,3; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=0,1+0,4; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=1,1+0,5; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=2,1+0,6; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=3,1+0,7; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=0,2+0,8; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=1,2+0,9; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=2,2+0,10; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=3,2+0,11; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=0,3+0,12; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=1,3+0,13; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=2,3+0,14; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')
i,j,cpt=3,3+0,15; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ DDPM"); axis[i,j].axis('off')

plt.show()
