### **Diffusion!**

![](https://raw.githubusercontent.com/okarthikb/diffusion/main/diffusion.png)

Unconditional DDPM training and sampling algorithm from [Ho et. al.](https://arxiv.org/abs/2006.11239) We implement the batched variants of these.

In [None]:
!pip install einops

In [None]:
import torch, torchvision, matplotlib
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import imageio as iio
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from torch import einsum
from torch.optim import Adam
from torchvision import datasets, transforms
from einops import rearrange, repeat
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm import tqdm
from IPython.display import HTML


matplotlib.rcParams['animation.embed_limit'] = 2 ** 128

### U-Net architecture

In [None]:
def position_embeddings(l, d):
  w = 1e-4 ** (repeat(torch.arange(2, d + 2, 2), 'l -> (l 2)') / d)
  t = repeat(torch.arange(1, l + 1), 'l -> l d', d=d)
  pos = w * t
  pos[:, ::2], pos[:, 1::2] = torch.sin(pos[:, ::2]), torch.cos(pos[:, 1::2])
  return pos

In [None]:
# returns Conv2d (down) and ConvTranspose2d (up)
# down and up will be s.t. x.shape = down(up(x)).shape = up(down(x)).shape
def get_conv_down_up(
  x_size, in_channels, out_channels, kernel_size, stride, padding
):
  down = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
  y_size = (x_size + 2 * padding - (kernel_size - 1) - 1) // stride + 1
  out_padding = x_size - ((y_size - 1) * stride - 2 * padding + kernel_size)
  up = nn.ConvTranspose2d(
    out_channels, in_channels, kernel_size, stride, padding, out_padding
  )
  return down, up

In [None]:
# standard multi-head attention, d is channel dim instead of token emb dim here
class MultiHeadAttention(nn.Module):
  def __init__(self, d, n_head=1):  # self attention by default
    super().__init__()
    assert d % n_head == 0, 'n_head must divide d'
    self.d, self.nh = d, n_head
    self.wx, self.wo = nn.Linear(d, 3 * d), nn.Linear(d, d)
    self.gn = nn.GroupNorm(1, d)

  def forward(self, x):
    y = self.wx(rearrange(self.gn(x), 'b c h w -> b (h w) c'))
    q, k, v = rearrange(y, 'b n (t nh dh) -> t b nh n dh', t=3, nh=self.nh)
    hs = F.softmax(einsum('bhic, bhjc -> bhij', q, k) / self.d ** 0.5, -1) @ v
    att = self.wo(rearrange(hs, 'b nh l dh -> b l (nh dh)'))
    return x + rearrange(att, 'b (h w) c -> b c h w', h=x.shape[2])

In [None]:
# ConvNeXt from arXiv:2201.03545
# we add time embedding to first conv layer output
class ConvNeXtBlock(nn.Module):
  def __init__(self, in_channels, out_channels, t_emb_d, m=2):
    super().__init__()
    self.t_fc = nn.Linear(t_emb_d, in_channels) if t_emb_d else None
    self.ds_conv = nn.Conv2d(in_channels, in_channels, 7, 1, 3, 1, in_channels)
    self.sequential = nn.Sequential(
      nn.GroupNorm(1, in_channels),
      nn.Conv2d(in_channels, out_channels * m, 3, 1, 1),
      nn.GELU(),
      nn.GroupNorm(1, out_channels * m),
      nn.Conv2d(out_channels * m, out_channels, 3, 1, 1)
    )

    if in_channels == out_channels:
      self.shortcut = nn.Identity()
    else:
      self.shortcut = nn.Conv2d(in_channels, out_channels, 1)

  def forward(self, x, t_emb=None):
    x_proj = self.shortcut(x)

    if t_emb is None:
      return self.sequential(self.ds_conv(x)) + x_proj
    else:
      t_emb = rearrange((self.t_fc(F.silu(t_emb))), 'b c -> b c 1 1')
      return self.sequential(self.ds_conv(x) + t_emb) + x_proj

![](https://raw.githubusercontent.com/okarthikb/diffusion/main/unet.png)

In [None]:
class UNet(nn.Module):
  def __init__(self, max_t, m, n_head):
    super().__init__()
    self.max_t, self.m = max_t, m

    x_size = 28  # image size
    t_emb_d = 4 * x_size  # time embedding dim = 4 x image size
    self.t_emb = nn.Parameter(
      position_embeddings(max_t + 1, t_emb_d), requires_grad=False
    )

    downscales, upscales = zip(
      get_conv_down_up(x_size, 16, 32, 3, 2, 1),
      get_conv_down_up(x_size, 32, 64, 3, 2, 1),
      get_conv_down_up(x_size, 64, 128, 3, 3, 1)
    )

    self.downscales = nn.ModuleList(downscales)
    self.upscales = nn.ModuleList(reversed(upscales))

    Block = lambda in_channels, out_channels: nn.ModuleList([
      ConvNeXtBlock(in_channels, out_channels, t_emb_d, m),
      MultiHeadAttention(out_channels, n_head)
    ])

    self.downblocks = nn.ModuleList([
      Block(1, 16), Block(32, 32), Block(64, 64), Block(128, 128)
    ])
    self.upblocks = nn.ModuleList([
      Block(128, 64), Block(64, 32), Block(32, 8)
    ])

    self.final_conv = ConvNeXtBlock(8, 1, t_emb_d, m)

  def forward(self, x, t=None):
    cache = []

    conv, attn = self.downblocks[0]
    t_emb = None if t is None else self.t_emb.index_select(0, t)

    x = attn(conv(x, t_emb))
    for down, (conv, attn) in zip(self.downscales, self.downblocks[1:]):
      cache.append(x)
      x = attn(conv(down(x), t_emb))

    for up, (conv, attn) in zip(self.upscales, self.upblocks):
      x = attn(conv(torch.cat([up(x), cache.pop()], 1), t_emb))

    return self.final_conv(x, t_emb)

### Get forward/backward diffusion sampling functions

In [None]:
# function to get forward/reverse diffusion process samplers
def get_samplers(max_t, device, s=0.008, shape=(1, 28, 28)):
  ts = torch.arange(0, max_t + 1, device=device)

  abar = ((torch.pi / 2) * (ts / max_t + s) / (1 + s)).cos() ** 2
  abar = abar / abar[0]

  beta = (1 - abar[1:] / abar[:-1]).clip(1e-4, 1 - 1e-4)
  beta = torch.cat([torch.zeros(1, device=device), beta])
  alpha, sigma = 1 - beta, beta.sqrt()

  sqrt_abar, sqrt_1m_abar = abar.sqrt(), (1 - abar).sqrt()

  coeff1, coeff2 = beta / sqrt_1m_abar, 1 / alpha.sqrt()

  # reshape so we can multiply with image batch of shape (b c h w)
  sqrt_abar = rearrange(sqrt_abar, 'b -> b 1 1 1')
  sqrt_1m_abar = rearrange(sqrt_1m_abar, 'b -> b 1 1 1')

  # forward process sampling at one t step
  def q_sample(x, t):
    eps = torch.randn_like(x, device=x.device)
    mu_t = sqrt_abar.index_select(0, t) * x
    std_t = sqrt_1m_abar.index_select(0, t)
    return mu_t + std_t * eps, eps

  # reverse process sampling all t steps (a trajectory, starting at z ~ N(0, I))
  def p_sample_trajectory(model, n_sample=1, return_all_steps=False):
    model.eval()

    x_t = torch.randn(n_sample, *shape, device=device)
    x_all = [x_t] if all else None
    tb = repeat(ts, 't -> t b', b=n_sample)

    # DDPM sampling algorithm (Algorithm 1)
    with torch.no_grad():
      for t in range(max_t, 0, -1):
        z = torch.randn_like(x_t, device=device) if t > 1 else 0
        eps = model(x_t, tb[t])
        x_t = coeff2[t] * (x_t - coeff1[t] * eps) + sigma[t] * z
        if all:
          x_all.append(x_t)

    return (x_t, x_all) if return_all_steps else x_t

  return q_sample, p_sample_trajectory

### Training

In [None]:
epochs = 30
batch_size = 128
lr = 3e-4
max_t = 500
device = 'cuda'

In [None]:
dataset = MNIST(
  root='./',
  download=True,
  transform=transforms.Compose([
    transforms.ToTensor(), transforms.Lambda(lambda x: 2 * x - 1)
  ])
)

dataloader = DataLoader(
  dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True
)

In [None]:
model = UNet(max_t=max_t, m=4, n_head=1).to(device)
optimizer = Adam(model.parameters(), lr=lr)

In [None]:
# we use q_sample below to return a batch of images, each image sampled from
# a different timestep in the forward diffusion process
q_sample, p_sample_trajectory = get_samplers(max_t, device)

In [None]:
# DDPM training algorithm (Algorithm 2)
model.train()
# epochs = 5  (train few epochs at a time, see if model is producing anything)
for epoch in range(1, epochs + 1):
  bar = tqdm(dataloader, ascii=' >=')
  for x, _ in bar:
    t = torch.randint(1, max_t + 1, (batch_size,), device=device)
    x_t, eps = q_sample(x.to(device), t)
    loss = F.mse_loss(model(x_t, t), eps)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    bar.set_postfix({"loss": f"{loss.item():.4f}"})

### Testing

Rerun cells to sample anew.

In [None]:
n_sample, rows = 25, 5
reshape = lambda x: rearrange(x, '(r f) 1 h w -> (r h) (f w)', r=rows)

x_0, x_all = p_sample_trajectory(model, n_sample, True)

fig = plt.figure(figsize=(6, 6))
plt.imshow(reshape(x_0.cpu().numpy()))
plt.axis('off')
plt.show()

In [None]:
images = []

skip = 2
for x in x_all[max_t % skip::skip]:
  images.append(reshape(x.cpu().numpy()))


def update(i):
  plt.clf()
  plt.imshow(images[i])
  plt.axis('off')


plt.ioff()
fig = plt.figure(figsize=(6, 6))
ani = animation.FuncAnimation(
  fig, update, frames=range(len(images)), interval=200
)

In [None]:
display(HTML(ani.to_jshtml()))