# Simple diffusion
This is a simple diffusion model heavily based on [minDiffusion](https://github.com/cloneofsimo/minDiffusion).

In [None]:
import os, random, time, math
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image, make_grid
from tqdm import tqdm

## Constants
Constants used in the diffusion process and training are defined here. [Full reproducibility requires more than just setting the seed.](https://pytorch.org/docs/stable/notes/randomness.html)

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_OUT_DIR = r"D:\Projs\proj\simple-diff\models"
SAMPLE_OUT_DIR = r"D:\Projs\proj\simple-diff\samples"

# Diffusion
NOISE_STEPS = 1000
BETA_START = 1e-4
BETA_END = 2e-2

# Training
INIT_LR = 1e-5
BATCH_SIZE = 32
EPOCHS = 3000

# Low effort reproducibility
SEED = 1
torch.manual_seed(SEED)
random.seed(SEED)

## Dataset
The dataset used is CIFAR-10, a collection of 3 channel 32 by 32 pixel images. The transform normalizes the input to range of \[-1, 1\]. For testing purposes a single batch can be used.

In [None]:
data_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
dataset = CIFAR10(
    "../../datasets",
    train=True,
    download=True,
    transform=data_transforms
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

In [None]:
# Test model by overfitting to one batch
subset_indices = range(BATCH_SIZE)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(subset_indices))

## Diffusion
The diffusion class implements forward diffusion (noise image generation) and sampling (denoising) using a linear schedule.

In [None]:
class Diffusion():
    def __init__(self, noise_steps=NOISE_STEPS, beta_start=BETA_START, beta_end=BETA_END, device=DEVICE):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.device = device

        self.alpha, self.beta, self.alcum = self.linear_schedule()
        self.sqrt_beta = torch.sqrt(self.beta)
        self.sqrt_alcum = torch.sqrt(self.alcum)
        self.rec_sqrt_alpha = 1. / torch.sqrt(self.alpha)
        self.sqrt_1sub_alcum = torch.sqrt(1 - self.alcum)
        self.beta_over_sqrt_1sub_alcum = self.beta / self.sqrt_1sub_alcum


    def linear_schedule(self):
        beta = torch.linspace(self.beta_start, self.beta_end, self.noise_steps + 1, device=self.device)
        alpha = 1. - beta
        alpha_cum = torch.cumprod(alpha, axis=0)
        return alpha, beta, alpha_cum


    def forward(self, x_0):
        t = torch.randint(1, self.noise_steps + 1, (x_0.shape[0],), device=self.device)
        noise = torch.randn_like(x_0, device=self.device)
        x_t = self.sqrt_alcum[t, None, None, None] * x_0 + self.sqrt_1sub_alcum[t, None, None, None] * noise
        return x_t, noise, t


    def sample(self, sample_num, img_dims, model):
        model.eval()
        with torch.no_grad():
            x = torch.randn((sample_num, *img_dims), device=self.device)
            for i in reversed(range(1, self.noise_steps)):
                t = torch.full(size=(sample_num,), fill_value=i, device=self.device)
                predicted_noise = model(x, t)
                noise = torch.randn_like(x)
                rec_sqrt_alpha = self.rec_sqrt_alpha[t, None, None, None]
                beta_over_sqrt_1sub_alcum = self.beta_over_sqrt_1sub_alcum[t, None, None, None]
                sqrt_beta = self.sqrt_beta[t, None, None, None]
                x = rec_sqrt_alpha * (x - beta_over_sqrt_1sub_alcum * predicted_noise) + sqrt_beta * noise
        return x


diffusion = Diffusion()

## Model
The denoising model is a simple U-Net structure based on [minDiffusion](https://github.com/cloneofsimo/minDiffusion). It uses sinusoidal position embeddings to encode time steps.

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    """Taken verbatim from https://huggingface.co/blog/annotated-diffusion"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim


    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [None]:
"""
    Simple convolutional network based on NaiveUnet from minDiffusion.
"""
class ConvBlock(nn.Module):
    def __init__(self, input_channels=3, output_channels=3):
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_channels, self.output_channels, 3, 1, 1),
            nn.GroupNorm(8, self.output_channels),
            nn.ReLU(),
            nn.Conv2d(self.output_channels, self.output_channels, 3, 1, 1),
            nn.GroupNorm(8, self.output_channels),
            nn.ReLU(),
            nn.Conv2d(self.output_channels, self.output_channels, 3, 1, 1),
            nn.GroupNorm(8, self.output_channels),
            nn.ReLU()
        )


    def forward(self, x):
        return self.conv(x)


class DownBlock(nn.Module):
    def __init__(self, input_channels=3, output_channels=3):
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        
        self.conv = nn.Sequential(
            ConvBlock(self.input_channels, self.output_channels),
            nn.MaxPool2d(2)
        )


    def forward(self, x):
        return self.conv(x)


class UpBlock(nn.Module):
    def __init__(self, input_channels=3, output_channels=3):
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(self.input_channels, self.output_channels, 2, 2),
            ConvBlock(self.output_channels, self.output_channels),
            ConvBlock(self.output_channels, self.output_channels)
        )


    def forward(self, x, residual):
        x = torch.cat((x, residual), dim=1)
        return self.conv(x)


class SimpleNet(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, hidden_size=128):
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.hidden_size = hidden_size

        self.first = ConvBlock(self.input_channels, self.hidden_size)
        self.sin_time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(2 * self.hidden_size),
            nn.Linear(2 * self.hidden_size, 2 * self.hidden_size),
            nn.ReLU(),
            nn.Linear(2 * self.hidden_size, 2 * self.hidden_size),
        )
        
        self.down1 = DownBlock(self.hidden_size, self.hidden_size)
        self.down2 = DownBlock(self.hidden_size, 2 * self.hidden_size)
        self.down3 = DownBlock(2 * self.hidden_size, 2 * self.hidden_size)
        
        self.avg = nn.Sequential(
            nn.AvgPool2d(4),
            nn.ReLU()
        )
        
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * self.hidden_size, 2 * self.hidden_size, 4, 4),
            nn.GroupNorm(8, 2 * self.hidden_size),
            nn.ReLU()
        )
        self.up1 = UpBlock(4 * self.hidden_size, 2 * self.hidden_size)
        self.up2 = UpBlock(4 * self.hidden_size, self.hidden_size)
        self.up3 = UpBlock(2 * self.hidden_size, self.hidden_size)
        
        self.out = nn.Conv2d(2 * self.hidden_size, self.output_channels, 3, 1, 1)
        
        
    def forward(self, x, t):
        x = self.first(x)
        embeddings = self.sin_time_mlp(t).view(-1, 2 * self.hidden_size, 1, 1)
        
        down1 = self.down1(x)
        down2 = self.down2(down1)
        down3 = self.down3(down2)
        
        avg = self.avg(down3)
        up0 = self.up0(avg + embeddings)
        
        up1 = self.up1(up0, down3)
        up2 = self.up2(up1, down2)
        up3 = self.up3(up2, down1)
        out = self.out(torch.cat((up3, x), dim=1))
        return out     

In [None]:
model = SimpleNet(3,3,128).to(DEVICE)

In [None]:
opt = Adam(model.parameters(), lr=INIT_LR)
loss_fn = nn.MSELoss()

## Training
Training is done by taking random timesteps and learning the noise. The test model was trained by overfitting on a single batch to check model functionality.

In [None]:
for e in range(0, EPOCHS):
    model.train()
    pbar = tqdm(dataloader)
    for x_0, _ in pbar:
        x_0 = x_0.to(DEVICE)
        x_t, noise, t = diffusion.forward(x_0)
        noise_hat = model(x_t, t)
        loss = loss_fn(noise, noise_hat)
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.set_postfix(EPOCH=e, MSE=loss.item())
print(f"Training loop finished")

In [None]:
torch.save(model.state_dict(), f"{MODEL_OUT_DIR}/{time.time_ns() // 1_000_000}.pt")

## Sampling
Sampling is done in a batch in an attempt to generate some acceptable images.

In [None]:
model.eval()
with torch.no_grad():
    sample = diffusion.sample(8, (3, 32, 32), model)
    grid = make_grid(sample, normalize=True, value_range=(-1, 1), nrow=4)
    save_image(grid, f"{SAMPLE_OUT_DIR}/sample_cifar_{time.time_ns() // 1_000_000}.png")
print("Images saved")

## Loading
Loading the saved model.

In [None]:
model = SimpleNet(3,3,128).to(DEVICE)
model.load_state_dict(torch.load(f"{MODEL_OUT_DIR}/1686675513310.pt"))