# **Setting Up A Diffusion UNet**

In this NB, we will train an unconditional diffusion model from scratch which will mostly be built using the pipeline components we've already built for this course - in addition to the model specific components from the **Karras Implementation NB i.e. 21_karras_implementation**. 

## **Setup**

In [5]:
import os, timm, torch, random, datasets, math, fastcore.all as fc
import numpy as np, matplotlib as mpl, matplotlib.pyplot as plt
import k_diffusion as K, torchvision.transforms as T
import torchvision.transforms.functional as TF, torch.nn.functional as F

from torch.utils.data import DataLoader, default_collate
from pathlib import Path
from torch.nn import init
from fastcore.foundation import L
from torch import nn, tensor
from datasets import load_dataset
from operator import itemgetter
from torcheval.metrics import MulticlassAccuracy
from functools import partial
from torch.optim import lr_scheduler
from torch import optim

from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
from miniai.augment import *
from miniai.accel import *

import logging
logging.disable(logging.WARNING)

set_seed(42)
if fc.defaults.cpus>8 : fc.defaults.cpus=8

In [4]:
from fastprogress import progress_bar
from diffusers import UNet2DModel, DDIMPipeline, DDPMPipeline, DDIMScheduler, DDPMScheduler

In [6]:
torch.set_printoptions(precision=5, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'
mpl.rcParams['figure.dpi'] = 70

## **Load Data**

In [7]:
xl, yl = 'image', 'label'
name = "fashion_mnist"
n_steps = 1000
bs = 512
dsd = load_dataset(name)

README.md:   0%|          | 0.00/9.02k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/30.9M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/5.18M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [8]:
# Taking the standard deviation of the input data as sigma. Bear in mind that the inplace tfms
# will have an impact on this value.
sig_data = 0.66

In [9]:
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2, 2, 2, 2))*2-1 for o in b[xl]]

def scalings(sig):
    # Total variance at a particular level of sigma
    totvar = sig**2 + sig_data**2
           #c_skip           , # c_out                   , #c_in
    return sig_data**2/totvar, sig*sig_data/totvar.sqrt(), 1/totvar.sqrt()

def noisify(x0):
    device = x0.device
    # Log normal distribution of sigmas
    sig = (torch.randn([len(x0)])*1.2-1.2).exp().to(x0).reshape(-1, 1, 1, 1)
    noise = torch.randn_like(x0, device=device)
    # Calculate values to pick an input between a clean image and pure noise
    c_skip, c_out, c_in = scalings(sig)
    noised_input = x0 + noise*sig
    # The target is based on a mixture of both noise and clean images with scaling
    # being done by c_out
    target = (x0 - c_skip*noised_input) / c_out
    # Noised input is scaled up or down using c_in
    return (noised_input*c_in, sig.squeeze()), target

def collate_ddpm(b): return noisify(default_collate(b)[xl])
def dl_ddpm(ds)    : return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=0)

In [10]:
tds = dsd.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

In [13]:
tds.shape

{'train': (60000, 2), 'test': (10000, 2)}

## **Train Model**

The unconditional model will be trained using the UNet architecture from previous NBs and the **Diffusers** library.

Additionally, we will be using the [SiLU](https://mlarchive.com/machine-learning/activation-functions-all-you-need-to-know/) or the Sigmoid Activation Function.

![title](imgs/SiLU.png)

Using the same convolution as the one from Tiny Imagenet, which is also called the **pre-activation convolution**. 

Preactivation convolution refers to a specific architectural design in NNs where the batch normalization and activation functions are applied before the convolution operation. This approach is primarily associated with enhancing the performance of deep learning models, particularly in residual networks (ResNets).

In [14]:
def unet_conv(ni, nf, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True):
    layers = nn.Sequential()
    if norm: layers.append(norm(ni))
    if act : layers.append(act())
    layers.append(nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias))
    return layers

In [15]:
# The structure is same as previous ResNet blocks with the exception that there is no option
# for down sampling and strides. That will be featured in the down_block(). This approach is similar
# to the one used in Diffusers.
class UnetResBlock(nn.Module):
    def __init__(self, ni, nf=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d):
        super().__init__()
        if nf is None: nf = ni
        self.convs = nn.Sequential(unet_conv(ni, nf, ks, act=act, norm=norm),
                                   unet_conv(nf, nf, ks, act=act, norm=norm))
        self.idconv = fc.noop if ni==nf else nn.Conv2d(ni, nf, 1)

    def forward(self, x): return self.convs(x) + self.idconv(x)

By not adding _stride_ and _down-sampling_ to `UnetResBlock()`, we are ensuring that our approach is similar to the one used in the original `DDPM` architecture.

We will try to simplify how different down-sampling blocks can be incorporated into UNets. One way to do this is to introduce the `SavedResBlock()` and `SavedConv()` modules. These two components have similar functionality as ResBlock() and Conv(), but are also able to store the activations. This makes the activations accessible as we develop the model architecture.

In [17]:
class SaveModule:
    # Calls forward to grab the ResBlock and Conv results and stores them.
    def forward(self, x, *args, **kwargs):
        # Using Mixin which contains methods for use by other classes (multiple inheritance)
        # without having to be the parent class of those other classes.
        self.saved = super().forward(x, *args, **kwargs) 
        return self.saved

# These classes only carry out Mixin ops for the target classes.
class SavedResBlock(SaveModule, UnetResBlock): pass # multiple inheritance, First call is used with the second argument 
class SavedConv(SaveModule, nn.Conv2d): pass        # same as above. This allows UnetResBlock and Conv2d outputs to be saved.

In [18]:
def down_block(ni, nf, add_down=True, num_layers=1):
    # SaveModule ops used Sequentially.
    res = nn.Sequential(*[SavedResBlock(ni=ni if i==0 else nf, nf=nf)
                         for i in range(num_layers)])
    # Carry out down sampling if needed.
    if add_down: res.append(SavedConv(nf, nf, 3, stride=2, padding=1))
    return res

In [23]:
# Upsampling will be done with a sequence of upsampling layers - followed by a simple 3x3 conv.
# Again this approach is the one preferred by the Stable Diffusion team.
def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1))

In [24]:
class UpBlock(nn.Module):
    # Storing previous number of filters.
    def __init__(self, ni, prev_nf, nf, add_up=True, num_layers=2):
        super().__init__()
        # Using the saved results in the upsampling path
        self.resnets = nn.ModuleList(
            [UnetResBlock((prev_nf if i==0 else nf) + (ni if (i==num_layers-1) else nf), nf)
             for i in range(num_layers)])
        # Add an upsampling layer if asked.
        self.up = upsample(nf) if add_up else nn.Identity()

    def forward(self, x, ups):
        # Call each resnet as we progress in the upsampling path. Concatenate downsampling activations with each upsampling
        # layer at the end.
        for resnet in self.resnets: x = resnet(torch.cat([x, ups.pop()], dim=1)) # Concatenate
        return self.up(x)