In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

In [2]:
def imshow(img):
    img = img * 0.5 + 0.5
    img = img.squeeze(0)
    npimg = img.clip(0,1).detach().cpu().numpy().transpose(1, 2, 0)
    npimg -= npimg.min(); npimg /= npimg.max()
    plt.imshow(npimg)
    plt.show()

In [3]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.down_layers = torch.nn.ModuleList([
            nn.Conv2d(3, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(128, 32, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
        ])
        self.final_conv1 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.final_conv2 = nn.Conv2d(16, 3, kernel_size=3, padding=1)
        self.act = nn.LeakyReLU()
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        skips = []

        for i, layer in enumerate(self.down_layers):
            x = self.act(layer(x))
            if i < 2:
                skips.append(x)
                x = self.downscale(x)

        for i, layer in enumerate(self.up_layers):
            if i > 0:
                x = self.upscale(x)
                x = torch.concatenate( [x, skips.pop()], axis=1)
            x = self.act(layer(x))

        x = self.act(self.final_conv1(x))
        x = self.final_conv2(x)
        return x

In [4]:
class DiffusionParameters(nn.Module):
    def __init__(self, beta_0, beta_T, T, device):
        super(DiffusionParameters, self).__init__()

        self.betas = torch.linspace(beta_0, beta_T, T+1, device=device)
        self.sqrt_betas = torch.sqrt(self.betas)
        self.alphas = 1 - self.betas
        self.alphas_bar = torch.cumprod(self.alphas, dim=0).to(device)
        self.one_minus_alphas_bar = 1.0 - self.alphas_bar
        self.sqrt_one_minus_alphas_bar = torch.sqrt(self.one_minus_alphas_bar)
        self.complex_mult = self.betas / self.sqrt_one_minus_alphas_bar
        self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_alphas_inv = 1.0 / self.sqrt_alphas

In [5]:
config = {
    'batch_size': 128,
    'num_epoch': 500,
    'beta_0': 1e-4,
    'beta_T': 0.02,
    'T': 1000,
    'lr': 1e-4
}

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataset_cats = Subset(train_dataset, np.where(np.array(train_dataset.targets) == 3)[0])
train_loader = DataLoader(train_dataset_cats, batch_size=config['batch_size'], shuffle=True)

test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_dataset_cats = Subset(test_dataset, np.where(np.array(test_dataset.targets) == 3)[0])
test_loader = DataLoader(test_dataset_cats, batch_size=config['batch_size'], shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43400612.21it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [7]:
model = UNet().to(device)
diffusion_params = DiffusionParameters(
    config['beta_0'], config['beta_T'], config['T'], device
)

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
loss_fn = nn.MSELoss()

In [8]:
def sample_from_noise(model, diffusion_params, image_shape, device):
    model.eval()
    with torch.no_grad():
        image_t = torch.randn((1,) + image_shape).to(device)

        for t in range(config['T'], 0, -1):

            image_t = diffusion_params.sqrt_alphas_inv[t] * (
                image_t - diffusion_params.complex_mult[t] * model(image_t)
            )

            if t > 1:
                noise = torch.randn(image_shape).to(device)
                image_t += diffusion_params.sqrt_betas[t] * noise

        imshow(image_t)

In [None]:
# training
for epoch in range(config['num_epoch']):
    model.train()
    losses = []
    for batch, _ in tqdm(train_loader):
        batch = batch.to(device)

        noise = torch.randn(batch.shape).to(device)

        timesteps = torch.randint(config['T'], (batch.shape[0], 1, 1, 1), device=device)

        batch_with_noise = batch * diffusion_params.sqrt_alphas_bar[timesteps] + \
                           noise * diffusion_params.sqrt_one_minus_alphas_bar[timesteps]

        output = model(batch_with_noise)
        loss = torch.norm(output - noise, dim=1).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    lr_scheduler.step()
    print(f'Epoch: {epoch}, Loss: {np.mean(losses)}')

    if (epoch + 1) % 25 == 0:
        model.eval()
        sample_from_noise(model, diffusion_params, batch.shape[1:], device)

In [None]:
sample_from_noise(model, diffusion_params, (3, 32, 32), device)