## Score-based modelling on MNIST

In this practical session, we will train a score-based model on MNIST. You will have to implement: 

1) The forward pass of a U-Net network. 

Reminder: a U-Net is a multi-scale neural network made of skip connections with concatenation. See there: https://en.wikipedia.org/wiki/U-Net 

2) The loss function of a score-based network (denoising objective): 
\begin{align}
  \mathcal{L}_\theta &= \sum^T_{t=1} \mathbb{E}_{x_0\sim p_{\mathrm{data}}, x_t \sim p_{\sigma_t}(x_t \mid x_0)} \left[ \left\| s_{\theta}(x, t) - \nabla \log p_{t} (x_t \mid x_0) \right\|^2_2 \right] \\ &= \sum^T_{t=1} \mathbb{E}_{x_0\sim p_{\mathrm{data}}, x_t \sim p_{\sigma_t}(x_t \mid x_0)} \left[ \left\| s_{\theta}(x, t) - \frac{x_t - x_0}{\sigma_t^2} \right\|^2_2 \right]
\end{align}

3) The annealed Langevin Dynamics sampling algorithm, which runs the Langevin dynamic over the $T$ noise levels (cf. the slides). 

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

transform = transforms.Compose(
    [transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.MNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.MNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(-npimg+1, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# PyTorch models inherit from torch.nn.Module
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None,n_groups=8):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.GroupNorm(n_groups,mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(n_groups,out_channels)
        )
        self.act = nn.GELU()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                                        nn.GroupNorm(n_groups,out_channels))
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        out = self.double_conv(x)
        x = self.shortcut(x)
        x = self.act(out + x)
        return x


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down_conv = nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x):
        x = self.down_conv(x)
        return self.conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # if bilinear, use the normal convolutions to reduce the number of channels
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels,sigmas):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        
        self.inc = DoubleConv(n_channels, 32,n_groups=1)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 128)
        #self.mid = DoubleConv(256,256)
        self.up1 = Up(256, 64)
        self.up2 = Up(128, 32)
        self.up3 = Up(64, 32)
        self.outc = OutConv(32, n_channels)

        self.sigmas = sigmas

    def forward(self, x, y):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        #x4 = self.mid(x4)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.outc(x)

        used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:])))

        x = x / used_sigmas

        return x

def get_sigmas(sigma_begin,sigma_end,nb_sigma):
    sigmas = torch.tensor(
        np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end),
                               nb_sigma))).float()
    return sigmas

def anneal_dsm_score_estimation(scorenet, samples, sigma_index, sigmas, anneal_power=2.):
    used_sigmas = sigmas[sigma_index].view(samples.shape[0], *([1] * len(samples.shape[1:])))
    
    perturbed_samples = samples + torch.randn_like(samples) * used_sigmas
    target = - 1 / (used_sigmas ** 2) * (perturbed_samples - samples)
    scores = scorenet(perturbed_samples, sigma_index)
    target = target.view(target.shape[0], -1)
    scores = scores.view(scores.shape[0], -1)
    loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1) * used_sigmas.squeeze() ** anneal_power

    return loss.mean(dim=0)

In [None]:
cuda_available = torch.cuda.is_available()
if cuda_available:
  print("GPUs available.")
else:
  print("No GPUs available.")

batch_size = 64
sigma_begin = 10
sigma_end = 0.01
nb_sigma = 500 
nb_epochs = 1
learning_rate = 0.0005

sigma_levels = get_sigmas(sigma_begin,sigma_end,nb_sigma)
if cuda_available:
  sigma_levels = sigma_levels.cuda(0)
print('hierarchy of sigma : ',sigma_levels)
model = UNet(1, sigma_levels)
if cuda_available:
  model = model.cuda(0)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Create data loaders for our datasets
training_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)

You should reach a loss of approximately 30/35 in 5 epochs! Go to the sampling part only you reached this loss level, otherwise it will not give good generated samples. It would be perfect to reach a loss below 30. If you have access to GPUs, go for more epochs and increase the capacity (feature maps width) of the U-Net.

In [None]:
for e in range(nb_epochs):
    running_loss = 0. 
    freq_print = 50
    
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, _ = data
        
        # Draw a random batch of integers: it will be used to select random noise levels in the anneal_dsm_score_estimation
        sigma_index = torch.randint(0,nb_sigma,(inputs.shape[0],))

        if cuda_available:
          inputs = inputs.cuda(0)
          sigma_index = sigma_index.cuda(0)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        loss = anneal_dsm_score_estimation(model,inputs,sigma_index,sigma_levels)

        # Compute the loss and its gradients
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % freq_print == (freq_print-1):
            last_loss = running_loss / freq_print # loss per batch
            print('epoch {}  batch {} loss: {}'.format(e + 1, i + 1, last_loss))
            running_loss = 0.

In [None]:
@torch.no_grad()
def sampling(model,sigma_levels,n_per_sigma=5,step_lr=0.0000062,cuda_available=False):
    #@torch.no_grad()
    #def sampling(model,sigma_levels,n_per_sigma=100,step_lr=0.00002):
    x = torch.randn((32,1,32,32)) 
    if cuda_available:
        x = x.cuda(0) 
    x = x * sigma_levels[0]
    for i in range(sigma_levels.shape[0]):
        sigma = sigma_levels[i]
        sigma_index = torch.ones(x.shape[0], device=x.device) * i
        sigma_index = sigma_index.long()
        if cuda_available:
          sigma_index = sigma_index.cuda(0)
        step_size = step_lr * (sigma / sigma_levels[-1]) ** 2
        for n in range(n_per_sigma):
            noise = torch.randn_like(x)
            score_estimation = model(x,sigma_index)
            x = x + step_size * score_estimation + torch.sqrt(step_size * 2) * noise
    return x

x = sampling(model,sigma_levels,cuda_available=cuda_available)
if cuda_available:
    x = x.detach().cpu()
x = torch.clip(x,-1,1)
img_grid = torchvision.utils.make_grid(x)
matplotlib_imshow(img_grid, one_channel=True)