# Score Based Generative Models Using Stochastic Differential Equations

Incorporating the same architecture in DDPM using the VP SDE framework.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
import numpy as np
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
import os
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Architecture

In [None]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * np.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(1, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 1, 3, stride=1, padding=1)
        )
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

## SDE

For now, I will be showing only the VP-SDE as it has better results.

In [9]:
class VPSDE:
    """
    Variance Preserving Stochastic Differential Equation.
    From the paper, we know that this is x(t) = -1/2 beta(t) x(t) dt + sqrt(beta(t)) dW(t)
    """
    def __init__(self, beta_0, beta_T, T):
        self.beta_0 = beta_0
        self.beta_T = beta_T
        self.T = T
        self.betas = torch.linspace(beta_0, beta_T, T).to(device)  # discretized
        self.alpha = 1 - self.betas
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.coeff_prev_diffusion = 1 / self.alpha.sqrt()
        self.coeff_noise_diffusion = self.coeff_prev_diffusion * (1 - self.alpha) / (1 - self.alpha_bar).sqrt()

    def beta(self, t):  # continuous beta
        return self.beta_0 + (self.beta_T - self.beta_0) * t
    
    def sde(self, x, t):
        """
        x: torch.Tensor(B, C, H, W)
        t: torch.Tensor(B)
        """
        return -0.5 * self.beta(t)[:, None, None, None] * x, torch.sqrt(self.beta(t))[:, None, None, None]
    
    def dist(self, x, t):
        """
        x: torch.Tensor(B, C, H, W) - the original x
        t: torch.Tensor(B) \in [0, 1] - the time
        Return the mean and variance of x(t) given x(0).
        """
        exponent_integral = -0.5 * self.beta_0 * t - 0.25 * (self.beta_T - self.beta_0) * t**2
        mean = x * torch.exp(exponent_integral)
        std = torch.sqrt((1 - torch.exp(2 * exponent_integral)))
        return mean, std
    
    def reverse_sde(self, model, x, t):
        """
        x: torch.Tensor(B, C, H, W) 
        t: torch.Tensor(B) \in [0, 1]
        model: Score model
        Return the drift and diffusion of the reverse SDE, as explained in the paper.
        """
        drift, diffusion = self.sde(x, t)
        return drift - diffusion ** 2 * model(x, t), diffusion
    
    def reverse_pf_ode(self, model, x, t):
        """
        Return the drift of the reverse ODE using Probability Flow sampling.
        """
        drift, diffusion = self.sde(x, t)
        return drift - 0.5 * diffusion ** 2 * model(x, t)

## The model trainer

In [13]:
class SDELoss(nn.Module):
    def __init__(self, sde):
        super().__init__()
        self.sde = sde

    def forward(self, model, x):
        """
        model: Score model
        x: torch.Tensor(B, C, H, W) - train data.
        """
        t = torch.rand(x.shape[0], device=device)
        noise = torch.randn_like(x)
        mean, std = self.sde.dist(x, t)
        x_t = mean + std * noise
        score_prediction = model(x_t, t)
        return F.mse_loss(score_prediction * std, -noise)


## Train

In [None]:
model_config = {
    "state": "train",
    "epoch": 50,
    "batch_size": 64,
    "T": 1000,
    "channel": 32,
    "channel_mult": [1, 2],
    "attn": [],
    "num_res_blocks": 2,
    "dropout": 0.15,
    "lr": 5e-4,
    "multiplier": 2.,
    "beta_1": 1e-4,
    "beta_T": 0.02,
    "img_size": 28,
    "grad_clip": 1.,
    "device": "cuda:0",
    "training_load_weight": None,
    "save_weight_dir": "./Checkpoints/",
    "test_load_weight": "ckpt_49_.pt",
    "sampled_dir": "",
    "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
    "sampledImgName": "SampledNoGuidenceImgs.png",
    "nrow": 8,
    "show_process": True,
    "corrector_steps": 2,
    "corrector_step_size": 0.1,
}

In [None]:
def train(model_config):
    model = UNet(
        T=model_config["T"],
        ch=model_config["channel"],
        ch_mult=model_config["channel_mult"],
        attn=model_config["attn"],
        num_res_blocks=model_config["num_res_blocks"],
        dropout=model_config["dropout"]
    ).to(model_config["device"])

    sde = VPSDE(model_config["beta_1"], model_config["beta_T"], model_config["T"])
    sde_loss = SDELoss(sde)
    optimizer = optim.Adam(model.parameters(), lr=model_config["lr"])
    model.train()

    train_dataset = MNIST(
        root="./data", train=True, download=True,
        transform=transforms.Compose([
            transforms.Resize(model_config["img_size"]),
            transforms.ToTensor()
        ])
    )
    train_loader = DataLoader(
        train_dataset, batch_size=model_config["batch_size"], shuffle=True
    )

    if not os.path.exists(model_config["save_weight_dir"]):
        os.makedirs(model_config["save_weight_dir"])

    if model_config["training_load_weight"]:
        model.load_state_dict(torch.load(model_config["training_load_weight"]))

    for epoch in range(model_config["epoch"]):
        for x, _ in tqdm(train_loader):
            x = x.to(model_config["device"])
            optimizer.zero_grad()
            loss = sde_loss(model, x)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), model_config["grad_clip"])
            optimizer.step()
        if model_config["show_process"]:
            print(f"Epoch {epoch}, Loss: {loss.item()}")
        torch.save(model.state_dict(), model_config["save_weight_dir"] + f"ckpt_{epoch}_.pt")

    

## Sampling

In [None]:
import abc

class Predictor(abc.ABC):
    def __init__(self, model, sde):
        self.model = model
        self.sde = sde

    def predictor_step(self, prev_x, t):
        pass


class Corrector(abc.ABC):
    def __init__(self, model, sde):
        self.model = model
        self.sde = sde

    def corrector_step(self, prev_x, t):
        pass

In [None]:
class ReverseDiffusionPredictor(Predictor):
    def __init__(self, model, sde):
        super().__init__(model, sde)
    
    def predictor_step(self, prev_x, t):
        drift, diffusion = self.sde.reverse_sde(self.model, prev_x, t)
        return prev_x - drift + diffusion ** 2 * self.model(prev_x, t) + diffusion * torch.randn_like(prev_x).to(prev_x.device)
        

class LangevinDynamicsCorrector(Corrector):
    """
    Based on Algorithm 5 in the paper.
    """
    def __init__(self, model, sde, corrector_step_size):
        super().__init__(model, sde)
        self.corrector_step_size = corrector_step_size
    
    def corrector_step(self, prev_x, t):
        score = self.model(prev_x, t)
        noise = torch.randn_like(prev_x).to(prev_x.device)
        score_norm = torch.linalg.norm(score.reshape(score.shape[0], -1), dim=1)
        noise_norm = torch.linalg.norm(noise.reshape(noise.shape[0], -1), dim=1)
        step_size = 2 * self.sde.alpha[(t*self.sde.T).int()] * (self.corrector_step_size * noise_norm / score_norm) ** 2
        return prev_x + step_size[:, None, None, None] * score + torch.sqrt(2 * step_size)[:, None, None, None] * noise
    
        

def predictor_corrector_step(model_config, predictor, corrector, x, t):
    """
    predictor: Predictor
    corrector: Corrector
    x: torch.Tensor(B, C, H, W)
    t: torch.Tensor(B)
    """
    x = predictor.predictor_step(x, t)
    for _ in range(model_config['corrector_steps']):
        x = corrector.corrector_step(x, t)
    return x


In [25]:
def sample(model_config):
    model = UNet(
    T=model_config["T"],
    ch=model_config["channel"],
    ch_mult=model_config["channel_mult"],
    attn=model_config["attn"],
    num_res_blocks=model_config["num_res_blocks"],
    dropout=model_config["dropout"]
    ).to(device)
    
    ckpt = torch.load(os.path.join(
        model_config["save_weight_dir"], model_config["test_load_weight"]), map_location=device, weights_only=True)
    model.load_state_dict(ckpt)
    model.eval()

    sde = VPSDE(model_config["beta_1"], model_config["beta_T"], model_config["T"])
    predictor = ReverseDiffusionPredictor(model, sde)
    corrector = LangevinDynamicsCorrector(model, sde, model_config["corrector_step_size"])

    x = torch.randn(model_config["nrow"] ** 2, 1, model_config["img_size"], model_config["img_size"]).to(device)
    noisy_images = x.cpu()
    for t in reversed(range(model_config["pc_steps"])):
        x = predictor_corrector_step(model_config, predictor, corrector, x, t)

    sampled_images = x.cpu()
    sampled_images = torch.clamp(sampled_images, 0, 1)
    noisy_images = torch.clamp(noisy_images, 0, 1)

    save_image(sampled_images, model_config["sampled_dir"] + model_config["sampledImgName"], nrow=model_config["nrow"])

tensor([[[[1., 1., 4., 1., 3.],
          [1., 3., 3., 4., 3.],
          [1., 4., 3., 4., 0.],
          [4., 4., 0., 3., 4.]],

         [[2., 1., 0., 2., 4.],
          [4., 3., 4., 4., 0.],
          [4., 4., 2., 1., 1.],
          [0., 2., 0., 0., 3.]],

         [[2., 4., 4., 3., 0.],
          [3., 1., 1., 2., 3.],
          [0., 0., 3., 3., 4.],
          [2., 2., 1., 3., 1.]]],


        [[[3., 4., 1., 1., 2.],
          [4., 4., 4., 2., 1.],
          [3., 0., 4., 2., 0.],
          [1., 1., 0., 2., 1.]],

         [[1., 4., 2., 3., 2.],
          [3., 2., 1., 1., 4.],
          [3., 1., 2., 1., 2.],
          [0., 3., 3., 3., 0.]],

         [[3., 1., 1., 0., 0.],
          [4., 4., 1., 0., 4.],
          [0., 3., 2., 0., 3.],
          [3., 0., 2., 1., 0.]]]])
tensor([[1., 1., 4., 1., 3., 1., 3., 3., 4., 3., 1., 4., 3., 4., 0., 4., 4., 0.,
         3., 4., 2., 1., 0., 2., 4., 4., 3., 4., 4., 0., 4., 4., 2., 1., 1., 0.,
         2., 0., 0., 3., 2., 4., 4., 3., 0., 3., 1., 1.

tensor([426.0000, 327.0000])

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [17]:
def diffusion_sample(model_config):
    model = UNet(
        ch=model_config["channel"],
        ch_mult=model_config["channel_mult"],
        attn=model_config["attn"],
        num_res_blocks=model_config["num_res_blocks"],
        dropout=model_config["dropout"]
    ).to(device)
    
    ckpt = torch.load(os.path.join(
        model_config["save_weight_dir"], model_config["test_load_weight"]), map_location=device, weights_only=True)
    model.load_state_dict(ckpt)
    model.eval()
    sde = VPSDE(model_config["beta_1"], model_config["beta_T"], model_config["T"])

    x = torch.randn(model_config["nrow"] ** 2, 1, model_config["img_size"], model_config["img_size"]).to(device)
    noisy_images = x.cpu()
    for time in reversed(range(model_config["T"])):
        t = time / model_config["T"]
        if t == 0:
            eps = 0
        else:
            eps = torch.randn_like(x)
        t = x.new_ones(x.shape[0], dtype=int) * t
        noise_pred = -model(x, t) * sde.dist(x, t)[1]
        mean = sde.coeff_prev_diffusion * x - sde.coeff_noise_diffusion * noise_pred
        var = sde.beta[time]
        x = mean + torch.sqrt(var) * eps

    sampled_images = x.cpu()
    sampled_images = torch.clamp(sampled_images, 0, 1)
    save_image(sampled_images, model_config["sampled_dir"] + model_config["sampledImgName"], nrow=model_config["nrow"])

TypeError: new_zeros() missing 1 required positional arguments: "size"

tensor([[[[1.0000e-04]]]])
