# Diffusion on MNIST

Goal: denoising diffusion model without latent space

## References

* fastai 2022 / 2023 course part II:
    * [notebook 26](https://github.com/fastai/course22p2/blob/master/nbs/26_diffusion_unet.ipynb)
    * [lesson 19](https://course.fast.ai/Lessons/lesson19.html)

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import typing as T

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torchinfo
import tqdm
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader

import random_neural_net_models.convolution_lecun1990 as conv_lecun1990
import random_neural_net_models.telemetry as telemetry
import random_neural_net_models.unet as unet
import random_neural_net_models.unet_with_noise as unet_with_noise
import random_neural_net_models.utils as utils

sns.set_theme()

In [None]:
DO_OVERFITTING_ONLY = True

In [None]:
mnist = fetch_openml("mnist_784", version=1, cache=True, parser="auto")

Setting seeds

In [None]:
utils.make_deterministic(42)

Getting device

In [None]:
def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"


device = get_device()
device

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

Selecting a few images to overfit on (limiting to the number 5)

In [None]:
n0 = 32
n1 = 1_000
is_5 = y == "5"
X0, y0 = X.loc[is_5].iloc[:n0], y.loc[is_5].iloc[:n0]
X1, y1 = X.loc[is_5].iloc[n0 : n1 + n0], y.loc[is_5].iloc[n0 : n0 + n1]
X0.shape, X1.shape

## Defining dataset and dataloader

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_test = conv_lecun1990.DigitsDataset(X1, y1)

In [None]:
item = ds[0]
plt.imshow(item[0], cmap="gray", origin="upper")
plt.title(f"Label: {item[1]}")
plt.axis("off")
plt.tight_layout()

applying noise based on 
```python
def noisify(x0):
    device = x0.device
    sig = (torch.randn([len(x0)])*1.2-1.2).exp().to(x0).reshape(-1,1,1,1)
    noise = torch.randn_like(x0, device=device)
    c_skip,c_out,c_in = scalings(sig)
    noised_input = x0 + noise*sig
    target = (x0-c_skip*noised_input)/c_out
    return (noised_input*c_in,sig.squeeze()),target
```
from https://github.com/fastai/course22p2/blob/master/nbs/26_diffusion_unet.ipynb

In [None]:
def list_of_tuples_to_tensors(
    batch: T.List[T.Tuple[torch.Tensor, int]]
) -> T.Tuple[torch.Tensor, torch.Tensor]:
    images, labels = zip(*batch)
    images = torch.stack(images)
    labels = torch.tensor(labels, dtype=int)
    return images, labels


SIG_DATA = 0.66


def get_cs(
    sig: torch.Tensor,
) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # TODO: wtf is happening here?
    totvar = sig**2 + SIG_DATA**2
    c_skip = SIG_DATA**2 / totvar
    c_out = sig * SIG_DATA / totvar.sqrt()
    c_in = 1 / totvar.sqrt()
    return c_skip, c_out, c_in


def draw_sig_from_noise_prior(n: int) -> torch.Tensor:
    "Draws noise level (prior) from a log normal distribution"
    sig = torch.randn(n)
    sig = 1.2 * sig - 1.2
    sig = sig.exp()
    return sig


def draw_img_noise_given_sig(
    sig: torch.Tensor,
    images: torch.Tensor = None,
    images_shape: T.Tuple[int, int, int] = None,
) -> torch.Tensor:
    "Draws noise from a normal distribution given the noise level (sig)"
    if images is not None:
        images_shape = images.shape

    noise = torch.randn(images_shape)
    noise = noise * sig
    return noise


def fudge_original_images(images: torch.Tensor) -> torch.Tensor:
    return images * 2 - 1


def apply_noise(
    batch: T.List[T.Tuple[torch.Tensor, int]]
) -> T.Tuple[T.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
    "Applies noise to the input image and returns the noisy image, the noise level and the de-noised image"

    orig_images, _ = list_of_tuples_to_tensors(batch)

    orig_images = fudge_original_images(orig_images)

    # drawing noise level (prior) from a log normal distribution
    sig = draw_sig_from_noise_prior(orig_images.shape[0])
    sig = sig.reshape(-1, 1, 1)

    c_skip, c_out, c_in = get_cs(sig)

    # adding noise to the image
    noise = draw_img_noise_given_sig(sig, images=orig_images)
    noisy_images = orig_images + noise

    target_noise = (orig_images - c_skip * noisy_images) / c_out
    noisy_images = noisy_images * c_in

    sig = sig.squeeze()

    return (noisy_images, sig), target_noise


def get_denoised_images(
    noisy_images: torch.Tensor, predicted_noise: torch.Tensor, sig: torch.Tensor
) -> torch.Tensor:
    "Returns the de-noised images given the noisy images, predicted noise and the noise level (sig)"
    c_skip, c_out, c_in = get_cs(sig)
    denoised_images = predicted_noise * c_out + (noisy_images / c_in) * c_skip
    return denoised_images

defining a dataloader

In [None]:
batch_size = n0
dataloader = DataLoader(
    ds, batch_size=batch_size, shuffle=False, collate_fn=apply_noise
)

inspecting the noisified images

In [None]:
(noisified_input_images, noise_levels), target_noise = next(iter(dataloader))

In [None]:
ix_img = 0
noisy_input_image = noisified_input_images[ix_img].cpu()
target_noise = target_noise[ix_img].cpu()

sig = noise_levels[ix_img].cpu()
c_skip, c_out, c_in = get_cs(sig)
denoised_image = target_noise * c_out + (noisy_input_image / c_in) * c_skip

print(f"noise level: {noise_levels[ix_img]}")

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 7))
ax = axs[0]
ax.imshow(noisy_input_image, cmap="gray")
ax.set_title("Noisy input image")
ax.axis("off")
ax = axs[1]
ax.imshow(target_noise, cmap="gray")
ax.set_title("Target noise")
ax.axis("off")
ax = axs[2]
ax.imshow(denoised_image, cmap="gray")
ax.set_title("Denoised image")
ax.axis("off")
plt.show()

In [None]:
display(
    "noisy input",
    pd.Series(noisy_input_image.flatten().numpy()).describe(),
    "target noise",
    pd.Series(target_noise.flatten().numpy()).describe(),
    "denoised image",
    pd.Series(denoised_image.flatten().numpy()).describe(),
)

In [None]:
bins = np.linspace(-3, 3, 100)
fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(10, 7), sharex=True)
ax = axs[0]
ax.hist(noisy_input_image.flatten(), bins=bins)
ax.set_title("Noisy input image")
ax = axs[1]
ax.hist(target_noise.flatten(), bins=bins)
ax.set_title("Target noise")
ax = axs[2]
ax.hist(denoised_image.flatten(), bins=bins)
ax.set_title("Denoised image")
plt.show()

## overfitting

defining the model

In [None]:
model = unet.UNetModel(
    in_channels=1,
    out_channels=1,
    list_num_features=(
        8,
        16,
    ),
    num_layers=2,
)
model = telemetry.ModelTelemetry(
    model,
    loss_names=("total",),
    activations_name_patterns=(".*act.*",),
    gradients_name_patterns=(r".*conv\d", r".*convs\.[25]$", r".*idconv$"),
    parameters_name_patterns=(r".*conv\d", r".*convs\.[25]$", r".*idconv$"),
    max_depth_search=10,
)
model.double()
model.to(device);

In [None]:
torchinfo.summary(model, input_size=(1, 28, 28), dtypes=[torch.double])

In [None]:
opt = SGD(
    model.parameters(),
    lr=0.1,
)

In [None]:
loss_func = nn.MSELoss()

In [None]:
_iter = 0

training loop

In [None]:
n_epochs = 100

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, ((xb, _), yb) in enumerate(dataloader):
        xb = xb.to(device)
        x_pred = model(xb)

        loss = loss_func(x_pred, yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

print("Done!")

plotting losses

In [None]:
model.draw_loss_history_train()

In [None]:
(input_images, noise_levels), target_noises = next(iter(dataloader))

In [None]:
model.eval();

inspecting predictions

In [None]:
input_images = input_images.to(device)
preds = model(input_images)
preds[0, :5, :5]

In [None]:
x_pred = preds.detach().cpu()  # .numpy()
x_pred[0, :3, :5]

In [None]:
ix_img = 0
noisy_input_image = input_images[ix_img].cpu()  # .numpy()
pred_noise = x_pred[ix_img]
target_noise = target_noises[ix_img].cpu()  # .numpy()
sig = noise_levels[ix_img].cpu()

c_skip, c_out, c_in = get_cs(sig)
denoised_image = target_noise * c_out + (noisy_input_image / c_in) * c_skip
pred_denoised_image = pred_noise * c_out + (noisy_input_image / c_in) * c_skip

print(f"noise level: {noise_levels[ix_img]}")
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 7))
ax = axs[0]
ax.imshow(noisy_input_image, cmap="gray")
ax.set_title("Noisy input image")
ax.axis("off")
ax = axs[1]
ax.imshow(denoised_image, cmap="gray")
ax.set_title("Ideal reconstructed image")
ax.axis("off")
ax = axs[2]
ax.imshow(pred_denoised_image, cmap="gray")
ax.set_title("Model reconstructed image")
ax.axis("off")
plt.show()

plotting gradients

In [None]:
model.draw_gradient_stats(yscale="log", figsize=(12, 20))

plotting activations

In [None]:
model.draw_activation_stats(yscale="log")

drawing histograms of the weights and biases across training iterations

In [None]:
model.draw_parameter_stats()

In [None]:
model.clean_hooks()

In [None]:
if DO_OVERFITTING_ONLY:
    raise SystemExit("Skipping training beyond overfitting.")

## including the noise level as input

In [None]:
noise = torch.linspace(-10, 10, 100)
emb = unet_with_noise.get_noise_level_embedding(noise, 8 * 4, max_period=1000)
print(emb.T.shape)
plt.imshow(emb.T)
plt.axis("off")
plt.tight_layout()

defining the model

In [None]:
model = unet_with_noise.UNetModel(
    in_channels=1,
    out_channels=1,
    list_num_features=(
        8,
        16,
    ),
    num_layers=2,
)
model = telemetry.ModelTelemetry(
    model,
    loss_names=("total",),
    activations_name_patterns=(".*act.*",),
    gradients_name_patterns=(r".*conv\d", r".*convs\.[25]$", r".*idconv$"),
    parameters_name_patterns=(r".*conv\d", r".*convs\.[25]$", r".*idconv$"),
    max_depth_search=10,
)
model.double()
model.to(device);

In [None]:
# opt = SGD(
#     model.parameters(),
#     lr=0.1,
# )
opt = Adam(model.parameters(), lr=4e-3, eps=1e-5)

In [None]:
loss_func = nn.MSELoss()

In [None]:
_iter = 0

training loop

In [None]:
n_epochs = 100

# opt = SGD(
#     model.parameters(),
#     lr=0.1,
# )
opt = Adam(model.parameters(), lr=1e-2, eps=1e-5)

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, ((xb, noise_levels), yb) in enumerate(dataloader):
        xb = xb.to(device)
        x_pred = model(xb, noise_levels)

        loss = loss_func(x_pred, yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

print("Done!")

In [None]:
n_epochs = 100

# opt = SGD(
#     model.parameters(),
#     lr=0.1,
# )
opt = Adam(model.parameters(), lr=4e-3, eps=1e-5)

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, ((xb, noise_levels), yb) in enumerate(dataloader):
        xb = xb.to(device)
        x_pred = model(xb, noise_levels)

        loss = loss_func(x_pred, yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

print("Done!")

In [None]:
n_epochs = 100

# opt = SGD(
#     model.parameters(),
#     lr=0.1,
# )
opt = Adam(model.parameters(), lr=4e-4, eps=1e-5)

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, ((xb, noise_levels), yb) in enumerate(dataloader):
        xb = xb.to(device)
        x_pred = model(xb, noise_levels)

        loss = loss_func(x_pred, yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

print("Done!")

plotting losses

In [None]:
model.draw_loss_history_train()

In [None]:
(input_images, noise_levels), target_noises = next(iter(dataloader))

In [None]:
model.eval();

inspecting predictions

In [None]:
input_images = input_images.to(device)
noise_levels = noise_levels.to(device)
preds = model(input_images, noise_levels)
preds[0, :5, :5]

In [None]:
x_pred = preds.detach().cpu()  # .numpy()
x_pred[0, :3, :5]

In [None]:
ix_img = 2
noisy_input_image = input_images[ix_img].cpu()  # .numpy()
pred_noise = x_pred[ix_img]
target_noise = target_noises[ix_img].cpu()  # .numpy()
sig = noise_levels[ix_img].cpu()

c_skip, c_out, c_in = get_cs(sig)
denoised_image = target_noise * c_out + (noisy_input_image / c_in) * c_skip
pred_denoised_image = pred_noise * c_out + (noisy_input_image / c_in) * c_skip

print(f"noise level: {noise_levels[ix_img]}")
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 7))
ax = axs[0]
ax.imshow(noisy_input_image, cmap="gray")
ax.set_title("Noisy input image")
ax.axis("off")
ax = axs[1]
ax.imshow(denoised_image, cmap="gray")
ax.set_title("Ideal reconstructed image")
ax.axis("off")
ax = axs[2]
ax.imshow(pred_denoised_image, cmap="gray")
ax.set_title("Model reconstructed image")
ax.axis("off")
plt.show()

sampling

noise levels based on
```python
def sigmas_karras(n, sigma_min=0.01, sigma_max=80., rho=7.):
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min**(1/rho)
    max_inv_rho = sigma_max**(1/rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho-max_inv_rho))**rho
    return torch.cat([sigmas, tensor([0.])]).cuda()
```

In [None]:
def sigmas_karras(
    n: int, sigma_min: float = 0.01, sigma_max: float = 80.0, rho: float = 7.0
) -> torch.Tensor:
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho

    return torch.cat([sigmas, torch.tensor([0.0])])


sigma_max = 0.5
sigs = sigmas_karras(100, sigma_max=sigma_max)
sigs.shape

In [None]:
sns.scatterplot(x=range(len(sigs)), y=sigs);

In [None]:
generative_sig = torch.tensor([sigma_max, sigma_max, sigma_max])
sampled_noise = draw_img_noise_given_sig(
    generative_sig.reshape(-1, 1, 1),
    images_shape=(generative_sig.shape[0], 28, 28),
)
sampled_noise.shape

In [None]:
model.eval();

denoising based on 
```python
def denoise(model, x, sig):
    sig = sig[None]
    c_skip,c_out,c_in = scalings(sig)
    return model((x*c_in, sig))*c_out + x*c_skip
    
def sample_lms(model, steps=100, order=4, sigma_max=80.):
    preds = []
    x = torch.randn(sz).cuda()*sigma_max
    sigs = sigmas_karras(steps, sigma_max=sigma_max)
    ds = []
    for i in progress_bar(range(len(sigs)-1)):
        sig = sigs[i]
        denoised = denoise(model, x, sig)
        d = (x-denoised)/sig
        ds.append(d)
        if len(ds) > order: ds.pop(0)
        cur_order = min(i+1, order)
        coeffs = [linear_multistep_coeff(cur_order, sigs, i, j) for j in range(cur_order)]
        x = x + sum(coeff*d for coeff, d in zip(coeffs, reversed(ds)))
        preds.append(x)
    return preds
```

In [None]:
def denoise_with_model(
    model: telemetry.ModelTelemetry, images: torch.Tensor, sigs: torch.Tensor
) -> T.Tuple[T.List[torch.Tensor], T.List[torch.Tensor]]:
    "Denoises an image with the model for a range of noise levels"
    noise_preds = []
    denoised_preds = []
    for i, sig in tqdm.tqdm(enumerate(sigs), total=len(sigs), desc="Sigmas"):
        _sigs = sig.repeat(images.shape[0])

        _, _, c_in = get_cs(_sigs.reshape(-1, 1, 1))
        if i == 0:
            images = images * c_in

        pred_noise = model(images, _sigs)

        images = get_denoised_images(
            images, pred_noise, _sigs.reshape(-1, 1, 1)
        )

        noise_preds.append(pred_noise.detach().cpu())
        denoised_preds.append(images.detach().cpu())
    return noise_preds, denoised_preds

In [None]:
noise_preds, denoised_preds = denoise_with_model(
    model, sampled_noise.double(), sigs
)

In [None]:
ix_img = 0
ix_denoise = 5
noisy_input_image = sampled_noise[ix_img].cpu()
predicted_noise = noise_preds[ix_denoise][ix_img].cpu()
denoised_image = denoised_preds[ix_denoise][ix_img].cpu()

sig = sigs[ix_denoise].cpu()
c_skip, c_out, c_in = get_cs(sig)

print(f"noise level for denoising: {sig}")

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 7))
ax = axs[0]
ax.imshow(noisy_input_image, cmap="gray")
ax.set_title("Noisy input image")
ax.axis("off")
ax = axs[1]
ax.imshow(predicted_noise, cmap="gray")
ax.set_title("Predicted noise")
ax.axis("off")
ax = axs[2]
ax.imshow(denoised_image, cmap="gray")
ax.set_title("Denoised image")
ax.axis("off")
plt.show()

In [None]:
# TODO: above sampling does not quite lead to the generation of the number 5, unclear when how many of the sigmas should be used since sometimes the first few already give much better results than the last