# Diffusion Model for Jupyter Notebook

**Author:** Elisa Warner  
**Email:** elisawa@umich.edu  
**Date:** 04/12/2023

## Load Libaries

In [None]:
import numpy as np
import torchvision
import torch
import torchvision.transforms
import torchvision.models as models
from unet_mha import *
from config import *
import os

## Utility Functions for Diffusion Model

In [None]:
class Utility_Diffusion():
    """
    Saves and calculates parameters needed for the diffusion model
    
    Parameters:
     - beta_start = the min range of the beta scheduler
     - beta_end = the max range of the beta scheduler
     - t = total time steps
    """
    
    def __init__(self, beta_start=BETA_START, beta_end=BETA_END, time=T):
        self.beta = np.linspace(beta_start, beta_end, time)
        self.alpha_t = 1 - self.beta
        self.alpha_bar_t = np.cumprod(self.alpha_t)
        self.T = time
    
    def Samplet(self, N):
        """
        Samples a time step t from a uniform distribution
        """
        return np.random.randint(0, self.T, size=N)
    
    def SampleNoise(self, N):
        """
        Samples normally distributed noise at the same size as the image
        """
        return np.random.normal(size = (N,3,SQ_SIZE,SQ_SIZE))
    
    # forward process
    def GetXt(self, x0, t, noise):
        """
        Performs the forward process for adding step-wise Gaussian noise to the image x0
        """
        beta_t = self.beta[t]
        alpha_t = self.alpha_t[t]
        alpha_bar_t = self.alpha_bar_t[t]
        return (np.sqrt(alpha_bar_t) * x0) + (np.sqrt(1 - alpha_bar_t) * noise)

## Load Dataset

In [None]:
# Define transforms applied to images
transforms = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Resize((SQ_SIZE,SQ_SIZE), antialias=True),
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [None]:
# Load dataset and dataloader
dataset = torchvision.datasets.ImageFolder(cat_directory, transform=transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
print("This dataset contains %s examples" % len(dataset))

### A. View Random Samples

In [None]:
# test that the Dataset works
import matplotlib.pyplot as plt
show_this_many = 3

fig, ax = plt.subplots(1,show_this_many)
for i in range(show_this_many):
    idx = np.random.randint(len(dataset))
    
    x = dataset[idx][0].T
    x = (x.clamp(-1, 1) + 1) / 2
    x = (x * 255).type(torch.uint8)
    
    ax[i].imshow(x)
    ax[i].axis('off')

### B. View Forward Process

In [None]:
# test noise generator
steps = 10
fig, ax = plt.subplots(1, steps, figsize=(20,40))
stepsize = int(T / 10)
sample_params = Utility_Diffusion() 

for i in range(steps):
    
    x = sample_params.GetXt(dataset[0][0], i*stepsize, sample_params.SampleNoise(1)[0,:,:,:]).T
    x = (x.clamp(-1, 1) + 1) / 2
    x = (x * 255).type(torch.uint8)
    
    ax[i].imshow(x)
    ax[i].set_title(i*stepsize)
    ax[i].axis("off")
plt.show()

## Load Model

In [None]:
# import model architecture
model = UNet(3, 3)
if device != "cpu":
    model = torch.nn.DataParallel(model)
model.to(device)

# initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9) #, weight_decay=WD)

In [None]:
# import saved parameters
if MODEL_OUT in os.listdir("."):
    savedData = torch.load(MODEL_OUT)
    startEpoch = savedData['epoch']
    model.load_state_dict(savedData['model_state_dict'])
    optimizer.load_state_dict(savedData['optimizer_state_dict'])
    print("Model and optimizer loaded. Model left off at epoch", startEpoch+1)
else:
    print("No model found. Creating new model.")
    startEpoch = 0
    
    with open(RESULTS_OUT, "wb") as f:
        f.write(("Begin training.\n").encode())

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', min_lr=1e-7)

In [None]:
# view optimizer settings
optimizer

Run the below cell to change optimizer parameters after initialization (optional)

## Training Loop

In [None]:
# Training loop
print("Begin training.")

model.train()
diff_params = Utility_Diffusion()

for epoch in range(startEpoch+1, EPOCHS):
    total_loss = 0
    
    for idx, (img_batch, _) in enumerate(train_dataloader):
        print("Progress: {}%".format(np.round(idx / len(train_dataloader) * 100),1), end = "\r")
        
        optimizer.zero_grad()
        
        # sample
        n = img_batch.shape[0]
        batch_t = diff_params.Samplet(n)
        batch_noise = torch.Tensor(diff_params.SampleNoise(n))
        
        # generate noisy image
        batch_noisy_img = torch.zeros((img_batch.shape))
        
        for i in range(n):
            batch_noisy_img[i,:,:,:] = diff_params.GetXt(img_batch[i,:,:,:], batch_t[i], batch_noise[i,:,:,:])
        
        # prediction
        pred = model(batch_noisy_img.to(device), torch.Tensor(batch_t).to(device))
        
        # loss
        loss = torch.nn.MSELoss()(batch_noise.to(device), pred)
        total_loss += loss.cpu().detach().numpy()
        
        # backward pass
        loss.backward()
        optimizer.step()
        scheduler.step(loss)
        
    ##### DOCUMENT #####
    print("Epoch %s: %s" % (epoch, total_loss / len(train_dataloader)))
    
    with open(RESULTS_OUT,"ab") as f:
        f.write( ("+-- Epoch %s: %s\n" % (epoch, (total_loss / len(train_dataloader)))).encode())
    
    #### save model #####
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, MODEL_OUT)

### UTILITY Clear Cache : Run if CUDA memory is full

Use this code if you run into a CUDA memory error to clear the cache. Sometimes the below code does not work and there is still a CUDA memory error. If this is the case, you may have to try the following:  
    1. Refresh the notebook  
    2. Exit Jupyter Notebook and restart. Alternatively. change the kernel to no kernel and then back to Python 3.    
    3. Restart computer/virtual instance  

## Sampling Code

### A. Load model

In [None]:
from config import *
from unet import *
import os
import matplotlib.pyplot as plt

os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # debug

In [None]:
# reverse process
def Sample():
    sample_params = Utility_Diffusion()
    xT = torch.Tensor(sample_params.SampleNoise(1)).to(device)
    
    xt = xT
    for t in range(T-1, -1, -1):
       
        # assign z
        if t > 1:
            z = torch.randn_like(xt)
        else:
            z = 0
        
        # assign alpha, alpha-bar
        beta_t = torch.Tensor([sample_params.beta[t]])
        alpha_t = torch.Tensor([sample_params.alpha_t[t]]).to(device)
        alpha_bar_t = torch.Tensor([sample_params.alpha_bar_t[t]]).to(device)
        sqrt_beta = torch.sqrt(beta_t).to(device)[0]
        
        # sample
        model.eval()
        with torch.no_grad():
            xt = (1 / torch.sqrt(alpha_t)) \
                * ((xt - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * model(xt, torch.ones(1, dtype=torch.long)*t)) + (sqrt_beta * z))
    
    xt = xt.detach().cpu()
    return xt

In [None]:
savedData = torch.load(MODEL_OUT)
epoch = savedData['epoch']

# import model
model = UNet(3, 3)

if device != "cpu":
    model = torch.nn.DataParallel(model)
model.to(device)

model.load_state_dict(savedData['model_state_dict'])

savedData = 0
print(epoch)

In [None]:
x = Sample()
x = (x.clamp(-1, 1) + 1) / 2
x = (x * 255).type(torch.uint8)

plt.imshow(x[0].T)
plt.axis('off')
plt.savefig("Generated_Image_Epoch_%s" % epoch)