In [1]:
import random
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
from torch.optim import Adam
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from helpers import sample_batch, MLP

In [None]:
def plot(model, filename):
    N = 5_000
    x0= sample_batch(N)
    samples = model.sample(N)
    
    fontsize=22
    nrows = 2
    ncols = 3
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 6))
    # Common x and y limits
    x_limits = (-2, 2)
    y_limits = (-2, 2)
    
    data = [
            [x0, model.forward_process(x0, 20)[-1], model.forward_process(x0, 40)[-1]],
            [samples[0], samples[20], samples[40]]
           ]
    titles = ['t=0','t=T/2', 't=T']
    
    for i in range(nrows):  # Iterate over rows
        for j in range(ncols):  # Iterate over columns
            colour = "b" if i == 0 else "r"
            if i == 0 and j == 0:
                ax[i,j].set_ylabel(r"$q\mathbf{x}^{(0..T)})$")
            if i == 1 and j == 0:
                ax[i,j].set_ylabel(r"$p\mathbf{x}^{(0..T)})$")
            ax[i,j].scatter(data[i][j][:, 0].data.numpy(), data[i][j][:, 1].data.numpy(), alpha=0.1, c= colour, s=4)
            ax[i,j].set_title(titles[j])
            ax[i,j].set_xlim(x_limits)
            ax[i,j].set_ylim(y_limits)
            plt.gca().set_aspect("equal")
    
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight")
    plt.close()

In [2]:
class DiffusionModel():
    def __init__(self, T, model: nn.Module, dim=2):
        self.betas = torch.sigmoid(torch.linspace(-18, 10, T)) * (3e-1 - 1e-5) + 1e-5
        self.alphas = 1 - self.betas
        self.alphas_bar = torch.cumprod(self.alphas, dim=0)
        self.T = T
        self.model = model
        self.dim = dim

    def forward_process(self, x0, t):
        """
        param t: number of diffusion steps
        """
        assert t > 0, "should be greater than zero"
        t <= self.T, f"t should be lower or equal than {self.T}" 
        
        t = t - 1 # index start at zero
        mu = torch.sqrt(self.alphas_bar[t]) * x0
        std = torch.sqrt(1 - self.alphas_bar[t])
        epsilon = torch.randn_like(x0)
        xt = mu + epsilon * std  # data ~ N(mu, std)

        m1 = torch.sqrt(self.alphas_bar[t-1]) * self.betas[t] / (1 - self.alphas_bar[t])
        m2 = torch.sqrt(self.alphas[t]) * (1-self.alphas_bar[t-1]) / (1 - self.alphas_bar[t])
        mu_q = m1 * x0 + m2 * xt
        std_q = torch.sqrt( (1-self.alphas_bar[t-1]) / (1-self.alphas_bar[t]) * self.betas[t] )
        
        return mu_q, std_q, xt

    def reverse_process(self, xt, t):
        assert t > 0, "should be greater than zero"
        t <= self.T, f"t should be lower or equal than {self.T}"

        t = t - 1 # index start at zero
        mu, std = self.model(xt, t)
        epsilon = torch.randn_like(xt)
        return  mu, std, mu + epsilon * std # data ~ N(mu, std)

    def sample(self, batch_size):
        noise = torch.randn(batch_size, self.dim)
        x = noise

        # adding the starting noise already makes the list composed of 41 elements (like in the paper)
        samples = [x]
        for t in range(self.T, 0, -1):
            # Edge effect of diffusione model (in the last step don't do anything)
            if not (t == 1):
                _, _, x = self.reverse_process(x, t)
            samples.append(x)

        return samples[:: -1] # reverse results in the list

    def get_loss(self, x0):
        """
        param x0: batch [batch_size, self.dim]
        """
        # sample t
        t = torch.randint(2, 40+1, (1,))
        mu_q , sigma_q, xt = self.forward_process(x0, t)
        mu_p, sigma_p,  xt_minus1 = self.reverse_process(xt.float(), t)
        # KL divergence for two gaussian distribution KL(q||p)
        KL = torch.log(sigma_p) -  torch.log(sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 * sigma_p**2)
        K_prime = - KL.mean() # we want to maximize K
        loss = -K_prime # should be minimized
        return loss

In [None]:
def train(diffusion_model, optimizer, batch_size, epochs):

    training_loss = []
    for epoch in tqdm(range(epochs)):
        xO = sample_batch(batch_size)
        loss = diffusion_model.get_loss(xO)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        training_loss.append(loss.item())

        if epoch % 100 == 0:
            plt.plot(training_loss)
            plt.savefig(f"train_loss_epoch_{epoch}.png")
            plt.show()
            plot(diffusion_model, f"./train_images/train_epoch_{epoch}.png")
    
    return training_loss