# The Annotated Diffusion Model
- https://huggingface.co/blog/annotated-diffusion

In [1]:
import math
import numpy as np
from inspect import isfunction
from functools import partial
from datasets import load_dataset


import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm.auto import tqdm
from einops import rearrange
from PIL import Image
import requests

import torch
from torch import nn, einsum
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
from torch.optim import Adam
from torchvision.utils import save_image


%matplotlib inline

In [2]:
# Network helpers
# Residual module: adds the input to the output of a particular function 
# (i.e., add residual connection to a particular function)
# Also, alias for the up and down sampling operations

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x, *args, **kwargs):
        # Return the output of fn plus the input of function
        return self.fn(x, *args, **kwargs) + x
    
def upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

def downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)

In [3]:
# Position embeddings
# Sinusoidal position embeddings to encode t, inspired by Transformer
# Input: Tensor (batch_size, 1) with noise levels of various noisy images in a batch
# Output: Tensor (batch_size, dim) where dim = dimensionality of position embeddings
# This is then added to each residual block

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] + embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [4]:
# ResNet/ConvNext block
# DDPM authors used Wide ResNet block
# Phil Wang used ConvNeXT block though eventually removed from his implementation as
# it didn't work well. Nonetheless, we'll continue using it as results were decent.
# ConvNeXT = https://arxiv.org/abs/2201.03545
# SiLU = Sigmoid Linear Unit aka "swish" (https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html)
# GELU = Gaussian Error Linear Units (https://pytorch.org/docs/stable/generated/torch.nn.GELU.html)

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()
        
    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)
        
        if exists(scale_shift):
            scale, shift = scale_shift
            x = (scale + 1) * x + shift
            
        x = self.act(x)
        return x
    

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim)
            else None
        )
        
        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
        
    def forward(self, x, time_emb=None):
        h = self.block1(x)
        
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            h = rearrange(time_emb, 'b c -> b c 1 1') + h
            
        h = self.block2(h)
        return h + self.res_conv(x)
    

class ConvNextBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )
        
        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
        
        self.net = nn.Sequential(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
        )
        
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
        
    def forward(self, x, time_emb=None):
        h = self.ds_conv(x)
        
        if exists(self.mlp) and exists(time_emb):
            condition = self.mlp(time_emb)
            h = h + rearrange(condition, 'b c -> b c 1 1')
            
        h = self.net(h)
        return h + self.res_conv(x)

In [5]:
# Attention which is added in between the conv blocks. Two variants of attention
# Regular multi-head self attention: As used in Transformer
# Linear attention variant: Where the time and memory requirements scale linear in
# sequence length as opposed to quadratic for regular attention

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
        
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
        q = q * self.scale
        
        sim = einsum('b h d i, b h d j -> b h i j', q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))
        
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)
        
        q = q * self.scale
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
        
        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)
        return self.to_out(out)

In [6]:
# Group normalization: DDPM interleaves convolution/attention layers of UNet with group
# normalization. The PreNorm class here will be used to apply gorupnorm before the
# attention layer

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)
        
    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

In [7]:
# Conditional UNet (conditioned on the noise level)
# Input: Batch of noisy images (batch_size, num_channels, height, width) and
# noise level (batch_size, 1)
# Output: Noise added to the output (batch_size, num_channels, height, width)

# Network
# Step 1: Convolutional layer is applied on the batch of noisy images, and position
# embeddings are computed for the noise levels
# Step 2: A sequence of downsampling stages are applied. Each downsampling stage has
# two ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + downsample
# Step 3: In the middle of network, ResNet/ConvNeXT blocks are applied, interleaved
# with attention
# Step 4: A sequence of upsampling stages are applied. Each upsampling stage has
# two ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + upsample
# Self 5: ResNet/ConvNeXT followed by a convolutional layer

class Unet(nn.Module):
    def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8),
                 channels=3, with_time_emb=True, resnet_block_groups=8,
                 use_convnext=True, convnext_mult=2):
        super().__init__()

        # determine dimensinos
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim)
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        upsample(dim_in) if not is_last else nn.Identity()
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def forward(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # upsample
        for block1, block2, attn, upsample in self.ups:
            # Combine the output of x and the hidden layer of last downsampling step
            # See Unet image to understand better
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        return self.final_conv(x)

In [8]:
# Forward diffusion process
# Original DDPM used a linear schedule but Improving Denoising Diffusion Models showed
# that better results can be achived with a cosine schedule

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps = 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x/timesteps) + s) / (1 + s) * torch.pi*0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

In [9]:
# We start with setting T = 200 and define the various variables from Beta_t, such as 
# the cumulative product of the variances (alpha_t). We also define an extract function
# which will extract the appropriate t index for a batch of indices

timesteps = 200

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

# This is sqrt(1/alpha) which is used to downscale the clean image (to make diffusion tractable)
sqrt_recip_alphas = torch.sqrt(1.0/alphas)

# calculations for diffusion q(x_t|x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_O)
# A sequence of posterior variances that is used to add noise during sampling
posterior_variances = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

In [10]:
# How does image weight and noise weight change during forward diffusion?
# Image weight: Decreases from 0.999 to 0.3636
# Noise weight: Increase from 0.010 to 0.932
# Interesting that the sum of image and noise weights don't add to 1

# Sample noisy images from clean image and timestep
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)
    print(f'Image weight: {sqrt_alphas_cumprod_t}, Noise weight: {sqrt_one_minus_alphas_cumprod_t}')
    
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

for i in [0, 1, 25, 50, 75, 100, 150, 199]:
    image_weight = extract(sqrt_alphas_cumprod, torch.tensor([i]), (1,))
    noise_weight = extract(sqrt_one_minus_alphas_cumprod, torch.tensor([i]), (1, ))
    print(f'timestep {str(i):3s} - Image weight: {image_weight.item():.4f}, Noise weight: {noise_weight.item():.4f}')

timestep 0   - Image weight: 0.9999, Noise weight: 0.0100
timestep 1   - Image weight: 0.9998, Noise weight: 0.0173
timestep 25  - Image weight: 0.9826, Noise weight: 0.1858
timestep 50  - Image weight: 0.9357, Noise weight: 0.3527
timestep 75  - Image weight: 0.8636, Noise weight: 0.5042
timestep 100 - Image weight: 0.7723, Noise weight: 0.6353
timestep 150 - Image weight: 0.5617, Noise weight: 0.8273
timestep 199 - Image weight: 0.3636, Noise weight: 0.9316


In [11]:
# How does latent and noise weight change during forward diffusion?
# Latent weight: Stays about 1.00
# Removed noise weight: Reduces from 0.0215 to 0.01
# Added back noise: Reduces from 0.1412 to 0.00
# Thus, we only remove very little noise (max 0.02, min 0.01) but add back in more noise (0.14, min 0)

# Sample output images from noise and timestep
@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper: Use the model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t)/ sqrt_one_minus_alphas_cumprod_t
    )
    
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variances, t, x.shape)
        noise = torch.randn_like(x)
        
        # Algo 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise

for i in reversed([0, 1, 25, 50, 75, 100, 150, 199]):
    betas_t = extract(betas, torch.tensor([i]), (1,))
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, torch.tensor([i]), (1,))
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, torch.tensor([i]), (1,))
    posterior_variance_t = extract(posterior_variances, torch.tensor([i]), (1,))
    posterior_sd_t = torch.sqrt(posterior_variance_t)
    removed_noise_weight = betas_t / sqrt_one_minus_alphas_cumprod_t
    print(f'timestep {str(i):3s} - Latent weight: {sqrt_recip_alphas_t.item():.4f}, ' \
          f'Removed noise weight: {removed_noise_weight.item():.4f}, ' \
          f'Added noise weight: {posterior_sd_t.item():.4f}')

timestep 199 - Latent weight: 1.0102, Removed noise weight: 0.0215, Added noise weight: 0.1412
timestep 150 - Latent weight: 1.0076, Removed noise weight: 0.0183, Added noise weight: 0.1224
timestep 100 - Latent weight: 1.0051, Removed noise weight: 0.0159, Added noise weight: 0.0997
timestep 75  - Latent weight: 1.0038, Removed noise weight: 0.0151, Added noise weight: 0.0862
timestep 50  - Latent weight: 1.0026, Removed noise weight: 0.0145, Added noise weight: 0.0701
timestep 25  - Latent weight: 1.0013, Removed noise weight: 0.0140, Added noise weight: 0.0491
timestep 1   - Latent weight: 1.0001, Removed noise weight: 0.0115, Added noise weight: 0.0082
timestep 0   - Latent weight: 1.0000, Removed noise weight: 0.0100, Added noise weight: 0.0000


In [12]:
# Get sample image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
# image

In [13]:
# Add noise to images via transformations
# Step 1: Normalize images by dividing by 255 (so they are [0, 1] range)
# Step 1: Make sure they are in [-1, 1] range
image_size = 128
transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(),  # Turn into numpy array of shape HWC and divice by 255 so in [0, 1] range
    Lambda(lambda t: (t * 2) - 1)  # Set to [-1, 1] range
])

x_start = transform(image).unsqueeze(0)

In [14]:
x_start.shape

torch.Size([1, 3, 128, 128])

In [15]:
# Reverse transform: Takes in Tensor with values [-1, 1] and return back to image
reverse_transform = Compose([
    Lambda(lambda t: (t + 1) / 2),
    Lambda(lambda t: t.permute(1, 2, 0)),  # CHW to HWC
    Lambda(lambda t: t * 255.),
    Lambda(lambda t: t.numpy().astype(np.uint8)),
    ToPILImage()
])

In [16]:
# reverse_transform(x_start.squeeze())

In [17]:
# Forward diffusino process via the "nice" property
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)
    print(f'Image weight: {sqrt_alphas_cumprod_t}, Noise weight: {sqrt_one_minus_alphas_cumprod_t}')
    
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def get_noisy_image(x_start, t):
    # Add noise
    x_noise = q_sample(x_start, t=t)
    
    # Convert back to image
    noisy_image = reverse_transform(x_noise.squeeze())
    
    return noisy_image

In [18]:
t = torch.tensor([50])
# get_noisy_image(x_start, t)

In [19]:
# Visualize adding noise across various timesteps
# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

In [20]:
# plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 25, 50, 75, 100, 199]])

In [21]:
# Define the loss function
def p_losses(denoise_model, x_start, t, noise=None, loss_type='l1'):
    if noise is None:
        noise = torch.randn_like(x_start)
        
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)
    
    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == 'huber':
        loss = F.smooth_l1_loss(noise, predicted_noised)
    else:
        raise NotImplementedError
    
    return loss

In [None]:
# Define a Dataset (we can use any dataset but here we use Fashion MNIST)
dataset = load_dataset('fashion_mnist')
image_size = 28
channels = 1
batch_size = 128

In [None]:
# Define image transformations hich includes random horizonal flips which the paper 
# reported to improve sample quality slightly
transform = Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 2)
])

def transforms(examples):
    examples['pixel_values'] = [transform(image.convert('L')) for image in examples['image']]
    del examples['image']
    
    return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns('label')

dataloader = DataLoader(transformed_dataset['train'], batch_size=batch_size, shuffle=True)

In [None]:
batch = next(iter(dataloader))
print(batch.keys())

In [None]:
# Sample from the model during training to trakc progress. How sampling is done:
# Step 1: Sample pure noise from a Gaussian
# Step 2: Use neural network to denoise it (via the learn conditional prob)

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper: Use the model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t)/ sqrt_one_minus_alphas_cumprod_t
    )
    
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variances, t, x.shape)
        noise = torch.randn_like(x)
        
        # Algo 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise
    
# Algorithm 2 (and returing all images)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device
    
    b = shape[0]
    
    # Start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []
    
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
        
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

In [None]:
# Define logic to periodically save generated images
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return rr

results_folder = Path('./results')
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

In [None]:
# Define model and move to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)

model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

In [None]:
# Start training
epochs = 5

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        
        batch_size = batch['pixel_values'].shape[0]
        batch = batch['pixel_values'].to(device)
        
        # Algo 1 line 3: Sample t uniformly for each sample in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()
        
        losses = p_losses(model, batch, t, loss_type='huber')
        
        if step % 100 == 0:
            print(f'Loss: {loss.item()}')
            
        loss.backward()
        optimizer.step()
        
    # save generated images
    if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

In [None]:
# Inference: sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

In [None]:
# Create gif of denoising process
random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()