In [None]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np

from DDPM import *

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


### Validate model IO

In [None]:
BATCH_SIZE = 16
INPUT_SIZE = (BATCH_SIZE, 1, 28, 28)

randomLabels = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).to(device)
randomTimes = torch.randint(0, 100, (BATCH_SIZE,)).to(device).to(torch.float32)
classMasks = torch.ones((BATCH_SIZE,)).to(device)

dummyInput = torch.rand(INPUT_SIZE).to(device)


module = UNet(numClasses=NUM_CLASSES)

# profileModel(module, input_size=INPUT_SIZE)

# Create an instance of the nn.module class defined above:
module = module.to(device)

output = module.forward(dummyInput, randomLabels, randomTimes, classMasks)
if output is not None:
    print(output.shape)

### Train DDPM model

In [None]:


# hardcoding these here
n_epoch = 20
batch_size = 128*3
numTimesteps = 400 # 500
device = "cuda:0"
lr = 1e-4
save_model = False
savePath = './DiffusionData/'
guidanceStrengths = [0.0, 0.5, 2.0] # strength of generative guidance

ddpm = DDPM(model=UNet(), betas=(1e-4, 0.02), numTimesteps=numTimesteps, dropoutRate=0.4, device=device)
ddpm.to(device)

# optionally load a model
# ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))

tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1

dataset = MNIST("./data", train=True, download=True, transform=tf)
# dataset = trainset
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
optim = torch.optim.Adam(ddpm.parameters(), lr=lr)

for ep in range(n_epoch):
    print(f'epoch {ep}')
    ddpm.train()

    # linear lrate decay
    optim.param_groups[0]['lr'] = lr*(1-ep/n_epoch)

    pbar = tqdm(dataloader)
    for x, c in pbar:
        optim.zero_grad()
        x = x.to(device)
        c = c.to(device)
        loss = ddpm(x, c)
        loss.backward()

        pbar.set_description(f"loss: {loss.item():.4f}")
        optim.step()

    ddpm.eval()
    with torch.no_grad():
        n_sample = 4*NUM_CLASSES
        for w_i, w in enumerate(guidanceStrengths):
            x_gen, x_gen_store = ddpm.sample(n_sample, (1, 28, 28), classifierGuidance=w)

            # append some real images at bottom, order by class also
            x_real = torch.Tensor(x_gen.shape).to(device)
            for k in range(NUM_CLASSES):
                for j in range(int(n_sample/NUM_CLASSES)):
                    try: 
                        idx = torch.squeeze((c == k).nonzero())[j]
                    except:
                        idx = 0
                    x_real[k+(j*NUM_CLASSES)] = x[idx]

            x_all = torch.cat([x_gen, x_real])
            grid = make_grid(x_all*-1 + 1, nrow=10)
            save_image(grid, savePath + f"image_ep{ep}_w{w}.png")
            print('saved image at ' + savePath + f"image_ep{ep}_w{w}.png")

            if ep%5==0 or ep == int(n_epoch-1):
                # create gif of images evolving over time, based on x_gen_store
                fig, axs = plt.subplots(nrows=int(n_sample/NUM_CLASSES), ncols=NUM_CLASSES,sharex=True,sharey=True,figsize=(8,3))
                def animate_diff(i, x_gen_store):
                    print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
                    plots = []
                    for row in range(int(n_sample/NUM_CLASSES)):
                        for col in range(NUM_CLASSES):
                            axs[row, col].clear()
                            axs[row, col].set_xticks([])
                            axs[row, col].set_yticks([])
                            # plots.append(axs[row, col].imshow(x_gen_store[i,(row*n_classes)+col,0],cmap='gray'))
                            plots.append(axs[row, col].imshow(-x_gen_store[i,(row*NUM_CLASSES)+col,0],cmap='gray',vmin=(-x_gen_store[i]).min(), vmax=(-x_gen_store[i]).max()))
                    return plots
                ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])    
                ani.save(savePath + f"gif_ep{ep}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
                print('saved image at ' + savePath + f"gif_ep{ep}_w{w}.gif")
    # optionally save model
    if save_model and ep == int(n_epoch-1):
        torch.save(ddpm.state_dict(), savePath + f"model_{ep}.pth")
        print('saved model at ' + savePath + f"model_{ep}.pth")