In [181]:
import torch.optim as optim
from collections import OrderedDict
from einops.layers.torch import Rearrange, Reduce

import torch as t
from typing import Union, Optional, Tuple
from torch import nn
import torch.nn.functional as F
import plotly.express as px
import plotly.graph_objects as go
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from fancy_einsum import einsum
import os
from tqdm import tqdm
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, TensorDataset
from dataclasses import dataclass
import wandb

import w5d1_utils
from abc import ABC, abstractmethod


from dataclasses import dataclass
from torchvision import transforms, datasets
from typing import Tuple
import time
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
import torchinfo
import importlib
import w5d1_solutions

MAIN = True

In [182]:
def gradient_images(n_images: int, img_size: tuple[int, int, int]) -> t.Tensor:
    '''Generate n_images of img_size, each a color gradient
    '''
    (C, H, W) = img_size
    corners = t.randint(0, 255, (2, n_images, C))
    xs = t.linspace(0, W / (W + H), W)
    ys = t.linspace(0, H / (W + H), H)
    (x, y) = t.meshgrid(xs, ys, indexing="xy")
    grid = x + y
    grid = grid / grid[-1, -1]
    grid = repeat(grid, "h w -> b c h w", b=n_images, c=C)
    base = repeat(corners[0], "n c -> n c h w", h=H, w=W)
    ranges = repeat(corners[1] - corners[0], "n c -> n c h w", h=H, w=W)
    gradients = base + grid * ranges
    assert gradients.shape == (n_images, C, H, W)
    return gradients / 255

def plot_img(img: t.Tensor, title: Optional[str] = None) -> None:
    '''Plots a single image, with optional title.
    '''
    img = rearrange(img, "c h w -> h w c").clip(0, 1)
    img = (255 * img).to(t.uint8)
    fig = px.imshow(img, title=title)
    fig.update_layout(margin=dict(t=70 if title else 40, l=40, r=40, b=40))
    fig.show()

def plot_img_grid(imgs: t.Tensor, title: Optional[str] = None, cols: Optional[int] = None) -> None:
    '''Plots a grid of images, with optional title. Splits according to cols.
    '''
    b = imgs.shape[0]
    imgs = rearrange(imgs, "b c h w -> b h w c")
    imgs = (255 * imgs).to(t.uint8)
    if cols is None:
        cols = int(b**0.5) + 1
    fig = px.imshow(imgs, facet_col=0, facet_col_wrap=cols, title=title)
    for annotation in fig.layout.annotations:
        annotation["text"] = ""
    fig.show()

def plot_img_slideshow(imgs: t.Tensor, title: Optional[str] = None) -> None:
    '''Plots slideshow of images (useful for visualising denoising).
    '''
    imgs = rearrange(imgs, "b c h w -> b h w c")
    imgs = (255 * imgs).to(t.uint8)
    fig = px.imshow(imgs, animation_frame=0, title=title)
    fig.show()

if MAIN:
    print("A few samples from the input distribution: ")
    img_shape = (3, 16, 16)
    n_images = 5
    imgs = gradient_images(n_images, img_shape)
    for i in range(n_images):
        plot_img(imgs[i])

A few samples from the input distribution: 


In [183]:
def normalize_img(img: t.Tensor) -> t.Tensor:
    return img * 2 - 1

def denormalize_img(img: t.Tensor) -> t.Tensor:
    return ((img + 1) / 2).clamp(0, 1)

if MAIN:
    plot_img(imgs[0], "Original")
    plot_img(normalize_img(imgs[0]), "Normalized")
    plot_img(denormalize_img(normalize_img(imgs[0])), "Denormalized")

In [184]:
def linear_schedule(max_steps: int, min_noise: float = 0.0001, max_noise: float = 0.02) -> t.Tensor:
    '''
    Return the forward process variances as in the paper.

    max_steps: total number of steps of noise addition
    out: shape (step=max_steps, ) the amount of noise at each step
    '''
    return t.linspace(min_noise, max_noise, max_steps)


betas = linear_schedule(max_steps=200)

def q_forward_slow(x: t.Tensor, num_steps: int, betas: t.Tensor) -> t.Tensor:
    '''Return the input image with num_steps iterations of noise added according to schedule.
    x: shape (channels, height, width)
    betas: shape (T, ) with T >= num_steps

    out: shape (channels, height, width)
    '''
    for _, beta in zip(range(num_steps), betas):
        x *= (1 - beta) ** 0.5
        x += (beta ** 0.5) * t.randn_like(x) 
    # for step in range(num_steps):
    #     noise = t.randn_like(x) * betas[step]
    #     x += noise
    return x

if MAIN:
    x = normalize_img(gradient_images(1, (3, 16, 16))[0])
    for n in [1, 10, 50, 200]:
        xt = q_forward_slow(x, n, betas)
        plot_img(denormalize_img(xt), f"Equation 2 after {n} step(s)")
    plot_img(denormalize_img(t.randn_like(xt)), "Random Gaussian noise")

In [185]:
# def q_forward_slow(x: t.Tensor, num_steps: int, betas: t.Tensor) -> t.Tensor:
#     '''Return the input image with num_steps iterations of noise added according to schedule.
#     x: shape (channels, height, width)
#     betas: shape (T, ) with T >= num_steps

#     out: shape (channels, height, width)
#     '''
#     for step in range(num_steps):
#         noise = t.randn_like(x) * betas[step]
#         x += noise
#     return x


def q_forward_fast(x: t.Tensor, num_steps: int, betas: t.Tensor) -> t.Tensor:
    '''Equivalent to Equation 2 but without a for loop.'''
    at = 1 - betas[:num_steps]
    at = at.prod()
    mean = x * at.sqrt()
    var = (1 - at).sqrt()
    noise = t.randn_like(x) * var
    return mean + noise

if MAIN:
    x = normalize_img(gradient_images(1, (3, 16, 16))[0])
    betas = linear_schedule(max_steps=200)
    for n in [1, 10, 50, 200]:
        xt = q_forward_fast(x, n, betas)
        plot_img(denormalize_img(xt), f"Equation 4 after {n} steps")

In [186]:
class NoiseSchedule(nn.Module):
    betas: t.Tensor
    alphas: t.Tensor
    alpha_bars: t.Tensor

    def __init__(self, max_steps: int, device: Union[t.device, str]) -> None:
        super().__init__()
        self.max_steps = max_steps
        self.device = device
        betas = linear_schedule(max_steps=max_steps).to(device)
        alphas = (1 - betas).to(device)
        alpha_bars = alphas.cumprod(dim=-1).to(device)

        # save buffers
        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alpha_bars", alpha_bars)

    @t.inference_mode()
    def beta(self, num_steps: Union[int, t.Tensor]) -> t.Tensor:
        '''
        Returns the beta(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        return self.betas[num_steps] # self.betas[num_steps:] if isinstance(num_steps, int) else self.betas[num_steps]

    @t.inference_mode()
    def alpha(self, num_steps: Union[int, t.Tensor]) -> t.Tensor:
        '''
        Returns the alphas(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        return self.alphas[num_steps] # self.alphas[num_steps:] if isinstance(num_steps, int) else self.alphas[num_steps]

    @t.inference_mode()
    def alpha_bar(self, num_steps: Union[int, t.Tensor]) -> t.Tensor:
        '''
        Returns the alpha_bar(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        return self.alpha_bars[num_steps] # self.alpha_bars[num_steps:] if isinstance(num_steps, int) else self.alpha_bars[num_steps]

    def __len__(self) -> int:
        return self.max_steps

    def extra_repr(self) -> str:
        return f"max_steps={self.max_steps}"

In [187]:
# def q_forward_fast(x: t.Tensor, num_steps: int, betas: t.Tensor) -> t.Tensor:
#     '''Equivalent to Equation 2 but without a for loop.'''
#     at = 1 - betas[:num_steps]
#     at = at.prod()
#     mean = x * at.sqrt()
#     var = (1 - at).sqrt()
#     noise = t.randn_like(x) * var
#     return mean + noise

def noise_img(
    img: t.Tensor, noise_schedule: NoiseSchedule, max_steps: Optional[int] = None
) -> tuple[t.Tensor, t.Tensor, t.Tensor]:
    '''
    Adds a uniform random number of steps of noise to each image in img.

    img: An image tensor of shape (B, C, H, W)
    noise_schedule: The NoiseSchedule to follow
    max_steps: if provided, only perform the first max_steps of the schedule

    Returns a tuple composed of:
    num_steps: an int tensor of shape (B,) of the number of steps of noise added to each image
    noise: the unscaled, standard Gaussian noise added to each image, a tensor of shape (B, C, H, W)
    noised: the final noised image, a tensor of shape (B, C, H, W)
    '''
    max_steps = max_steps or t.inf
    num_steps = t.randint(1, min(len(noise_schedule), max_steps), (img.shape[0],), device=img.device)

    alpha_bars = noise_schedule.alpha_bar(num_steps)
    noise = t.randn_like(img)
    noised = img * (alpha_bars.sqrt())[:, None, None, None] + noise  * (1 - alpha_bars).sqrt()[:, None, None, None]
    return num_steps, noise, noised

if MAIN:
    noise_schedule = NoiseSchedule(max_steps=200, device="cpu")
    img = gradient_images(1, (3, 16, 16))
    (num_steps, noise, noised) = noise_img(normalize_img(img), noise_schedule, max_steps=10)
    plot_img(img[0], "Gradient")
    plot_img(noise[0], "Applied Unscaled Noise")
    plot_img(denormalize_img(noised[0]), "Gradient with Noise Applied")

In [188]:
def reconstruct(noisy_img: t.Tensor, noise: t.Tensor, num_steps: t.Tensor, noise_schedule: NoiseSchedule) -> t.Tensor:
    '''
    Subtract the scaled noise from noisy_img to recover the original image. We'll later use this with the model's output to log reconstructions during training. We'll use a different method to sample images once the model is trained.

    Returns img, a tensor with shape (B, C, H, W)
    '''
    alpha_bars = rearrange(noise_schedule.alpha_bar(num_steps), "b -> b 1 1 1")
    reconstructed = noisy_img / alpha_bars.sqrt() - noise * ((1 - alpha_bars) / alpha_bars).sqrt()
    return reconstructed

# def reconstruct(noisy_img: t.Tensor, noise: t.Tensor, num_steps: t.Tensor, noise_schedule: NoiseSchedule) -> t.Tensor:
#     '''
#     Subtract the scaled noise from noisy_img to recover the original image. We'll later use this with the model's output to log reconstructions during training. We'll use a different method to sample images once the model is trained.

#     Returns img, a tensor with shape (B, C, H, W)
#     '''
#     alpha_bars = rearrange(noise_schedule.alpha_bar(num_steps), "b -> b 1 1 1")
#     # return (noisy_img / alpha_bars.sqrt() - noise * (1 - alpha_bars).sqrt()) / (alpha_bars.sqrt())

#     reconstructed = noisy_img / alpha_bars.sqrt() - noise * ((1 - alpha_bars) / alpha_bars).sqrt()
#     return reconstructed

if MAIN:
    reconstructed = reconstruct(noised, noise, num_steps, noise_schedule)
    denorm = denormalize_img(reconstructed)
    plot_img(img[0], "Original Gradient")
    plot_img(denorm[0], "Reconstruction")
    t.testing.assert_close(denorm, img)

In [189]:
import math

@dataclass
class DiffusionArgs():
    lr: float = 0.001
    image_shape: tuple = (3, 4, 5)
    epochs: int = 10
    max_steps: int = 100
    batch_size: int = 128
    img_log_interval_seconds: int = 10
    n_images_to_log: int = 3
    n_images: int = 50000
    n_eval_images: int = 1000
    cuda: bool = True
    track: bool = True
    hidden_size: int = 128

class DiffusionModel(nn.Module, ABC):
    img_shape: tuple[int, ...]
    noise_schedule: Optional[NoiseSchedule]

    @abstractmethod
    def forward(self, images: t.Tensor, num_steps: t.Tensor) -> t.Tensor:
        ...

@dataclass(frozen=True)
class TinyDiffuserConfig:
    max_steps: int
    img_shape: Tuple[int, ...] = (3, 4, 5)
    hidden_size: int = 128

class TinyDiffuser(DiffusionModel):
    def __init__(self, config: TinyDiffuserConfig):
        '''
        A toy diffusion model composed of an MLP (Linear, ReLU, Linear)
        '''
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.img_shape = config.img_shape
        self.noise_schedule = NoiseSchedule(config.max_steps, device=device)
        self.max_steps = config.max_steps
        in_dim = math.prod(self.img_shape)+1
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, math.prod(self.img_shape)),
        )

    def forward(self, noised_images: t.Tensor, num_steps: t.Tensor) -> t.Tensor:
        '''
        Given a batch of images and noise steps applied, attempt to predict the noise that was applied.
        images: tensor of shape (B, C, H, W)
        num_steps: tensor of shape (B,)

        Returns
        noise_pred: tensor of shape (B, C, H, W)
        '''
        # num_steps, noise, noised = noise_img(images, self.noise_schedule, self.max_steps)
        flattened = rearrange(noised_images, "b c h w -> b (c h w)")
        flattened = t.cat([flattened, num_steps.unsqueeze(1)], dim=1)
        out = self.mlp(flattened)
        noise_pred = rearrange(out, "b (c h w) -> b c h w", c=self.img_shape[0], h=self.img_shape[1], w=self.img_shape[2])
        return noise_pred

if MAIN:
    img_shape = (3, 4, 5)
    n_images = 5
    imgs = gradient_images(n_images, img_shape)
    n_steps = t.zeros(imgs.size(0))
    print(imgs.shape)
    model_config = TinyDiffuserConfig(16, img_shape, 100)
    model = TinyDiffuser(model_config)
    out = model(imgs, n_steps)
    plot_img(out[0].detach(), "Noise prediction of untrained model")

torch.Size([5, 3, 4, 5])


In [190]:
def log_images(
    img: t.Tensor, noised: t.Tensor, noise: t.Tensor, noise_pred: t.Tensor, reconstructed: t.Tensor, num_images: int = 3
) -> list[wandb.Image]:
    '''
    Convert tensors to a format suitable for logging to Weights and Biases. Returns an image with the ground truth in the upper row, and model reconstruction on the bottom row. Left is the noised image, middle is noise, and reconstructed image is in the rightmost column.
    '''
    actual = t.cat((noised, noise, img), dim=-1)
    pred = t.cat((noised, noise_pred, reconstructed), dim=-1)
    log_img = t.cat((actual, pred), dim=-2)
    images = [wandb.Image(i) for i in log_img[:num_images]]
    return images

def train(
    model: DiffusionModel, 
    trainset: TensorDataset,
    config_dict,
    testset: TensorDataset,
) -> DiffusionModel:
        model.train()
        wandb.init(project="diffusion", config=config_dict)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        train_loader = DataLoader(trainset, batch_size=config_dict["batch_size"], shuffle=True)
        test_loader = DataLoader(testset, batch_size=config_dict["batch_size"], shuffle=True)
        schedule = NoiseSchedule(max_steps=config_dict["max_steps"], device=device)
        model.noise_schedule = schedule
        n_examples_seen = 0
        n_steps = 0
        for epoch in range(config_dict["epochs"]):
            train_loader_progress = tqdm(train_loader, desc=f"E{epoch+1}")
            for batch_idx, (data,) in enumerate(train_loader_progress):
                data = data.to(device)
                num_steps, noise, noised = noise_img(data, model.noise_schedule, config_dict["max_steps"])
                # linearly anneal learning rate
                optimizer.lr = config_dict["lr"] * (1 - n_steps / config_dict["max_steps"])
                optimizer.zero_grad()
                output = model(noised, num_steps)
                loss = F.mse_loss(output, noise)
                loss.backward()
                optimizer.step()
                train_loader_progress.set_postfix(loss=loss.item())
                n_examples_seen += data.shape[0]
                n_steps += 1
                wandb.log({"loss": loss.item()}, step=n_examples_seen)

                if (n_steps + 1) % config_dict["img_log_interval_seconds"] == 0:
                    with t.inference_mode():
                        reconstructed = reconstruct(noised, noise, num_steps, schedule)
                        images = log_images(data, noised, noise, output, reconstructed, num_images=config_dict["n_images_to_log"])
                        wandb.log({"images": images}, step=n_examples_seen)

            if testset is not None:
                total_loss = 0
                for (img,) in tqdm(test_loader, desc=f"Epoch {epoch+1} eval"):
                    img = img.to(device)
                    num_steps, noise, noised = noise_img(img, schedule)
                    with t.inference_mode():
                        noise_pred = model(noised, num_steps)
                        loss = F.mse_loss(noise_pred, noise)
                    total_loss += loss.item()
                wandb.log({"test_loss": total_loss/len(test_loader)}, step=n_examples_seen)
        
        wandb.finish()
        return model
from typing import Dict, Any
if MAIN:
    config_dict: Dict[str, Any] = dict(
        lr=0.001,
        image_shape=(3, 4, 5),
        hidden_size=128,
        epochs=20,
        max_steps=100,
        batch_size=128,
        img_log_interval_seconds=1000,
        n_images_to_log=3,
        n_images=50000,
        n_eval_images=1000,
        device=device,
    )
    model_config = TinyDiffuserConfig(config_dict["max_steps"], config_dict["image_shape"], config_dict["hidden_size"])
    model = TinyDiffuser(model_config).to(device).train()
    trainset = TensorDataset(normalize_img(gradient_images(config_dict["n_images"], config_dict["image_shape"])))
    testset = TensorDataset(normalize_img(gradient_images(config_dict["n_eval_images"], config_dict["image_shape"])))
    model = train(model, trainset, config_dict, testset)

E1: 100%|██████████| 391/391 [00:00<00:00, 526.83it/s, loss=0.299]
Epoch 1 eval: 100%|██████████| 8/8 [00:00<00:00, 1677.81it/s]
E2: 100%|██████████| 391/391 [00:00<00:00, 607.49it/s, loss=0.304]
Epoch 2 eval: 100%|██████████| 8/8 [00:00<00:00, 1702.06it/s]
E3: 100%|██████████| 391/391 [00:00<00:00, 583.81it/s, loss=0.269]
Epoch 3 eval: 100%|██████████| 8/8 [00:00<00:00, 1625.38it/s]
E4: 100%|██████████| 391/391 [00:00<00:00, 517.47it/s, loss=0.26] 
Epoch 4 eval: 100%|██████████| 8/8 [00:00<00:00, 1626.64it/s]
E5: 100%|██████████| 391/391 [00:00<00:00, 571.53it/s, loss=0.277]
Epoch 5 eval: 100%|██████████| 8/8 [00:00<00:00, 1984.65it/s]
E6: 100%|██████████| 391/391 [00:00<00:00, 585.77it/s, loss=0.266]
Epoch 6 eval: 100%|██████████| 8/8 [00:00<00:00, 1498.97it/s]
E7: 100%|██████████| 391/391 [00:00<00:00, 504.17it/s, loss=0.266]
Epoch 7 eval: 100%|██████████| 8/8 [00:00<00:00, 1088.44it/s]
E8: 100%|██████████| 391/391 [00:01<00:00, 355.43it/s, loss=0.227]
Epoch 8 eval: 100%|██████████|

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▄▂▂▂▂▂▂▁▁▂▂▁▂▁▁▂▁▁▁▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁
test_loss,█▄▃▃▃▃▂▂▂▂▂▁▂▂▁▁▂▁▁▂

0,1
loss,0.24014
test_loss,0.22596


In [191]:
def sample(model, n_samples: int, return_all_steps: bool = False) -> t.Tensor:
    """
    Sample, following Algorithm 2 in the DDPM paper

    model: The trained noise-predictor
    n_samples: The number of samples to generate
    return_all_steps: if true, return a list of the reconstructed tensors generated at each step, rather than just the final reconstructed image tensor.

    out: shape (B, C, H, W), the denoised images
            or (T, B, C, H, W), if return_all_steps=True (where [i,:,:,:,:]th element is result of (i+1) steps of sampling)
    """
    schedule = model.noise_schedule
    assert schedule is not None
    
    # Creating list of arrays of shape (max_steps, B, C, H, W), to store all the results
    T = len(schedule)
    out = t.zeros(T, n_samples, *model.img_shape)
    model.eval()

    # Algorithm:
    # STEP (1)
    x = t.randn(size=(n_samples, *model.img_shape)).to(device)
    # STEP (2)
    for t_ in tqdm(range(T, 0, -1)):
        # STEP (3)
        z = t.randn_like(x) if t_ > 1 else 0
        # STEP (4)
        alpha = schedule.alpha(t_-1)
        alpha_bar = schedule.alpha_bar(t_-1)
        beta = schedule.beta(t_-1)
        sigma = 0 # why the fuck is this happening
        t_full = t.full((n_samples,), fill_value=t_, device=schedule.device)
        eps = model(x, t_full)
        sf_1 = 1 / alpha.sqrt()
        sf_2 = (1 - alpha) / ((1 - alpha_bar).sqrt())
        x = sf_1 * (x - sf_2 * eps) + sigma * z
        out[-t_] = x
        # STEP (5)

    # STEP (6)
    if return_all_steps:
        return out
    else:
        return out[-1]


if MAIN:
    print("Generating multiple images")
    assert isinstance(model, DiffusionModel)
    with t.inference_mode():
        samples = sample(model, 6)
        samples_denormalized = denormalize_img(samples).cpu()
    plot_img_grid(samples_denormalized, title="Sample denoised images", cols=3)
if MAIN:
    print("Printing sequential denoising")
    assert isinstance(model, DiffusionModel)
    with t.inference_mode():
        samples = sample(model, 1, return_all_steps=True)[::10, 0, :]
        samples_denormalized = denormalize_img(samples).cpu()
    plot_img_slideshow(samples_denormalized, title="Sample denoised image slideshow")

Generating multiple images


100%|██████████| 100/100 [00:00<00:00, 10911.02it/s]


Printing sequential denoising


100%|██████████| 100/100 [00:00<00:00, 13126.48it/s]
