### Loading Packages and Checking Device

In [None]:
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tqdm.auto import tqdm

from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid

In [None]:
# Using a gpu is adviced since it speeds up training significantly (~3 hours on cpu, but only ~1 hour on gpu)
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(dev)

### Loading MNIST-Dataset and Preprocessing

In [None]:
# Normalize our pictures to a range from -1 to 1 since this matches range of Gaussian noise better
# We also pad the images so that they work better with the U-Net Pooling.

pre_process = transforms.Compose([
    transforms.Pad(2),                      # 28x28 → 32x32
    transforms.ToTensor(),                  # 0..1
    transforms.Normalize((0.5,), (0.5,))    # -1..1
])


dataset = datasets.MNIST(root='./data', train=True, download=True, transform=pre_process)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

### Creating Noise Schedule and Forward Process

In [None]:
# Creating the noise schedule
T = 1000
betas = torch.linspace(0.0002, 0.02, T, device=dev)
alphas = 1 - betas
alpha_bars = torch.cumprod(alphas, dim=0)

# Jumping from x0 directly to xt through explicit noise adding
def noising(x0, t, noise):

    sqrt_alpha_bars_t = torch.sqrt(alpha_bars[t])[:, None, None, None]
    sqrt_one_minus_alpha_bars_t = torch.sqrt(1 - alpha_bars[t])[:, None, None, None]

    xt = sqrt_alpha_bars_t * x0 + sqrt_one_minus_alpha_bars_t * noise

    return xt

### Neural Network for Noise-Prediction in the Reverse Process

In [None]:
# Our model architecture is a standard U-Net. We dont use self attention or residual connections.

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, emb_dim=64):
        super(DoubleConv, self).__init__()
        self.time_mlp = nn.Linear(emb_dim, out_ch)

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
        )

    def forward(self, x, t):
        x = self.conv(x)
        t_emb = self.time_mlp(t)
        return x + t_emb[:,:,None,None]


class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        self.up_scale = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)

    def forward(self, x1, x2):
        x1 = self.up_scale(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x1, x2], dim=1)
        return x


class DownLayer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DownLayer, self).__init__()
        self.pool = nn.MaxPool2d(2, stride=2, padding=0)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x, t):
        a = self.pool(x)
        x = self.conv(a, t)
        return x


class UpLayer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UpLayer, self).__init__()
        self.up = Up(in_ch, out_ch)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2, t):
        a = self.up(x1, x2)
        x = self.conv(a, t)
        return x

# To condition the net on the timestep t, we 'translate' t in the a 'language' the net understands better.
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim=64):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        dev = time.device
        half_dim = self.dim//2
        factor = torch.log(torch.tensor(10_000, device=dev))/ half_dim
        w = torch.exp(-torch.arange(half_dim, device=dev)*factor)
        phase_tensor = torch.outer(time, w)

        emb = torch.empty(len(time), self.dim, device=dev)
        emb[:,0::2] = torch.sin(phase_tensor)
        emb[:,1::2] = torch.cos(phase_tensor)
        return emb

# Final architecture
class DenoisingUNet(nn.Module):
    def __init__(self):
        super(DenoisingUNet, self).__init__()

        self.t_embedding = SinusoidalPositionEmbeddings()

        self.conv1 = DoubleConv(1, 64)
        self.down1 = DownLayer(64, 128)
        self.down2 = DownLayer(128, 256)
        self.down3 = DownLayer(256, 512)
        self.up1 = UpLayer(512, 256)
        self.up2 = UpLayer(256, 128)
        self.up3 = UpLayer(128, 64)
        self.last_conv = nn.Conv2d(64, 1, 1)

    def forward(self, x, t):
        t = self.t_embedding(t)
        x1 = self.conv1(x, t)
        x2 = self.down1(x1, t)
        x3 = self.down2(x2, t)
        x4 = self.down3(x3, t)
        x1_up = self.up1(x4, x3, t)
        x2_up = self.up2(x1_up, x2, t)
        x3_up = self.up3(x2_up, x1, t)
        output = self.last_conv(x3_up)
        return output

### Sampling Procedure

In [None]:
# Sampling function to generate images
def sample(n=64):
    xt = torch.randn(n,1,32,32, device=dev)
    for t in reversed(range(0,T)):
        t_tensor = torch.full((n,), t, dtype=torch.long, device=dev)
        with torch.no_grad():
          eps_pred = model(xt, t_tensor)

        z = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt)

        xt = 1/torch.sqrt(alphas[t]) * (xt - (1-alphas[t])/torch.sqrt(1-alpha_bars[t]) * eps_pred) + torch.sqrt(betas[t]) * z
    return xt

### Plotting Samples

In [None]:
def plot_samples(x0, title='Samples', save=True):
    samples = x0.detach().cpu()
    samples = (samples + 1) / 2  # Rescale to [0, 1]
    samples = samples.clamp(0, 1)

    grid = make_grid(samples)
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
    plt.axis('off')
    plt.title(title)
    if save==True:
        plt.savefig(f"media/{title}.png")
    plt.show()

### Training Loop with Performance Tracking

In [None]:
# Parameters
model = DenoisingUNet().to(dev)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
loss_fn = nn.MSELoss()
epochs=100

# Training loop (takes about 1 hour to finish)
for epoch in range(epochs):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
    epoch_loss=0
    for b in pbar:
        x0 = b[0].to(dev)
        optimizer.zero_grad()

        t = torch.randint(0, T, (x0.shape[0],), device=dev)

        eps = torch.randn_like(x0)
        xt = noising(x0, t, eps)

        eps_pred = model(xt, t)

        loss = loss_fn(eps_pred, eps)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix(epoch_loss=epoch_loss)

    # Plotting the samples every 10th epoch to see learning progression
    if (epoch+1)%10==0:
        model.eval()
        with torch.no_grad():
            x0_pred = sample(n=8)
            plot_samples(x0_pred, title=f"Samples at Epoch: {epoch+1}", save=False)
        model.train()

### Generating Images

In [None]:
# Sample pictures from our model and plot them. 
model.eval()
x0_pred = sample()
plot_samples(x0_pred, title="Generated Samples")