# DDPM

I will be implementing the Denoising Diffusion Probabilistic Model from the paper with the same name by Ho et al. (2020)
Generation will be on MNIST.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch.nn.init as init


## The model

Taken from https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-/blob/main/Diffusion/Model.py

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [29]:
device

device(type='cpu')

In [2]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * np.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

torch.Size([8, 3, 32, 32])


In [52]:
def gather_reshape(v, t, x_shape):
    """
    :param v: tensor of parameters (such as beta)
    :param t: tensor of indices, for each element in batch has a corresponding index
    :param x_shape: shape of the output tensor (to easily find x_t)
    """
    out = torch.gather(v, index=t, dim=0).float().to(device)
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

In [54]:
class DiffusionLoss(nn.Module):
    def __init__(self, T, model, beta_1, beta_T):
        super().__init__()
        self.T = T
        self.model = model
        self.register_buffer('beta', torch.linspace(beta_1, beta_T, T))
        a = 1 - self.beta
        alpha = torch.cumprod(a, 0)
        self.register_buffer('sqrt_alpha', alpha.sqrt())
        self.register_buffer('sqrt_1m_alpha', (1 - alpha))

    def forward(self, train_x):
        B, C, H, W = train_x.shape
        t = torch.randint(self.T, (B, ))  # time index is uniform (int)
        noise = torch.randn_like(train_x)  # epsilon
        x_t = gather_reshape(self.sqrt_alpha, t, train_x.shape) * train_x + gather_reshape(self.sqrt_1m_alpha, t, train_x.shape) * noise  # according to equation
        loss = F.mse_loss(self.model(x_t, t), noise)
        return loss
    

class DiffusionSampler(nn.Module):
    def __init__(self, T, model, beta_1, beta_T):
        super().__init__()
        self.T = T
        self.model = model
        self.register_buffer('beta', torch.linspace(beta_1, beta_T, T))
        self.register_buffer('variance', torch.cat([torch.tensor([0]), self.beta[1:]]))
        a = 1 - self.beta
        alpha = torch.cumprod(a, 0)
        self.register_buffer('coeff_prev', 1 / alpha.sqrt())
        self.register_buffer('coeff_noise', self.coeff_prev * (1 - self.beta) / (1 - alpha).sqrt())

    def forward(self, x):
        for t in reversed(range(self.T)):
            if t == 0:
                noise = 0
            else:
                noise = torch.randn_like(x)
            t = x.new_ones(x.shape[0]) * t
            mean = gather_reshape(self.coeff_prev, t, x.shape) * x - gather_reshape(self.coeff_noise, t, x.shape) * self.model(x, t)
            var = gather_reshape(self.variance, t, x.shape)
            x = mean + torch.sqrt(var) * noise
        return x



In [None]:
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
model = UNet(T=1000, ch=64, ch_mult=(1, 2, 2, 4), attn=(0, 1, 2, 3), num_res_blocks=2, dropout=0.1).to(device)
loss_fn = DiffusionLoss(T=1000, model=model, beta_1=0.0001, beta_T=0.02)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for epoch in tqdm(range(100)):
    for x, _ in train_loader:
        x = x.to(device)
        loss = loss_fn(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())
    print(f'Epoch {epoch} Loss {loss.item()}')

In [44]:
x = torch.randn(2, 3, 32, 32)
t = torch.randint(2, (2,))

In [45]:
model = UNet(2, 64, [2, 2, 2, 2], [1, 3], 2, 0.1).to(device)

In [58]:
sum(p.numel() for p in model.parameters())

10951299

In [56]:
sampler = DiffusionSampler(2, model, 0.01, 0.1).to(device)

In [57]:
sampler(x)

tensor([[[[-5.1063e-01, -1.1846e+00,  1.9892e+00,  ..., -3.9252e-02,
           -1.0488e+00, -1.2237e+00],
          [ 1.4269e+00,  1.0801e+00,  8.1339e-01,  ..., -1.0344e+00,
           -1.6252e-01,  2.5414e+00],
          [-2.4462e-01,  2.0766e+00, -1.3673e-02,  ..., -6.4854e-01,
           -2.0041e-01, -1.1301e+00],
          ...,
          [-8.7303e-01, -1.9073e-01, -2.4993e-01,  ...,  3.0115e-01,
            5.7494e-01,  2.6118e-01],
          [ 4.3245e-01, -2.6581e-01,  3.0842e-01,  ..., -6.8901e-01,
            7.9575e-01, -3.0081e-01],
          [ 8.3180e-01,  1.4765e+00, -1.0293e+00,  ..., -9.5463e-01,
           -1.3505e+00,  1.8644e+00]],

         [[ 2.7015e+00, -1.1578e+00,  1.8737e-01,  ...,  5.5893e-01,
            2.7249e+00, -2.5220e+00],
          [-1.3063e+00, -1.4486e+00, -1.1028e+00,  ...,  7.4785e-01,
            2.0098e+00,  6.4185e-01],
          [-7.8450e-01, -2.3460e+00, -8.4807e-01,  ...,  1.4677e-01,
           -9.7963e-03, -1.0759e+00],
          ...,
     

In [68]:
a= torch.randn(2, 3, 32, 32)
b = torch.randn(2, 3, 32, 32)

F.mse_loss(a, b)

tensor(2.0528)