In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [4]:
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

X_test = torch.Tensor( testset.data ) / 255.0# - 0.5
y_test = torch.Tensor( testset.targets ).long()
X_train = torch.Tensor( trainset.data ) / 255.0# - 0.5
y_train = torch.Tensor( trainset.targets ).long()

# train_dataset = TensorDataset(X_train, y_train)

train_data = DataLoader(trainset, batch_size=256, shuffle=True, drop_last=True)
train_data_iter = iter(train_data)

test_data = DataLoader(testset, batch_size=256, shuffle=True, drop_last=True)
test_data_iter = iter(test_data)

# Noise scheduling-- how fast do we want to add noise?

We want to arrive at an isotropic Gaussian distribution: one with $0$ mean and equal covariances.

In [5]:
def linear_schedule(timesteps, start=1e-4, end=.02):
    betas = torch.linspace(start, end, timesteps)
    return betas

def cosine_schedule(timesteps, s=0.008):
    def f(t):
        return torch.cos((t / timesteps + s) / (1 + s) * 0.5 * torch.pi) ** 2
    x = torch.linspace(0, timesteps, timesteps + 1)
    alphas_cumprod = f(x) / f(torch.tensor([0]))
    betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
    betas = torch.clip(betas, 0.0001, 0.999)
    return betas

In [6]:
class Diffusion():
    def __init__(self, T: int, model : nn.Module, optimizer, device='cpu', noise_model=''):
        assert device in {'cpu', 'gpu'}, ValueError

        self.T = T
        self.model = model
        self.optimizer = optimizer
        self.device = device

        if noise_model == 'cosine':
            self.beta = cosine_schedule(timesteps=T)
        else: self.beta = linear_schedule(timesteps=T)

        self.alpha = 1.0 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, axis=0)

        self.loss_function = torch.nn.MSELoss()

    def training(self, batch_size=256):
        X_batch, _ = next(train_data_iter)
        X_0 = X_batch.to(self.device)
        
        t = torch.randint(1, self.T + 1, (batch_size,), device=self.device)
        epsilon = torch.randn_like(X_0)

        alpha_bar_t = self.alpha_bar[t-1][:, None, None, None]
        epsilon_hat = self.model(torch.sqrt(alpha_bar_t) * X_0 + torch.sqrt(1 - alpha_bar_t) * epsilon, t)

        loss = self.loss_function(epsilon, epsilon_hat)
        loss.backward()

        self.optimizer.step()
        self.optimizer.zero_grad()

        return loss.item()

    @torch.no_grad()
    def testing(self, batch_size=256):
        X_batch, _ = next(test_data_iter)
        X_0 = X_batch.to(self.device)
        
        t = torch.randint(1, self.T + 1, (batch_size,), device=self.device)
        epsilon = torch.randn_like(X_0)

        alpha_bar_t = self.alpha_bar[t-1][:, None, None, None]
        epsilon_hat = self.model(torch.sqrt(alpha_bar_t) * X_0 + torch.sqrt(1 - alpha_bar_t) * epsilon, t)

        loss = self.loss_function(epsilon, epsilon_hat)

        return loss.item()

    @torch.no_grad()
    def sampling(self, batch_size=256, image_channels=1, img_size=(28, 28)):
        assert len(img_size) == 2, ValueError

        X = torch.randn((batch_size, image_channels, *img_size), device=self.device)

        for t in range(self.T, 0, -1):
            z = torch.randn_like(X) if t > 1 else torch.zeros_like(X)
            
            beta_t = self.beta[t-1][:, None, None, None]
            alpha_t = self.alpha[t-1][:, None, None, None]
            alpha_bar_t = self.alpha_bar[t-1][:, None, None, None]

            # alpha_t_minus = self.alpha[t-2][:, None, None, None]
            # alpha_bar_t_minus = self.alpha_bar[t-2][:, None, None, None]

            mu = 1/torch.sqrt(alpha_t) * (X - (1 - alpha_t)/(torch.sqrt(1 - alpha_bar_t)) * self.noise_model(X, t))
            sigma = torch.sqrt(beta_t)
            # mu = (1 - alpha_bar_t_minus)*torch.sqrt(alpha_t)/(1 - alpha_bar_t) * X + (1-alpha_t)*torch.sqrt(alpha_bar_t_minus)/(1 - alpha_bar_t) * self.model(X)
            # sigma = 

            X = mu + sigma * z
            

In [14]:
from torch import nn
import math


class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 1
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])
        
        # Edit: Corrected a bug found by Jakub C (see YouTube comment)
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

model = SimpleUnet()

In [15]:
T = 300

lr = .001
model = UNet()

optimizer = optim.Adam(model.parameters(), lr=lr)

d = Diffusion(T=T, model=model, optimizer=optimizer)

In [16]:
lim = 50

for epoch in range(lim):
    loss = d.training()

AssertionError: 