In [21]:
from tqdm import tqdm
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from utils import ResidualConvBlock, UnetDown, EmbedFC, UnetUp, plot_sample, CustomDataset, transform

### UNET model Discussion 

This network is the core "brain" of the diffusion model, responsible for predicting noise at each timestep, conditioned on both the time step t and the context (label) c.

Layer 1: expands the channel from 3 to 64 channels with residual connection, (batch, 3, 16, 16) -> (batch, 64, 16, 16) study the features 

Layer 2 and 3 are downsample path, this is part of encoding step 
* down1: (Batch, 64, 16, 16) -> max pooling -> (Batch, 64, 8, 8) no residual connection 
* down2: (Batch, 64, 8, 8) -> max pooling and more filters -> (Batch, 128, 4, 4) no residual connection

Layer 4: (Batch, 128, 4, 4) -> use avg pooling to convert to a latent vector -> (Batch, 128, 1, 1) 

Layer 5,6 7 and 8 are context and time embeddings 
the context can be sound or text anad time is the timestep 
* EmbedFC is just a linear layer with two linear layers with GeLU activation function. 
* t: (1,1) -> (batch,128,1,1)
* context(1,5) -> (batch, 128, 1,1) this is used with down 2 for up1
* t: (1,1) -> (batch, 64, 1, 1)
* context(1,5) -> (batch, 64,1,1) this is used with down 1 for up2

Layer 9, 10 and 11 are upsampling path (decoding step) 
* up0: (128,1,1) from the latent vector of layer 4 and expand it back with ConvTranspose2d back to (Batch, 128, 4, 4)
* up1: use the layer 5,6 context and time embedding with up0 plus skip connection of down 2 to produce upsample of (Batch, 128,8,8)
* up2: use the layer 7,8 context anad time embedding with up1 plus skip connection of down1 to produce upsammple of (Batch, 64, 16, 16)

Layer 12 is the final output layer 
* concat of up3 along with original layer 1 sample (Batch, 128, 16, 16) -> standard 2d convolution -> (Batch, 64, 16, 16) -> standard 2d convolution -> (Batch, 3, 16, 16)

See the ContextUnet_markdown.md for more detail

In [15]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):  # cfeat - context features
        super(ContextUnet, self).__init__()

        # number of input channels, number of intermediate feature maps and number of classes
        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.h = height  #assume h == w. must be divisible by 4, so 28,24,20,16...

        # Initialize the initial convolutional layer
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        # Initialize the down-sampling path of the U-Net with two levels
        self.down1 = UnetDown(n_feat, n_feat)        # down1 #[10, 256, 8, 8]
        self.down2 = UnetDown(n_feat, 2 * n_feat)    # down2 #[10, 256, 4,  4]
        
         # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
        self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())

        # Embed the timestep and context labels with a one-layer fully connected neural network
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
        self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)

        # Initialize the up-sampling path of the U-Net with three levels
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample  
            nn.GroupNorm(8, 2 * n_feat), # normalize                       
            nn.ReLU(),
        )
        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)

        # Initialize the final convolutional layers to map to the same number of channels as the input image
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps   #in_channels, out_channels, kernel_size, stride=1, padding=0
            nn.GroupNorm(8, n_feat), # normalize
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
        )

    def forward(self, x, t, c=None):
        """
        x : (batch, n_feat, h, w) : input image
        t : (batch, n_cfeat)      : time step
        c : (batch, n_classes)    : context label
        """
        # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on

        # pass the input image through the initial convolutional layer
        x = self.init_conv(x)
        # pass the result through the down-sampling path
        down1 = self.down1(x)       #[10, 256, 8, 8]
        down2 = self.down2(down1)   #[10, 256, 4, 4]
        
        # convert the feature maps to a vector and apply an activation
        hiddenvec = self.to_vec(down2)
        
        # mask out context if context_mask == 1
        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
            
        # embed context and timestep
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)     # (batch, 2*n_feat, 1,1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
        # print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")


        up1 = self.up0(hiddenvec)
        up2 = self.up1(cemb1*up1 + temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2 + temb2, down1)
        # print(torch.cat((up3,x),1).shape)
        out = self.out(torch.cat((up3, x), 1))
        # print(f"unnet forward: up1 {up1.shape}. up2 {up2.shape} up3 {up3.shape} out {out.shape}")
        return out


### Denoise model maths discussion 
The denoise process uses this paper: https://arxiv.org/abs/2006.11239

This formula corresponds to the algorithm 2 in the pager (sample generation once the model has been trained) 
Basically from time step t to 0: it does two steps
1. randomly generate a noise for each timestamp so the underlying assumption that the variance is non zero stays True
   random Noise Z * b_t[t], it doesn't add random noise back at timestep zero. 
2. substract predicted noise for that time stamp from X_t to get X_t-1 


The key equation is shown in the denoise_add_noise method and the sample_ddpm is basically the algorithm 2 in the paper 

b_t[t]: amount of noise 

a_t[t]: 1 - b_t[t] amount of signal remain 

ab__t[t]: the cummulative remaining signal  

x_t = x_t-1*ab_t[t].sqrt() + predicted_noise * (1 - ab_t[t]).sqrt() 

this means to predict x_t-1 
(x_t - predicted_noise * (1-ab_[t]).sqrt())/a_t[t].sqrt() 

but the predicted noise need to be the noise at that timestamp not just accumlative noise remain since ab_t is cumulative. 

E[X|Y] = Covar(X,Y)/Var(Y)*Y 

Y = total noise (predicted noise * (1-ab_t[t]).sqrt()) 
X = b_t[t].sqrt()*predicted noise 

Var(Y) = 1 - ab_t[t]
Var(X,Y) = b_t[t] 

E[X|Y] = b_t[t]/(1-ab_t[t])*(1-ab_t[t]).sqrt()*predicted noise = 1-a_t[t]*predicted_noise/(1-ab_t[t]).sqrt() 

In [10]:
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
    return mean + noise

In [11]:
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, save_rate=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # array to keep track of generated steps for plotting
    intermediate = [] 
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0

        eps = nn_model(samples, t)    # predict noise e_(x_t,t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate ==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [12]:
# hyperparameters

# diffusion hyperparameters
timesteps = 1000
beta1 = 1e-4
beta2 = 0.02

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = './weights/'

In [13]:
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()    
ab_t[0] = 1

In [16]:
# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)

## training 

Algorithm 1 in the probability diffusion paper. 

Below are the key steps:  
1. sample original image X0
2. sample random timestamp
3. sample noise with Guassian distribution N(0, I), unit variance (non zero)
4. create noisey image for the model input:
     * Formula: X_t = X_0 * ab_t.sqrt() + noise * (1-ab_t).sqrt(), see perturb_input function, x is x_0
6. feed the noisey image along with timestamp to the Unet model to predict the sampled noise.
7. loss is MSE = || actual_noise - predicted_noise || 

In [18]:
# load dataset and construct optimizer
# training hyperparameters
batch_size = 100
n_epoch = 32
lrate=1e-3

dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.Adam(nn_model.parameters(), lr=lrate)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


In [19]:
# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]).sqrt() * noise

In [None]:
# training without context code

# set into train mode
nn_model.train()

for ep in range(n_epoch):
    print(f'epoch {ep}')
    
    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    
    pbar = tqdm(dataloader, mininterval=2 )
    for x, _ in pbar:   # x: images
        optim.zero_grad()
        x = x.to(device)
        
        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) 
        x_pert = perturb_input(x, t, noise)
        
        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps)
        
        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        
        optim.step()

    # save model periodically
    if ep%4==0 or ep == int(n_epoch-1):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth")
        print('saved model at ' + save_dir + f"model_{ep}.pth")

epoch 0


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 894/894 [11:26<00:00,  1.30it/s]


saved model at ./weights/model_0.pth
epoch 1


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 894/894 [11:24<00:00,  1.31it/s]


epoch 2


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 894/894 [11:29<00:00,  1.30it/s]


epoch 3


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 894/894 [11:25<00:00,  1.30it/s]


epoch 4


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 894/894 [11:29<00:00,  1.30it/s]


saved model at ./weights/model_4.pth
epoch 5


 39%|████████████████████████████████████████████████▊                                                                            | 349/894 [04:29<07:12,  1.26it/s]

In [8]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_trained.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")

Loaded in Model


In [9]:
# visualize samples
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

gif animating frame 56 of 57

<Figure size 640x480 with 0 Axes>

In [None]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_0.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")

In [None]:
# visualize samples
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

In [None]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_4.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")

In [None]:
# visualize samples
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

In [None]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_8.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")b

In [None]:
# visualize samples
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

In [None]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")

In [None]:
# visualize samples
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())