# Implement Diffusion Models from Scratch
In this assignment, you will implement a small version of the famous [stable diffusion](https://arxiv.org/abs/2112.10752).

You will 
- Implement each module in the diffusion model to increase your understanding of its composition.
- Implement a simple noise sampling process to understand the process of diffusion
- Use the model you constructed to train on the mnist dataset, observe the sampling results, and experience the power of the diffusion model

## Setup
We recommend working on Colab with GPU enabled since this assignment needs a fair amount of compute.
In Colab, we can enforce using GPU by clicking `Runtime -> Change Runtime Type -> Hardware accelerator` and selecting `GPU`.
The dependencies will be installed once the notebook cells are excuted.

In [None]:
!pip install datasets
!pip install einops
!pip install transformers

In [78]:
import os
import io
import math
import urllib
import random
import requests
from pathlib import Path
from functools import partial

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam

from torchvision import transforms
from torchvision.utils import save_image

from transformers import CLIPTokenizer, CLIPTextModel

from datasets import load_dataset
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.animation as animation

import datetime
import numpy as np
from tqdm.notebook import trange, tqdm

torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#@title random seed
def set_seed(seed: int = 3207, verbose=False) -> None:
    """Set random seed for reproducibility."""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    if verbose:
        print(f"Random seed set as {seed}")

set_seed(verbose=True)

In [80]:
#@title Utilities for Testing
def save_auto_grader_data():
    torch.save(
        {'output': auto_grader_data['output']},
        'autograder.pt'
    )

def rel_error(x, y):
    return torch.max(
        torch.abs(x - y)
        / (torch.maximum(torch.tensor(1e-8), torch.abs(x) + torch.abs(y)))
    ).item()

def check_error(name, x, y, tol=1e-3):
    error = rel_error(x, y)
    if error > tol:
        print(f'The relative error for {name} is {error}, should be smaller than {tol}')
    else:
        print(f'The relative error for {name} is {error}')

def check_loss(loss, threshold):
    if loss > threshold:
        print(f'The minimum loss {loss} should <= threshold loss {threshold}')
    else:
        print(f'The minimum loss {loss} is smaller than threshold loss {threshold}')

def load_from_url(url):
    return torch.load(io.BytesIO(urllib.request.urlopen(url).read()))

test_data = load_from_url('https://github.com/jun-tian/CS182_Project_diffusion/raw/main/Diffusion/test_data.pt')
auto_grader_data = load_from_url('https://github.com/jun-tian/CS182_Project_diffusion/raw/main/Diffusion/auto_grader_data.pt')
auto_grader_data['output'] = {}

## Implement Modules in the Diffusion Model

Below, you'll implement different modules in the diffusion model, including Resnet blocks, Spatial Transformer blocks and the U-Net backbone. It's important to note that this implementation is a simplified version of the official stable diffusion architecture used in real applications. To make it easier to implement, we'll offer some architecture diagrams to help you understand those modules.

In [6]:
#@title Helper Functions 
# (You will utilize some of the helper functions in your implementation)

def exists(x):
    """return true if x is not none"""
    return x is not None

def default(val, d):
    """return val if exists(val) else d"""
    if exists(val):
        return val
    return d() if callable(d) else d

# normalization functions
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

def normalization(dim):
    return nn.GroupNorm(num_groups=16, num_channels=dim)

# upsample and downsample
def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )

### ResNet Module
First, you'll need to implement the ResNet Module below. Note that this Resnet is different than what you have learned in lecture before, because we need to combine the timestep into the forward pass. For here, you don't need know what timestep is yet, just remember it is a tensor of size `[Batchsize, time_emb_dim]`.
To help you get the correct answer, we put a architecture diagram in the below.

<img src="https://github.com/jun-tian/CS182_Project_diffusion/blob/main/images/resnet.png?raw=true" alt="resnt" width="600" height="500" align="bottom" />

In [7]:
#@title Build Your Resnet

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, time_emb_dim = None, dropout=0.1):
        """
        Args:
          dim: input dimension
          dim_out: output dimension
          time_emb_dim: the dimension for the timestep embedding
          drop_out: probability of dropout
        """
        super().__init__()
        
        ########################################################################
        # TODO: Define  layers for time_embed, input, output and skip connection.
        #       Hint: 1. Use nn.Sequential to contain different layers in one element
        #             2. We use SiLU as the out activation function
        #             3. Feel free to use the helper function normalization defined above
        #             4. The skip connection may require a Conv if the input and output
        #                dimensions do not match
        ########################################################################
        raise NotImplementedError()
        ########################################################################

    def forward(self, x, timestep):
        """
        Args:
            x: input tensor [B, C, H, W]
            timestep: timestep [B, time_emb_dim]
        """

        time_emb = None
        scale_shift = None
        h = None
        output = None

        ########################################################################
        # TODO: Implement the forward pass of the `ResnetBlock` class.
        #       Hint: 1. Please refer to the diagram above
        #             2. einops.rearrange and tensor.chunk() function maybe useful
        ########################################################################
        raise NotImplementedError()
        ########################################################################
        return output

In [None]:
#@title Test for Resnet
dim = 32
dim_out = 32
time_emb_dim = 10
dropout = 0
batchsize = 10
image_size = 32
model = ResnetBlock(dim, dim_out, time_emb_dim, dropout)

# test
model.load_state_dict(test_data['weights']['resnet'])
x, timestep = test_data["input"]["resnet"]
y = test_data["output"]["resnet"]
output = model(x, timestep).detach()
check_error("resnet", output, y)


# auto_grader  
model.load_state_dict(auto_grader_data['weights']['resnet'])
x, timestep = auto_grader_data["input"]["resnet"]
output = model(x, timestep).detach()
auto_grader_data["output"]["resnet"] = output

### Text Embedding
A key difference from past diffusion models is that latent diffusion models incorporate general-purpose conditioning mechanisms to condition the diffusion process with, for example, text, images, or layout maps. This opens the way for image generation from a prompt, image-to-image generation, and more.

This is achieved by creating an embedding of the conditioning $y$ using a domain specific encoder $\tau_{\theta}$. In the case of text conditioning, this can be done by using a CLIP or BERT embedder to first generate token embeddings and then passing the embeddings through a transformer. In this demo we will be using CLIP embeddings along with the frozen CLIP transformer. 

This conditioning is incorporated into the model by augmenting the U-Net with cross-attention, where the keys and values are created from $\tau_{\theta}(y)$ and the queries come from the representation of the U-Net at the specific layer. The same conditioning $\tau_{\theta}(y)$ is incorporated in each attention mechanism at every denoising iteration.

<img src="https://github.com/jun-tian/CS182_Project_diffusion/blob/main/images/ldm.png?raw=true" width="650"/>

Since it is based on the CLIP API, we have implemented it for you.

In [18]:
class FrozenCLIPEmbedder(nn.Module):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="openai/clip-vit-base-patch16", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)

        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)

### Attention Module

Next, we need to implement the attention module in the diffusion model. As we know, stable diffusion is a text conditional model. So here we will use a Spatial Transformer to incorporate textual information into our forward pass process. 

Below, we divide the Spatial Transformer into four classes. Implementing the first three classes can help you to understand the Spatial Transformer more easily. This is a modular programming thought.

In [14]:
#@title Build Your Spatial Transformer
class FeedForward(nn.Module):
    """Feed Forward layer with GELU activation and dropout."""
    def __init__(self, dim, dim_out=None, mult=4, dropout=0.1):
        super().__init__()

        self.net = None
        ########################################################################
        # TODO: Define layers for the `FeedForward` class.
        #       Hint: 1. Linear -> GELU -> Dropout -> Linear
        #             2. inner dim = mult x dim
        #             3. If dim_out is None, then just use dim as the output dimension
        ########################################################################
        raise NotImplementedError()
        ########################################################################

    def forward(self, x):
        return self.net(x)


class CrossAttention(nn.Module):
    """A standard multi heads Cross Attention layer"""
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.1):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.heads = heads

        self.scale = None
        self.to_q = None
        self.to_k = None
        self.to_v = None
        self.to_out = None

        ########################################################################
        # TODO: Define layers for the `CrossAttention` class.
        #       Hint: 1. Define the weights for q, k, v. (Linear layer without bias)
        #             2. scale is a coefficient for scaled dot attention
        #             3. to_out is the output projection (Linear->Dropout)
        ########################################################################
        raise NotImplementedError()
        ########################################################################

    def forward(self, x, context=None):
        """
        Args:
          x: input tensor [B, C, H, W] or [B, (H, W), C]
          context: text embedding [B, K, E]
        """

        context = default(context, x) # if context is None, then do self-attention
        output = None
        ########################################################################
        # TODO: Implement the forward pass of the `CrossAttention` class.
        #       Hint: Remember there are multiple heads. rearrange and einsum maybe useful 
        ########################################################################
        raise NotImplementedError()
        ########################################################################
        return output


class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0.1, context_dim=None):
        super().__init__()

        self.attn1 = None
        self.attn2 = None
        self.ff = None
        self.norm1 = None
        self.norm2 = None
        self.norm3 = None
        ########################################################################
        # TODO: Define layers for the `BasicTransformerBlock` class.
        #       Hint: 1. Use the CrossAttention and FeedForward Class
        #             2. attn1 is self-attention; attn2 is cross-attention
        #             3. We adopt LayerNorm as the normalization function here
        ########################################################################
        raise NotImplementedError()
        ########################################################################


    def forward(self, x, context=None):
        """
        Args:
          x: input tensor [B, C, H, W] or [B, (H,W), C]
          context: text embedding [B, K, E]
        """
        output = None
        ########################################################################
        # TODO: Implement forward pass for the `BasicTransformerBlock` class.
        #       Hint: 1. Can be done in three lines
        #             2. We use residual connections in each layer
        ########################################################################
        raise NotImplementedError()
        ########################################################################
        return output


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    """
    def __init__(self, in_channels, n_heads=4, d_head=32,
                 depth=1, dropout=0.1, context_dim=512):
        super().__init__()
        input_dim = in_channels
        inner_dim = n_heads * d_head
        output_dim = in_channels
        self.norm = nn.GroupNorm(num_groups=16, num_channels=in_channels, eps=1e-6, affine=True)

        self.proj_in = None
        self.transformer_blocks = None
        self.proj_out = None
        ########################################################################
        # TODO: Define layers for the `SpatialTransformer` class.
        #       Hint: 1. proj_in and proj_out are projection layer for altering dimensions, with kernel_size=1 and stride=1
        #             2. depth is the number of BasicTransformerBlock we use in our transformer blocks
        ########################################################################
        raise NotImplementedError()
        ########################################################################

    def forward(self, x, context=None):
        """
        note: if no context is given, cross-attention defaults to self-attention
        Args:
            x: input image, shape: B, C, H, W
            context: text embedding, shape: B, K, E
        """
        output = None
        ########################################################################
        # TODO: Implement forward pass for the `SpatialTransformer` class.
        #       Hint: 1. norm -> proj_in -> inner transformer -> proj_out
        #             2. Again, there are multiple heads, rearrange maybe useful
        #             2. We use residual connection with the output and the input
        ########################################################################
        raise NotImplementedError()
        ########################################################################
        return output

In [None]:
#@title Test for Attention
in_channels = 16
n_heads = 4
d_head=8
depth=1
dropout=0
context_dim=16
model = SpatialTransformer(in_channels, n_heads, d_head,
              depth, dropout, context_dim)

# test
model.load_state_dict(test_data['weights']['attention'])
x, context = test_data["input"]["attention"]
y = test_data["output"]["attention"]
output = model(x, context).detach()
check_error("attention", output, y)

# auto_grader
model.load_state_dict(auto_grader_data["weights"]["attention"])
x, context = auto_grader_data["input"]["attention"]
output = model(x, context).detach()
auto_grader_data["output"]["attention"] = output 

### Position Embeddings
As the parameters of the neural network are shared across time (noise level), we employ sinusoidal position embeddings to encode t. This makes the neural network "know" at which particular time step (noise level) it is operating, for every image in a batch.

The SinusoidalPositionEmbeddings module takes a tensor of shape `(batch_size, 1)` as input (i.e. the noise levels of several noisy images in a batch), and turns this into a tensor of shape `(batch_size, dim)`, with dim being the dimensionality of the position embeddings. This is then added to each residual block, as we have seen in the resnet blocks .

We have already provide it to you, because it is not our focus in this assignment. But we recommend you to read the code.

In [17]:
class SinusoidalPosEmb(nn.Module):
    """
    Build sinusoidal embeddings.
    This matches the implementation in Denoising Diffusion Probabilistic Models.
    Used for both position and time embeddings.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        # input: timesteps [B, ]
        # output: embeddings [B, dim]
        assert len(time.shape) == 1

        device = time.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=device) * -emb)
        emb = time.float()[:, None] * emb[None, :]
        emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
        if self.dim % 2 == 1:  # zero pad
            emb = torch.nn.functional.pad(emb, (0,1,0,0))
        return emb

### Unet Model
After implementing the different modules above, we will now utilize them to build our own actual diffusion model. The main components of the diffusion model is a Denoising UNet, which is also called the noisy predictor. 

Below is a picture of general Resnet-based Unet, which may help you understand the Unet architecture, but the actual architecture to be implemented will differ. To be more specific, we will use a mix of Resnet and Spatial Transformer blocks in the process of downsampling and upsampling.

<img src="https://github.com/jun-tian/CS182_Project_diffusion/blob/main/images/resnet_unet.png?raw=true" width="600"/>

In [19]:
#@title Build Your Unet Model
class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        context_dim = None,
        dropout=0.1
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim) # init dimension after init_conv
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:])) # [(init_dim, dim), (dim, dim * 2), ...]

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        time_dim = dim * 4
        resnet_block = partial(ResnetBlock, time_emb_dim=time_dim, dropout=dropout)
        transformer_block = partial(SpatialTransformer, dropout=dropout, context_dim=context_dim)

        ########################################################################
        # TODO: Define layers for the `UNet` class.
        ########################################################################

        # Implement time embeddings (Hint: SinusoidalPosEmb->Linear->SiLU->Linear)
        # We will change the dim in the first Linear layer(from dim to time_dim))
        self.time_mlp = None

        # Implement the initial convolutional layer (kernel size=7, padding=3)
        self.init_conv = None

        # Implement the downsampling process
        # Hint: 1. use the in_out list to get the input and output dimension in each down sample iteration
        #       2. We use Resnet -> Attention -> Resnet -> Attention -> Down Sample in each downsample iteration
        #       3. The helper function Downsample maybe useful. 
        #       4. If it is the last layer before the middle layers, we do conv(kernel=3, padding=1)
        #          instead of downsample
        raise NotImplementedError()

        # Implement the middle bottleneck layers
        # Hint 1. Keep the dim as mid_dim
        #      2. Resnet -> Attention -> Resnet
        raise NotImplementedError()


        # Implement the upsampling process
        # Hint: 1. The structure is the same as downsampling 
        #         (Note that we use a residual connection here, so take care of the input dimension)
        #       2. We use Resnet -> Attention -> Resnet -> Attention -> Upsample in each upsample iteration
        #       3. The helper function Upsample maybe useful. 
        #       4. If it is the last iteration for upsampling, we do conv(kernel=3, padding=1) instead of up sample
        raise NotImplementedError()

        # Implement the final output layers
        # Hint: 1. Since we have the residual connection in the Unet, 
        #          the input channel for the output layer is 2 x init_dim
        #       2. First use a single layer to change the dimension from 2 x init_dim to
        #          init_dim (kernel size=3, padding=1)
        #       3. Then do normalization -> SiLU -> Conv(kernel size=3, padding=1)
        self.final_conv = None
        self.out = None
        
        ########################################################################

    def forward(self, x, time, context):
        """
        forward pass of the unet model
        Args:
            x: input image [Batch_size, Channels, H, W]
            time: timestep [Batch_size, 1]
            text: text embedding [Batch_size, K, E]
        Returns:
            the predicted noise [Batch_size, Channels, H, W]
        """
        output = None
        ########################################################################
        # TODO: Implement forward pass for the `UNet` class.
        # Hint: 1. We use residual connections here, remember to save your temp parameters
        #          in the downsampling phase and concat it to the parameters in upsampling
        #       2. More specifically, save the model input and the outputs for each attention
        #          layers in the downsampling
        #       3. Concatenate them to the corresponding model output(before final_conv)
        #          and inputs for each resnet block in the upsampling phase
        ########################################################################
        raise NotImplementedError()
        ########################################################################
        
        return output

In [None]:
#@title Test for Unet Model
dim = 16
init_dim=16
model = Unet(dim, init_dim, 
              out_dim = None, dim_mults=(1, 2, 4),
              channels = 1, context_dim = 10,
              dropout=0)

# test
model.load_state_dict(test_data['weights']['unet'])
x, timestep, context = test_data["input"]["unet"]
y = test_data["output"]["unet"]
output = model(x, timestep, context).detach()
check_error("unet", output, y)


# auto_grader
model.load_state_dict(auto_grader_data["weights"]["unet"])
x, timestep, context = auto_grader_data["input"]["unet"]
output = model(x, timestep, context).detach()
auto_grader_data["output"]["unet"] = output

**Answer the folowing question in your writeup:**

**Question**: Why might a U-Net by a good choice for the model backbone? List two reasons.

## Diffusion Process
After implementing different kinds of modules used in diffusion models, we now will walk through the algorithm itself. We will first talk about the basics of the diffusion process. Diffusion models are a class of **generative** models that break down the generation process into iterative denoising steps. They work by taking as input an image $x_0$, successively adding Gaussian noise to it, and then learning how to recover the original input by reversing this noising process. Once trained, a diffusion model can then generate new images by running this learned denoising process on randomly sampled noise.

More concretely, given a starting input $x_0$ sampled from the data distribution $q(x)$, the diffusion forward process is iteratively adding Gaussian noise according to a variance schedule $\beta_1, ... , \beta_T$, producing $x_t$ with distribution $q(x_t|x_{t-1})$. The data is also scaled down by a factor of $\sqrt{1-\beta_t}$ at each step so so that the overall variance does not grow when adding noise.

$$q(x_t|x_{t-1}) = \mathcal{N}(x_t; \mu_t = \sqrt{1-\beta_t}x_{t-1}, \Sigma_t=\beta_t I)$$

This corrresponds to a transition of the following Markov chain.

![forwards](https://github.com/jun-tian/CS182_Project_diffusion/blob/main/images/forwards_diffusion.png?raw=true)

Using the Markov chain property that each step only depends on the previous one, we can get the following expression for the posterior after repeatedly transitioning along the chain from timesepts 1 to T.

$$q(x_{1:T}|x_{0}) = \prod_{t=1}^{T} q(x_t|x_{t-1})$$

However, this means that sampling $x_t$ at an arbitrary timestep would require sampling from the distribution $t$ times. The forward process can be made more efficient by reparameterizing. 

Defining $\alpha_t = 1 - \beta_t,\; \bar\alpha_t = \prod_{s=1}^{t} \alpha_s, \: \epsilon \sim N(0,I)$. It gives us the following closed form expresion:

$$ \begin{align}
x_t & = \sqrt{1-\beta_t}\,x_{t-1} + \sqrt{\beta_t}\,\epsilon \\
&= \sqrt{\alpha_t}\,x_{t-1} + \sqrt{1-\alpha_t}\,\epsilon \\
&= \;... \\
&= \sqrt{\bar\alpha_t}\,x_{0} + \sqrt{1-\bar\alpha_t}\,\epsilon \\
\end{align}$$

These $\alpha_t$ and $\bar\alpha_t$ can be precomputed for all timesteps at the start, so any $x_t$ can be sampled quickly. Given this noisy input $x_t$, the model is tasked with reconstructing $x_{t-1}$ by predicting the amount of noise in $x_t$ and then sampling from the posterior $q(x_{t-1}|x_t, x_0)$. We sample from $q(x_{t-1}|x_t, x_0)$ instead of $q(x_{t-1}|x_t)$ as without conditioning on $x_0$, the problem is intractable. $q(x_{t-1}|x_t, x_0)$ has variance $\sigma_t = \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_{t}} \beta_t$

This corresponds to learning the reverse process of the above Markov chain. We approximate the backwards distribution $q(x_{t-1}|x_t)$ as $p_{\theta}(x_{t-1}|x_t)$ using our model.


![backwards](https://github.com/jun-tian/CS182_Project_diffusion/blob/main/images/backwards_diffusion.png?raw=true)

As $T \longrightarrow \infty$, the input signal is effectively destroyed, and can be assumed to be $\mathcal{N}(0, I)$ noise. Thus, we can sample $x_T$ from $\mathcal{N}(0, I)$ and treat it as our starting point. We can then run our diffusion model in an autoregressive style to generate a new image, feeding the denoised version of the image at one iteration to the input of the next iteration. This denosining process consists of T denoising steps, in which the model reconstructs $x_{T-1}$ from $x_T$, $x_{T-2}$ from $x_{T-1}$, and so on until it reaches $x_0$, a newly generated image.

These algorithms for training and sampling a new image are summarized in the image below. For more details, read the original [Denoising Diffusion Probabilistic Models (DDPM)](https://arxiv.org/pdf/2006.11239.pdf) paper.

![algorithms](https://github.com/jun-tian/CS182_Project_diffusion/blob/main/images/Algorithms.jpg?raw=true)

Diffusion models use a variance schedule, which specifies the variance, $\beta_t$, of the Gaussian noise added to the input at timestep t. The LDM paper simply uses an increasing linear schedule, but other schedules, such as a cosine schedule described here [Nichol et al. 2021](https://arxiv.org/pdf/2102.09672.pdf), have given promising results. Below we have implemented four classic schedules for you. 

In [30]:
#@title classic schedule functions
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
  
schedules = {
    "cosine": cosine_beta_schedule,
    "linear": linear_beta_schedule,
    "quadratic": quadratic_beta_schedule,
    "sigmoid": sigmoid_beta_schedule
}

In [53]:
#@title build your scheduler
class Scheduler:
  """Define the coefficients in the training and sampling algorithm"""
  def __init__(self, schedule='linear', timesteps=200):
    self.timesteps = timesteps

    ########################################################################
    # TODO: Define parameters for `Scheduler` class.
    # Hint: 1. Feel free to use the schedules we have implemented for you to get betas
    #       2. Refer back to the background section for how the variables are defined.
    #       3. torch.cumprod and F.pad function maybe useful
    ########################################################################
    # define beta schedule
    self.betas = None

    # define alphas 
    alphas = None
    alphas_cumprod = None
    alphas_cumprod_prev = None
    self.sqrt_reciprocal_alphas = None

    # calculations for diffusion q(x_t | x_{t-1}) and others
    self.sqrt_alphas_cumprod = None
    self.sqrt_one_minus_alphas_cumprod = None

    # calculations the variance for posterior q(x_{t-1} | x_t, x_0)
    self.posterior_variance = None

    ########################################################################

In [None]:
#@title test for scheduler
scheduler = Scheduler(schedule="linear", timesteps=200)
betas = scheduler.betas
sqrt_reciprocal_alphas = scheduler.sqrt_reciprocal_alphas
sqrt_alphas_cumprod = scheduler.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = scheduler.sqrt_one_minus_alphas_cumprod
posterior_variance = scheduler.posterior_variance

true_betas = test_data["output"]["betas"]
true_sqrt_reciprocal_alphas = test_data["output"]["sqrt_reciprocal_alphas"]
true_sqrt_alphas_cumprod = test_data["output"]["sqrt_alphas_cumprod"]
true_sqrt_one_minus_alphas_cumprod = test_data["output"]["sqrt_one_minus_alphas_cumprod"]
true_posterior_variance = test_data["output"]["posterior_variance"]

check_error("betas", betas, true_betas)
check_error("sqrt_reciprocal_alphas", sqrt_reciprocal_alphas, true_sqrt_reciprocal_alphas)
check_error("sqrt_alphas_cumprod", sqrt_alphas_cumprod, true_sqrt_alphas_cumprod)
check_error("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod, true_sqrt_one_minus_alphas_cumprod)
check_error("posterior_variance", posterior_variance, true_posterior_variance)

scheduler = Scheduler(schedule="sigmoid", timesteps=100)
betas = scheduler.betas
sqrt_reciprocal_alphas = scheduler.sqrt_reciprocal_alphas
sqrt_alphas_cumprod = scheduler.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = scheduler.sqrt_one_minus_alphas_cumprod
posterior_variance = scheduler.posterior_variance

auto_grader_data["output"]["betas"] = betas
auto_grader_data["output"]["sqrt_reciprocal_alphas"] = sqrt_reciprocal_alphas
auto_grader_data["output"]["sqrt_alphas_cumprod"] = sqrt_alphas_cumprod
auto_grader_data["output"]["sqrt_one_minus_alphas_cumprod"] = sqrt_one_minus_alphas_cumprod
auto_grader_data["output"]["posterior_variance"] = posterior_variance

## Understand the Diffusion Process on a single image
Below, you will observe the diffusion process on a single picture, which will help you understand the diffusion model. To be more specific, this is a process of gradually adding noise to the picture, done by sampling from $q(x_{1:T}|x_{0})$.

Run the folllowing code to load an image.

In [None]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

In [None]:
# Define the image transforms
image_size = 128
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(), # turn into Numpy array of shape HWC, divide by 255
    transforms.Lambda(lambda t: (t * 2) - 1),
])

reverse_transform = transforms.Compose([
     transforms.Lambda(lambda t: (t + 1) / 2),
     transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     transforms.Lambda(lambda t: t * 255.),
     transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
     transforms.ToPILImage(),
])

x_start = transform(image).unsqueeze(0)
x_start.shape  # output torch.Size([1, 3, 128, 128])

In [None]:
reverse_transform(x_start.squeeze())

Implement the q sampling function below. This function will add noise to the input image. As we talked in the previous part, we need to use the coefficient from scheduler.

In [41]:
#@title q sample
linear_scheduler = Scheduler(schedule='linear', timesteps=200)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    x_noisy = None
    ########################################################################
    # TODO: Implement the forward diffusion process. (Sample from q(x_t | x_{0})
    # Hint: 1. If noise is not defined, then you can sample it from a normal distribution
    #       2. The extract function will help you to extract the t-th coefficients from scheduler(defined above as linear_scheduler)
    ########################################################################
    raise NotImplementedError()
    ########################################################################
    return x_noisy

def get_noisy_image(x_start, t):
    # add noise
    x_noisy = q_sample(x_start, t=t)
    
    # turn back into PIL image
    noisy_image = reverse_transform(x_noisy.squeeze())
    return noisy_image

In [None]:
#@title test for q sample
noise = test_data["input"]["q_sample"]
y = test_data["output"]["q_sample"]
t = torch.tensor([50])
x_noisy = q_sample(x_start, t, noise)
check_error("q_sample", x_noisy, y)

noise = auto_grader_data["input"]["q_sample"]
t = torch.tensor([100])
x_noisy = q_sample(x_start, t, noise)
auto_grader_data["output"]["q_sample"] = x_noisy

In [None]:
# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

In [None]:
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

**Question**

**Screenshot your visualization above** and include it in your submission of the written assignment. Also, describe the picture you observed briefly. What kind of process is this?


## Train Your Diffusion Models
To further understand how to train a diffusion model, we will let you train a small diffusion model based on the MNIST dataset below. Here we show a simple image which shows the overall process for training stage and inference stage. It is important to note that in the offical stable diffusion model, it adopts an AutoEncoderKL to encode the image from RGB or pixel space to latent space and downsample the image by a factor of $f$. This can reduce the demand on computing resources and speed up the inference time if the raw image size is large. 

However, in our demo, we won't use AutoEncoderKL, because our dataset MNIST is small enough, which means we will do the training and inference stages on the pixel level. But keep in mind, there is not much difference between pixel level and latent level when we input them into our model. The essential functionality of the model remains the same.

<img src="https://github.com/jun-tian/CS182_Project_diffusion/blob/main/images/training.png?raw=true" alt="resnt" width="900" height="500" align="bottom" />

Run the following code to prepare for the training stage.

In [44]:
losses = {
    'l1': F.l1_loss,
    'l2': F.mse_loss,
    'huber': F.smooth_l1_loss,
}

sample_text = [
    'The number 0',
    'The number 1',
    'The number 2',
    'The number 3',
    'The number 4',
    'The number 5',
    'The number 6',
    'The number 7',
    'The number 8',
    'The number 9'
]

In [None]:
#@title Dataset
dataset = load_dataset("mnist", split="train")

image_size = 32
channels = 1
batch_size = 64

transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1),
])

def data_transforms(examples):
    examples["pixel_values"] = [transform(image) for image in examples['image']]
    examples["text"] = [f"The number {label}" for label in examples['label']]

    del examples['image']
    del examples['label']
    return examples

transformed_dataset = dataset.with_transform(data_transforms)
dataloader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
batch = next(iter(dataloader))
print(batch.keys())

In [None]:
#@title Initialize the Model
model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,),
    context_dim=512,
    dropout=0.1
)
model.to(torch_device)

optimizer = Adam(model.parameters(), lr=1e-3)

textembedder = FrozenCLIPEmbedder(max_length=5).to(torch_device)
sample_embeddings = textembedder(sample_text)

In [61]:
#@title Define your diffusion class for sampleing and loss function

class Diffusion(nn.Module):
    def __init__(self, model, sample_emb=sample_embeddings, image_size=32, schedule="linear", timesteps=200, loss_type="l1"):
      super().__init__()
      self.model = model
      self.image_size = image_size
      self.scheduler = Scheduler(schedule, timesteps)
      self.loss_fn = losses[loss_type]
      self.sample_emb = sample_emb

    def p_losses(self, x_start, t, context, noise=None):

        loss = None
        ########################################################################
        # TODO: Define the loss function.
        # Hint: 1. If the noise is not defined, then sample it from a normal distribution
        #       2. Your previously implemented q_sample function will be useful to add noise
        ########################################################################
       
        # 1. sample a noise
        raise NotImplementedError()

        # 2. add noise to the image
        raise NotImplementedError()

        # 3. predict the noise
        raise NotImplementedError()

        # 4. calculate the loss
        raise NotImplementedError()
        ########################################################################

        return loss

    @torch.no_grad()
    def p_sample(self, x, t, t_index):

        output = None
        ########################################################################
        # TODO: Define the p sampling function.
        # Hint: 1. Recall the sampling algorithm we have shown you in the previous diffusion process part
        #       2. The extract function can help you extract the t-th coefficient from the scheduler
        #       3. We only need to add noise (z in the algorithm defined above) if t>1
        ########################################################################
        raise NotImplementedError()
        ########################################################################
        return output
        
    # (including returning all images)
    @torch.no_grad()
    def p_sample_loop(self, shape):
        device = next(self.model.parameters()).device
        b = shape[0]
        # start from pure noise (for each example in the batch)
        img = torch.randn(shape, device=device)
        imgs = []

        for i in tqdm(reversed(range(0, self.scheduler.timesteps)), desc='sampling loop time step', total=self.scheduler.timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), i)
            imgs.append(img.cpu().numpy())
        imgs = np.array(imgs)
        return imgs

    @torch.no_grad()
    def sample(self, image_size, batch_size=10, channels=3):
        return self.p_sample_loop(shape=(batch_size, channels, image_size, image_size))

    def sample_t(self, batch_size):
      return torch.randint(0, self.scheduler.timesteps, (batch_size,), device=torch_device).long()

In [62]:
# define the save parameters
results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 300

# initialize the diffusion class for training
timesteps = 200
diffusion_model = Diffusion(model, schedule='linear', timesteps=timesteps, loss_type='huber')

In [None]:
#@title Start Training
epochs = 1 # need 14 minutes to train 1 epoch
min_loss = 1000
print("Start time:", datetime.datetime.now())

for epoch in range(epochs):
    for step, batch in (pbar := tqdm(enumerate(dataloader), desc='training', total=len(dataloader))):
        optimizer.zero_grad()

        batch_size = batch["pixel_values"].shape[0]
        text = batch["text"]
        batch = batch["pixel_values"].to(torch_device)

        # text embedding
        context = textembedder(text)

        # sample t uniformally for every example in the batch
        t = diffusion_model.sample_t(batch_size)

        # calculate the loss
        loss = diffusion_model.p_losses(batch, t, context)

        loss.backward()
        optimizer.step()

        # Show data
        pbar.set_postfix(loss=loss.item(), epoch=f"{epoch+1}/{epochs}")

        if loss < min_loss:
            min_loss = loss.item()

        # save generated images
        if step != 0 and step % save_and_sample_every == 0:
            milestone = step // save_and_sample_every
            all_images = torch.Tensor(diffusion_model.sample(image_size, batch_size=10, channels=channels)).flatten(end_dim=1)
            all_images = (all_images + 1) * 0.5
            save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = len(sample_text))

print("End time:", datetime.datetime.now())

In [64]:
#@title Test your implementation
auto_grader_data["output"]["min_loss"] = min_loss
check_loss(min_loss, threshold=0.02)

The minimum loss 0.013211900368332863 is smaller than threshold loss 0.02


To help you understand the q_sample process, we have sampled few images in the training stage. You can check them in the results file. Below, we will also show the performance of the trained model.

In [None]:
# generate a sample image to show the model performance
%matplotlib inline
samples = torch.Tensor(diffusion_model.sample(image_size, batch_size=10, channels=channels))
samples = (samples + 1) * 0.5
random_index = 8
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

In [None]:
# generate a dynamic gif to show the model performance (in the Files)
random_index = 3

fig = plt.figure()
ims = []
for i in range(timesteps):
    image = samples[i][random_index].permute(1, 2, 0).clamp(0, 1)
    im = plt.imshow(image, cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

**Question**

**Screenshot one of your visualizations above** and include it in your submission of the written assignment. Answer the following question:
- How does the model perform and does it meet your expectations? If not, what do you think are the directions for improvement?

## Submission

Download the file and upload it to the Gradescope.
The Gradescope will run an autograder on the files you submit. 

It is very unlikely but still possible that your implementation might fail to pass some test cases on the Gradescope. Check your code carefully to ensure it correctness.

In [None]:
!rm submission.zip
!zip submission.zip -r *.ipynb auto_grader_data.pt