In [1]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
import math
import torch
import torch.nn as nn
import random

from IPython.display import Image
from pathlib import Path
from diffusers.optimization import get_cosine_schedule_with_warmup
from tqdm import tqdm

from src.models.vae import VAE
from src.const import DATA_PATH, SEED
from src.preprocess import get_data_loader
from src.models.utils import pos_encoding

BATCH_SIZE = 32

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
vae = VAE(device)
train_dataloader = get_data_loader(DATA_PATH, BATCH_SIZE, vae)

for idx, batch in enumerate(train_dataloader):
    break



In [4]:
class EmbeddingBlock(nn.Module):
    """Embedding block for UNet."""

    def __init__(self, n_steps, d_model):
        super(EmbeddingBlock, self).__init__()
        self.n_steps = n_steps
        self.t_embed = self.init_pos_encoding(d_model)
        # self.l1 = nn.Linear(16, 32)
        # self.l2 = nn.Linear(32, d_model)
        # self.silu = nn.SiLU()

    def init_pos_encoding(self, d_model):
        t_embed = nn.Embedding(self.n_steps, d_model)
        t_embed.weight.data = pos_encoding(self.n_steps, d_model)
        t_embed.requires_grad = False
        return t_embed

    def forward(self, t):
        t = self.t_embed(t)
        # t = self.l1(t)
        # t = self.silu(t)
        # t = self.l2(t)
        return t

class ConvBlock(nn.Module):
    """Convolutional block for UNet."""

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same")
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same")
        self.relu = nn.ReLU()
        self.bnorm = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bnorm(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bnorm(x)
        x = self.relu(x)
        return x

class DownsampleBlock(nn.Module):
    """Downsample block block for UNet."""

    def __init__(self, in_channels, out_channels):
        super(DownsampleBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2)

    def forward(self, x):
        x = self.conv(x)
        pool = self.pool(x)
        return x, pool

class UpsampleBlock(nn.Module):
    """Upsample block for UNet."""

    def __init__(self, in_channels, out_channels):
        super(UpsampleBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = ConvBlock(2*out_channels, out_channels)

    def forward(self, x, down_tensor):
        x = self.upconv(x)
        x = torch.cat((x, down_tensor), dim=1)
        x = self.conv(x)
        return x

class UNet(nn.Module):
    """UNet model for diffusion."""

    def __init__(self, batch_size, n_steps, input_size=32, in_channels=4, first_layer_channels=64):
        super(UNet, self).__init__()

        self.batch_size = batch_size

        # input size
        self.s1 = input_size
        self.s2 = self.s1 // 2
        self.s3 = self.s2 // 2

        # number of channels
        self.ch0 = in_channels
        self.ch1 = first_layer_channels
        self.ch2 = self.ch1 * 2
        self.ch3 = self.ch2 * 2

        # embedding blocks
        self.em1 = EmbeddingBlock(n_steps, in_channels * self.s1 * self.s1)
        self.em2 = EmbeddingBlock(n_steps, self.ch1 * self.s2 * self.s2)
        self.em3 = EmbeddingBlock(n_steps, self.ch2 * self.s3 * self.s3)
        self.em4 = EmbeddingBlock(n_steps, self.ch3 * self.s3 * self.s3)
        self.em5 = EmbeddingBlock(n_steps, self.ch2 * self.s2 * self.s2)

        # downsample blocks
        self.e1 = DownsampleBlock(self.ch0, self.ch1)
        self.e2 = DownsampleBlock(self.ch1, self.ch2)

        # upsample blocks
        self.d1 = UpsampleBlock(self.ch3, self.ch2)
        self.d2 = UpsampleBlock(self.ch2, self.ch1)

        # middle conv block
        self.middle = ConvBlock(self.ch2, self.ch3)

        # output layer
        self.out = nn.Conv2d(self.ch1, self.ch0, kernel_size=1, padding="same")

    def forward(self, x, t):
        t1 = self.em1(t).view(-1, self.ch0, self.s1, self.s1)
        t2 = self.em2(t).view(-1, self.ch1, self.s2, self.s2)
        t3 = self.em3(t).view(-1, self.ch2, self.s3, self.s3)
        t4 = self.em4(t).view(-1, self.ch3, self.s3, self.s3)
        t5 = self.em5(t).view(-1, self.ch2, self.s2, self.s2)

        x1, pool1 = self.e1(x + t1)
        x2, pool2 = self.e2(pool1 + t2)
        x = self.middle(pool2 + t3)
        x = self.d1(x + t4, x2)
        x = self.d2(x + t5, x1)
        x = self.out(x)
        return x

In [5]:
class CosineScheduler():

    def __init__(self, n_steps, device):
        # Save the device
        self.device = device
        t_vals = torch.arange(0, n_steps, 1).to(torch.int).to(self.device)

        def f(t):
            s = 0.008
            return torch.clamp(torch.cos(((t/n_steps + s)/(1+s)) * (torch.pi/2))**2 /\
                torch.cos(torch.tensor((s/(1+s)) * (torch.pi/2)))**2,
                1e-10,
                0.999)

        # alpha_bar_t is defined directly from the scheduler
        self.a_bar_t = f(t_vals+1).to(self.device)
        self.a_bar_t1 = f((t_vals).clamp(0, torch.inf)).to(self.device)

        # beta_t and alpha_t are defined from a_bar_t
        self.beta_t = 1 - (self.a_bar_t / self.a_bar_t1)
        self.beta_t = torch.clamp(self.beta_t, 1e-10, 0.999).to(self.device)
        self.a_t = 1 - self.beta_t

        # Roots of a and a_bar
        self.sqrt_a_t = torch.sqrt(self.a_t).to(self.device)
        self.sqrt_a_bar_t = torch.sqrt(self.a_bar_t).to(self.device)
        self.sqrt_1_minus_a_bar_t = torch.sqrt(1-self.a_bar_t).to(self.device)
        self.sqrt_a_bar_t1 = torch.sqrt(self.a_bar_t1).to(self.device)

        # Beta tilde value
        self.beta_tilde_t = (((1 - self.a_bar_t1)/(1 - self.a_bar_t)) * self.beta_t).to(self.device)

        self.beta_t = self.beta_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        self.a_t = self.a_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        self.a_bar_t = self.a_bar_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        self.a_bar_t1 = self.a_bar_t1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        self.sqrt_a_t = self.sqrt_a_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        self.sqrt_a_bar_t = self.sqrt_a_bar_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        self.sqrt_1_minus_a_bar_t = self.sqrt_1_minus_a_bar_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        self.sqrt_a_bar_t1 = self.sqrt_a_bar_t1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        self.beta_tilde_t = self.beta_tilde_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

In [6]:
class IDDPM(nn.Module):
    """Improved DDPM model for diffusion."""

    def __init__(self, unet, device):
        super(IDDPM, self).__init__()
        self.unet = unet.to(device)
        self.device = device
        self.n_steps = unet.em1.n_steps
        self.scheduler = CosineScheduler(self.n_steps, self.device)
        self.out_mean = nn.Conv2d(4, 4, 3, padding=1, groups=4)
        self.out_var = nn.Conv2d(4, 4, 3, padding=1, groups=4)

    def forward(self, x, t, eps=None):
        if eps is None:
            eps = torch.randn(x.shape).to(self.device)
        x_with_noise = self.scheduler.sqrt_a_bar_t[t] * x + self.scheduler.sqrt_1_minus_a_bar_t[t] * eps
        return x_with_noise
    
    def backward(self, x, t):
        print(x.shape, t.shape)
        out = self.unet(x, t)
        noise, v = out[:, :4], out[:, :4]
        noise = self.out_mean(noise)
        v = self.out_var(v)
        return noise, v
    
    def noise_to_mean(self, epsilon, x_t, t, corrected=True):
        # Note: Corrected function from the following:
        # https://github.com/hojonathanho/diffusion/issues/5
        beta_t = self.scheduler.beta_t[t]
        sqrt_a_t = self.scheduler.sqrt_a_t[t]
        a_bar_t = self.scheduler.a_bar_t[t]
        sqrt_a_bar_t = self.scheduler.sqrt_a_bar_t[t]
        sqrt_1_minus_a_bar_t = self.scheduler.sqrt_1_minus_a_bar_t[t]
        a_bar_t1 = self.scheduler.a_bar_t1[t]
        sqrt_a_bar_t1 = self.scheduler.sqrt_a_bar_t1[t]

        mean = torch.where(t == 0,
            (1 / sqrt_a_t) * (x_t - (beta_t / sqrt_1_minus_a_bar_t) * epsilon),
            (sqrt_a_bar_t1 * beta_t) / (1 - a_bar_t) * \
                torch.clamp((1 / sqrt_a_bar_t) * x_t - (sqrt_1_minus_a_bar_t / sqrt_a_bar_t) * epsilon, -1, 1 ) + \
                (((1 - a_bar_t1) * sqrt_a_t) / (1 - a_bar_t)) * x_t
        )
        return mean
    
    def vs_to_variance(self, v, t):
        beta_t = self.scheduler.beta_t[t]
        beta_tilde_t = self.scheduler.beta_tilde_t[t]
        
        # Return the variance value
        return torch.exp(torch.clamp(v * torch.log(beta_t) + (1 - v) * torch.log(beta_tilde_t), torch.tensor(-30, device=self.device), torch.tensor(30, device=self.device)))
        

In [7]:

class IDDPMTrainer:
    def __init__(self, iddpm, device, lr=1e-4, Lambda=1e-3):
        self.n_steps = iddpm.n_steps
        self.Lambda = Lambda
        self.best_loss = float("inf")
        self.device = device
        self.model = iddpm
        self.t_vals = np.arange(0, self.n_steps, 1)
        self.t_dist = torch.distributions.uniform.Uniform(
            float(1) - float(0.499), float(self.n_steps) + float(0.499)
        )
        self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr, eps=1e-4)
        self.mse_loss = nn.MSELoss(reduction="none").to(self.device)
        self.losses = np.zeros((self.n_steps, 10))
        self.losses_ct = np.zeros(self.n_steps, dtype=int)

    def update_losses(self, loss_vlb, t):
        for t_val, loss in zip(t, loss_vlb):
            if self.losses_ct[t_val] == 10:
                self.losses[t_val] = np.concatenate((self.losses[t_val][1:], [loss]))
            else:
                self.losses[t_val, self.losses_ct[t_val]] = loss
                self.losses_ct[t_val] += 1

    def loss_simple(self, eps, eps_theta):
        return ((eps_theta - eps) ** 2).flatten(1, -1).mean(-1)

    # Formula derived from: https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
    def loss_vlb_gauss(self, mean_real, mean_fake, var_real, var_fake):
        """KL divergence between two gaussians."""
        std_real = torch.sqrt(var_real)
        std_fake = torch.sqrt(var_fake)
        kl_div = (
            (
                torch.log(std_fake / std_real)
                + ((var_real) + (mean_real - mean_fake) ** 2) / (2 * (var_fake))
                - torch.tensor(1 / 2)
            )
            .flatten(1, -1)
            .mean(-1)
        )
        return kl_div

    def calc_losses(self, eps, eps_theta, var_theta, x, x_with_noise, t):

        mean_t_pred = self.model.noise_to_mean(eps_theta, x_with_noise, t, True)
        var_t_pred = self.model.vs_to_variance(var_theta, t)

        beta_t = self.model.scheduler.beta_t[t]
        a_bar_t = self.model.scheduler.a_bar_t[t]
        a_bar_t1 = self.model.scheduler.a_bar_t1[t]
        beta_tilde_t = self.model.scheduler.beta_tilde_t[t]
        sqrt_a_bar_t1 = self.model.scheduler.sqrt_a_bar_t1[t]
        sqrt_a_t = self.model.scheduler.sqrt_a_t[t]

        mean_t = ((sqrt_a_bar_t1 * beta_t) / (1 - a_bar_t)) * x + (
            (sqrt_a_t * (1 - a_bar_t1)) / (1 - a_bar_t)
        ) * x_with_noise

        loss_simple = self.loss_simple(eps, eps_theta)
        loss_vlb = (
            self.loss_vlb_gauss(
                mean_t, mean_t_pred.detach(), beta_tilde_t, var_t_pred
            )
            * self.Lambda
        )
        loss_hybrid = loss_simple + loss_vlb

        with torch.no_grad():
            t = t.detach().cpu().numpy()
            loss = loss_vlb.detach().cpu()
            self.update_losses(loss, t)

            if np.sum(self.losses_ct) == self.losses.size - 20:
                p_t = np.sqrt((self.losses**2).mean(-1))
                p_t = p_t / p_t.sum()
                loss = loss / torch.tensor(p_t[t], device=self.device)

        return loss_hybrid.mean(), loss_simple.mean(), loss_vlb.mean()

    def train(self, loader, n_epochs, model_store_path=None):
        self.model.train()
        n = len(loader.dataset)

        self.losses_comb = np.array([])
        self.losses_mean = np.array([])
        self.losses_var = np.array([])
        self.steps_list = np.array([])

        losses_comb_s = torch.tensor(0.0, requires_grad=False)
        losses_mean_s = torch.tensor(0.0, requires_grad=False)
        losses_var_s = torch.tensor(0.0, requires_grad=False)

        for epoch in range(n_epochs):
            for _, batch in enumerate(
                tqdm(loader, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")
            ):
                x = batch[0]
                if np.sum(self.losses_ct) == self.losses.size - 20:
                    # Weights for each value of t
                    p_t = np.sqrt((self.losses**2).mean(-1))
                    p_t = p_t / p_t.sum()
                    t = torch.tensor(
                        np.random.choice(self.t_vals, size=x.shape[0], p=p_t),
                        device=self.device,
                    )
                else:
                    t = self.t_dist.sample((x.shape[0],)).to(self.device)
                    t = torch.round(t).to(torch.long)

                with torch.no_grad():
                    eps = torch.randn_like(x).to(device)
                    x_with_noise = self.model(x, t, eps)

                print(x_with_noise.shape, t.shape)
                eps_theta, var_theta = self.model.backward(x_with_noise, t)
                loss, loss_mean, loss_var = self.calc_losses(
                    eps, eps_theta, var_theta, x, x_with_noise, t
                )
                loss = loss * x.shape[0] / n
                loss_mean = loss_mean * x.shape[0] / n
                loss_var = loss_var * x.shape[0] / n

                # Backprop the loss, but save the intermediate gradients
                loss.backward()

                losses_comb_s += loss.cpu().detach()
                losses_mean_s += loss_mean.cpu().detach()
                losses_var_s += loss_var.cpu().detach()

                self.optim.step()
                self.optim.zero_grad()

                self.losses_comb = np.append(self.losses_comb, losses_comb_s.item())
                self.losses_mean = np.append(self.losses_mean, losses_mean_s.item())
                self.losses_var = np.append(self.losses_var, losses_var_s.item())

                losses_comb_s *= 0
                losses_mean_s *= 0
                losses_var_s *= 0

            print(
                f"Loss at epoch {epoch + 1}: "
                + f"Combined: {round(self.losses_comb[-10:].mean(), 4)}    "
                f"Mean: {round(self.losses_mean[-10:].mean(), 4)}    "
                f"Variance: {round(self.losses_var[-10:].mean(), 6)}\n\n"
            )

            if model_store_path is not None and self.best_loss > loss:
                self.best_loss = loss
                torch.save(self.model.state_dict(), model_store_path)
                print("Best model ever (stored)")


In [8]:
iddpm = IDDPM(UNet(batch_size=BATCH_SIZE, n_steps=64), device).to(device)

In [9]:
trainer = IDDPMTrainer(iddpm, device)
trainer.train(train_dataloader, 10)

Epoch 1/10:   0%|[38;2;0;85;0m          [0m| 0/32 [00:00<?, ?it/s]

torch.Size([32, 4, 32, 32]) torch.Size([32])
torch.Size([32, 4, 32, 32]) torch.Size([32])


Epoch 1/10:   6%|[38;2;0;85;0m▋         [0m| 2/32 [00:05<01:16,  2.54s/it]

torch.Size([32, 4, 32, 32]) torch.Size([32])
torch.Size([32, 4, 32, 32]) torch.Size([32])


../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [14,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Epoch 1/10:   6%|[38;2;0;85;0m▋         [0m| 2/32 [00:07<01:51,  3.71s/it]

torch.Size([32, 4, 32, 32]) torch.Size([32])
torch.Size([32, 4, 32, 32]) torch.Size([32])





RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
