In [18]:
from abc import ABC, abstractmethod
import time
from dataclasses import dataclass
from typing import Optional, Union, Tuple, List
import torch
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from torch.utils.data import DataLoader
import wandb
from torchvision import transforms
import torchinfo
from torch import nn
import plotly.express as px
from einops.layers.torch import Rearrange
from torch.utils.data import TensorDataset
from tqdm import tqdm
from torchvision import datasets
from pathlib import Path
from fancy_einsum import einsum

MAIN = __name__ == "__main__"

device = "cuda" if torch.cuda.is_available() else "cpu"



In [19]:
def gradient_images(n_images: int, img_size: tuple[int, int, int]) -> torch.Tensor:
    '''Generate n_images of img_size, each a color gradient
    '''    
    (C, H, W) = img_size
    corners = torch.randint(0, 255, (2, n_images, C))
    xs = torch.linspace(0, W / (W + H), W)
    ys = torch.linspace(0, H / (W + H), H)
    (x, y) = torch.meshgrid(xs, ys, indexing="xy")
    grid = x + y
    grid = grid / grid[-1, -1]
    grid = repeat(grid, "h w -> b c h w", b=n_images, c=C)
    base = repeat(corners[0], "n c -> n c h w", h=H, w=W)
    ranges = repeat(corners[1] - corners[0], "n c -> n c h w", h=H, w=W)
    gradients = base + grid * ranges
    assert gradients.shape == (n_images, C, H, W)
    return gradients / 255

def plot_img(img: torch.Tensor, title: Optional[str] = None) -> None:
    '''Plots a single image, with optional title.
    '''
    img = rearrange(img, "c h w -> h w c").clip(0, 1)
    img = (255 * img).to(torch.uint8)
    fig = px.imshow(img, title=title)
    fig.update_layout(margin=dict(t=70 if title else 40, l=40, r=40, b=40))
    fig.show()

def plot_img_grid(imgs: torch.Tensor, title: Optional[str] = None, cols: Optional[int] = None) -> None:
    '''Plots a grid of images, with optional title.
    '''
    b = imgs.shape[0]
    imgs = (255 * imgs).to(torch.uint8).squeeze()
    if imgs.ndim == 3:
        imgs = repeat(imgs, "b h w -> b 3 h w")
    imgs = rearrange(imgs, "b c h w -> b h w c")
    if cols is None: cols = int(b**0.5) + 1
    fig = px.imshow(imgs, facet_col=0, facet_col_wrap=cols, title=title)
    for annotation in fig.layout.annotations: annotation["text"] = ""
    fig.show()

def plot_img_slideshow(imgs: torch.Tensor, title: Optional[str] = None) -> None:
    '''Plots slideshow of images.
    '''
    imgs = (255 * imgs).to(torch.uint8).squeeze()
    if imgs.ndim == 3:
        imgs = repeat(imgs, "b h w -> b 3 h w")
    imgs = rearrange(imgs, "b c h w -> b h w c")
    fig = px.imshow(imgs, animation_frame=0, title=title)
    fig.show()

if MAIN:
    print("A few samples from the input distribution: ")
    image_shape = (3, 16, 16)
    n_images = 5
    imgs = gradient_images(n_images, image_shape)
    for i in range(n_images):
        plot_img(imgs[i])

A few samples from the input distribution: 


In [20]:
def normalize_img(img: torch.Tensor) -> torch.Tensor:
    return img * 2 - 1

def denormalize_img(img: torch.Tensor) -> torch.Tensor:
    return ((img + 1) / 2).clamp(0, 1)

if MAIN:
    plot_img(imgs[0], "Original")
    plot_img(normalize_img(imgs[0]), "Normalized")
    plot_img(denormalize_img(normalize_img(imgs[0])), "Denormalized")

In [21]:
def linear_schedule(max_steps: int, min_noise: float = 0.0001, max_noise: float = 0.02) -> torch.Tensor:
    '''
    Return the forward process variances as in the paper.

    max_steps: total number of steps of noise addition
    out: shape (step=max_steps, ) the amount of noise at each step
    '''
    betas = torch.linspace(min_noise, max_noise, max_steps)
    return betas

if MAIN:
    betas = linear_schedule(max_steps=200)

    '''Plot the betas on the x axis'''
    fig = px.line(x=torch.arange(len(betas)), y=betas)

    fig.show()

In [22]:
def q_forward_slow(x: torch.Tensor, num_steps: int, betas: torch.Tensor) -> torch.Tensor:
    '''Return the input image with num_steps iterations of noise added according to schedule.
    x: shape (channels, height, width)
    betas: shape (T, ) with T >= num_steps

    out: shape (channels, height, width)
    '''
    for t in range(num_steps):
        noise = torch.randn_like(x) * betas[t] ** 0.5
        x = x * (1 - betas[t]) ** 0.5 + noise
    
    return x

if MAIN:
    x = normalize_img(gradient_images(1, (3, 16, 16))[0])
    for n in [1, 10, 50, 200]:
        xt = q_forward_slow(x, n, betas)
        plot_img(denormalize_img(xt), f"Equation 2 after {n} step(s)")
    plot_img(denormalize_img(torch.randn_like(xt)), "Random Gaussian noise")

In [23]:
def q_forward_fast(x: torch.Tensor, num_steps: int, betas: torch.Tensor) -> torch.Tensor:
    '''Equivalent to Equation 2 but without a for loop.'''
    
    alphas = torch.prod(1 - betas[:num_steps])
    
    x = x * alphas ** 0.5 + torch.randn_like(x) * (1 - alphas) ** 0.5
    
    return x


if MAIN:
    for n in [1, 10, 50, 200]:
        xt = q_forward_fast(x, n, betas)
        plot_img(denormalize_img(xt), f"Equation 4 after {n} steps")

In [24]:
class NoiseSchedule(nn.Module):
    betas: torch.Tensor
    alphas: torch.Tensor
    alpha_bars: torch.Tensor

    def __init__(self, max_steps: int, device: Union[torch.device, str]) -> None:
        super().__init__()
        self.max_steps = max_steps
        self.device = device
        
        self.register_buffer("betas", linear_schedule(max_steps).to(device))
        self.register_buffer("alphas", (1 - self.betas).to(device))
        self.register_buffer("alpha_bars", torch.cumprod(self.alphas, dim=0).to(device))
        self.to(device)

    @torch.inference_mode()
    def beta(self, num_steps: Union[int, torch.Tensor]) -> torch.Tensor:
        '''
        Returns the beta(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        return self.betas[num_steps]

    @torch.inference_mode()
    def alpha(self, num_steps: Union[int, torch.Tensor]) -> torch.Tensor:
        '''
        Returns the alphas(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        return self.alphas[num_steps]

    @torch.inference_mode()
    def alpha_bar(self, num_steps: Union[int, torch.Tensor]) -> torch.Tensor:
        '''
        Returns the alpha_bar(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        return self.alpha_bars[num_steps]

    def __len__(self) -> int:
        return self.max_steps

    def extra_repr(self) -> str:
        return f"max_steps={self.max_steps}"

In [25]:
def noise_img(
    img: torch.Tensor, noise_schedule: NoiseSchedule, max_steps: Optional[int] = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Adds a uniform random number of steps of noise to each image in img.

    img: An image tensor of shape (B, C, H, W)
    noise_schedule: The NoiseSchedule to follow
    max_steps: if provided, only perform the first max_steps of the schedule

    Returns a tuple composed of:
    num_steps: an int tensor of shape (B,) of the number of steps of noise added to each image
    noise: the unscaled, standard Gaussian noise to be scaled andadded to each image, a tensor of shape (B, C, H, W)
    noised: the final noised image, a tensor of shape (B, C, H, W)
    '''
    
    
    batch_size = img.shape[0]
    max_steps = max_steps or len(noise_schedule)
    num_steps = torch.randint(0, max_steps, (batch_size,)).to(noise_schedule.device)
    
    noise = torch.randn_like(img) 
    scaled_noise = noise * (1 - noise_schedule.alpha_bar(num_steps)[:, None, None, None]) ** 0.5
    noised = img * noise_schedule.alpha_bar(num_steps)[:, None, None, None] ** 0.5 + scaled_noise
    return num_steps, noise, noised

if MAIN:
    noise_schedule = NoiseSchedule(max_steps=200, device="cpu")
    img = gradient_images(1, (3, 16, 16))
    (num_steps, noise, noised) = noise_img(normalize_img(img), noise_schedule, max_steps=10)
    plot_img(img[0], "Gradient")
    plot_img(noise[0], "Applied Unscaled Noise")
    plot_img(denormalize_img(noised[0]), "Gradient with Noise Applied")

In [26]:
def reconstruct(noisy_img: torch.Tensor, noise: torch.Tensor, num_steps: torch.Tensor, noise_schedule: NoiseSchedule) -> torch.Tensor:
    '''
    Subtract the scaled noise from noisy_img to recover the original image. We'll later 
    use this with the model's output to log reconstructions during training. We'll use a 
    different method to sample images once the model is trained.

    Returns img, a tensor with shape (B, C, H, W)
    '''
    scaled_noise = noise * (1 - noise_schedule.alpha_bar(num_steps)[:, None, None, None]) ** 0.5
    img = noisy_img - scaled_noise
    scaled_image = img / noise_schedule.alpha_bar(num_steps)[:, None, None, None] ** 0.5

    return scaled_image

if MAIN:
    reconstructed = reconstruct(noised, noise, num_steps, noise_schedule)
    denorm = denormalize_img(reconstructed)
    plot_img(img[0], "Original Gradient")
    plot_img(denorm[0], "Reconstruction")
    torch.testing.assert_close(denorm, img)

In [58]:
@dataclass
class DiffusionArgs():
    lr: float = 0.001
    image_shape: tuple = (3, 4, 5)
    epochs: int = 10
    max_steps: int = 100
    batch_size: int = 128
    seconds_between_image_logs: int = 10
    n_images_per_log: int = 3
    n_images: int = 50000
    n_eval_images: int = 1000
    cuda: bool = False
    track: bool = False

class DiffusionModel(nn.Module, ABC):
    image_shape: tuple[int, ...]
    noise_schedule: Optional[NoiseSchedule]

    @abstractmethod
    def forward(self, images: torch.Tensor, num_steps: torch.Tensor) -> torch.Tensor:
        ...

@dataclass(frozen=True)
class TinyDiffuserConfig:
    image_shape: Tuple[int, ...] = (3, 4, 5)
    hidden_size: int = 128
    max_steps: int = 200

class TinyDiffuser(DiffusionModel):
    def __init__(self, config: TinyDiffuserConfig):
        '''
        A toy diffusion model composed of an MLP (Linear, ReLU, Linear)
        '''
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.image_shape = config.image_shape
        self.max_steps = config.max_steps
        c, h, w = self.image_shape
        self.flat_shape = torch.prod(torch.tensor(self.image_shape))
        self.mlp = nn.Sequential(
            nn.Linear(self.flat_shape+1, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.flat_shape),
            Rearrange('b (c h w) -> b c h w', c=c, h=h, w=w)
        )

    def forward(self, images: torch.Tensor, num_steps: torch.Tensor) -> torch.Tensor:
        '''
        Given a batch of images and noise steps applied, attempt to predict the noise that was applied.
        images: tensor of shape (B, C, H, W)
        num_steps: tensor of shape (B,)

        Returns
        noise_pred: tensor of shape (B, C, H, W)
        '''
        images = torch.cat([images.flatten(1), torch.unsqueeze(num_steps, dim=1)], dim=1)
        noise_pred = self.mlp(images)
        return noise_pred

if MAIN:
    image_shape = (3, 4, 5)
    n_images = 5
    imgs = gradient_images(n_images, image_shape)
    n_steps = torch.zeros(imgs.size(0))
    model_config = TinyDiffuserConfig(image_shape, 16, 100)
    model = TinyDiffuser(model_config)
    out = model(imgs, n_steps)
    plot_img(imgs[0].detach(), "Original of untrained model")
    plot_img(out[0].detach(), "Noise prediction of untrained model")

In [59]:
args = DiffusionArgs()

def train(args):
    '''
    Train a diffusion model on the gradient images dataset
    '''
    device = torch.device("cuda" if args.cuda else "cpu")
    model_config = TinyDiffuserConfig(args.image_shape, 128, args.max_steps)
    model = TinyDiffuser(model_config)
    model.to(device)
    model.noise_schedule = NoiseSchedule(args.max_steps, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    train_loader = gradient_images(args.n_images, args.image_shape)

    loss = 1e10
    i = 0
    while loss > 0.26:
        batch = gradient_images(args.batch_size, args.image_shape)
        batch = batch.to(device)
        num_steps, noise, noised = noise_img(batch, model.noise_schedule, args.max_steps)
        noise_pred = model(noised, num_steps)
        loss = torch.mean((noise_pred - noise) ** 2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"Loss {loss.item()}")
        i += 1

In [60]:
def log_images(
    img: torch.Tensor, noised: torch.Tensor, noise: torch.Tensor, noise_pred: torch.Tensor, reconstructed: torch.Tensor, num_images: int = 3
) -> list[wandb.Image]:
    '''
    Convert tensors to a format suitable for logging to Weights and Biases. Returns an image with the ground truth in the upper row, and model reconstruction on the bottom row. Left is the noised image, middle is noise, and reconstructed image is in the rightmost column.
    '''
    actual = torch.cat((noised, noise, img), dim=-1)
    pred = torch.cat((noised, noise_pred, reconstructed), dim=-1)
    log_img = torch.cat((actual, pred), dim=-2)
    images = [wandb.Image(i) for i in log_img[:num_images]]
    return images

def train(
    model: DiffusionModel, 
    args: DiffusionArgs, 
    trainset: TensorDataset,
    testset: Optional[TensorDataset] = None
) -> DiffusionModel:
    '''
    Train a diffusion model on the gradient images dataset
    '''
    device = torch.device("cuda" if args.cuda else "cpu")
    model.to(device)
    model.noise_schedule = NoiseSchedule(args.max_steps, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=True) if testset else None

    wandb.init(project="Curt-Joseph-diffusion", config=args)
    wandb.watch(model)

    for epoch in range(args.epochs):
        for i, (img,) in enumerate(train_loader):
            img = img.to(device)
            num_steps, noise, noised = noise_img(img, model.noise_schedule)
            noise_pred = model(noised, num_steps)
            loss = torch.mean((noise_pred - noise) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")
                wandb.log({"loss": loss.item()})
                if test_loader:
                    with torch.no_grad():
                        model.eval()    
                        test_img = next(iter(test_loader))[0].to(device)
                        test_num_steps, test_noise, test_noised = noise_img(test_img, model.noise_schedule, args.max_steps)
                        test_noise_pred = model(test_noised, test_num_steps)
                        test_reconstructed = reconstruct(test_noised, test_noise_pred, num_steps, model.noise_schedule)
                        test_loss = torch.mean((test_noise_pred - test_noise) ** 2)
                        wandb.log({"test_loss": test_loss.item()})
                        images = log_images(test_img, test_noised, test_noise, test_noise_pred, test_reconstructed)
                        wandb.log({"test_images": images})
                        model.train()
                else:
                    with torch.no_grad():
                        reconstructed = reconstruct(noised, noise_pred, num_steps, model.noise_schedule)
                        images = log_images(img, noised, noise, noise_pred, reconstructed)
                        wandb.log({"images": images})
    return model

if MAIN:
    args = DiffusionArgs(epochs=3) # This shouldn't take long to train
    model_config = TinyDiffuserConfig(max_steps=args.max_steps)
    model = TinyDiffuser(model_config).to(device).train()
    trainset = TensorDataset(normalize_img(gradient_images(args.n_images, args.image_shape)))
    testset = TensorDataset(normalize_img(gradient_images(args.n_eval_images, args.image_shape)))
    model = train(model, args, trainset, testset)

VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▅▂▁▁▁▁▁▁▁▂▁
test_loss,█▅▂▁▂▁▁▁▁▁▁▁

0,1
loss,0.27384
test_loss,0.25158


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016699645148279766, max=1.0…

Epoch 0, batch 0, loss 3.5304651260375977
Epoch 0, batch 100, loss 0.8125084042549133
Epoch 0, batch 200, loss 0.6366317868232727
Epoch 0, batch 300, loss 0.5284217596054077
Epoch 1, batch 0, loss 0.3473640978336334
Epoch 1, batch 100, loss 0.3378453850746155
Epoch 1, batch 200, loss 0.2892999053001404
Epoch 1, batch 300, loss 0.30319270491600037
Epoch 2, batch 0, loss 0.2963164746761322
Epoch 2, batch 100, loss 0.23404461145401
Epoch 2, batch 200, loss 0.28790798783302307
Epoch 2, batch 300, loss 0.2909761369228363


In [63]:
def sample(model: DiffusionModel, n_samples: int, return_all_steps: bool = False) -> Union[torch.Tensor, list[torch.Tensor]]:
    '''
    Sample, following Algorithm 2 in the DDPM paper

    model: The trained noise-predictor
    n_samples: The number of samples to generate
    return_all_steps: if true, return a list of the reconstructed tensors generated at 
    each step, rather than just the final reconstructed image tensor.

    out: shape (B, C, H, W), the denoised images
            or (T, B, C, H, W), if return_all_steps=True (where ith element is batched 
            result of (i+1) steps of sampling)
    '''
    schedule = model.noise_schedule
    assert schedule is not None
    device = schedule.device
    reconstruction = [torch.randn(n_samples, *model.image_shape, device=device).unsqueeze(0)]
    with torch.inference_mode():
        for i in range(schedule.max_steps):
            # get  z
            if i < range(schedule.max_steps):
                noise = torch.randn(n_samples, *model.image_shape, device=device)
            else:
                noise = torch.zeros((n_samples, *model.image_shape), device=device)

            noise_pred = model(test_noised, i)
            reconstruction.append(
                reconstruct(
                    reconstruction[-1], 
                    noise_pred, 
                    num_steps=i, 
                    noise_schedule = model.noise_schedule
                    )
                )

            # estimate noise


if MAIN:
    print("Generating multiple images")
    assert isinstance(model, DiffusionModel)
    with torch.inference_mode():
        samples = sample(model, 6)
        print(samples.shape, samples.dtype)
        samples_denormalized = denormalize_img(samples).cpu()
    # plot the images
    #px.imshow(samples_denormalized, facet_col = 0, title="Sample denoised images").show()
    imgs = gradient_images(6, samples.shape[1:])
    # plot_img_grid(imgs, title="Sample denoised images", cols=3)
    plot_img_grid(samples_denormalized, title="Sample denoised images", cols=3)
if MAIN:
    print("Printing sequential denoising")
    assert isinstance(model, DiffusionModel)
    with torch.inference_mode():
        samples = sample(model, 1, return_all_steps=True)[::10, :, :]
        samples_denormalized = denormalize_img(samples).cpu()
    plot_img_slideshow(samples_denormalized, title="Sample denoised image slideshow")

Generating multiple images


TypeError: unsqueeze(): argument 'input' (position 1) must be Tensor, not int