In [1]:
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from tqdm import tqdm

In [2]:
from dataclasses import dataclass, asdict

@dataclass
class ModelConfig:
    nT: int = 1000
    beta_s: float = 1e-4
    beta_e: float = 2e-2
    img_dim: int = 32
    n_channels: int = 1

cfg_m = ModelConfig()

In [39]:
class DDPM(nn.Module):
    def __init__(self, nT, beta_s, beta_e, img_dim, n_channels):
        super().__init__()
        conv_block = lambda in_c, out_c: nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=7, padding=3),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU()
        )
        self.model = nn.Sequential(
            conv_block(n_channels, 64),
            conv_block(64, 128),
            conv_block(128, 256),
            conv_block(256, 512),
            conv_block(512, 256),
            conv_block(256, 128),
            conv_block(128, 64),
            nn.Conv2d(64, n_channels, kernel_size=3, padding=1)
        )
        self.img_dims = (n_channels, img_dim, img_dim)

        self.nT = nT
        beta = torch.linspace(beta_s, beta_e, nT)  # Linear schedule
        alpha = 1.0 - beta
        alpha_bar = torch.cumprod(alpha, dim=0)

        var_schedule = {
            'sqrt_alpha_bar': alpha_bar.sqrt(),
            'sqrt_one_minus_alpha_bar': torch.sqrt(1.0 - alpha_bar),
            'rsqrt_alpha': alpha.rsqrt(),
            'beta_rsqrt_omab': beta * torch.rsqrt(1.0 - alpha_bar),
            'sigma': beta.sqrt()
        }
        for name, tensor in var_schedule.items():
            self.register_buffer(name, tensor.reshape(-1, 1, 1, 1))

    def forward(self, x0, eps, t):
        x_t = self.sqrt_alpha_bar[t, ...] * x0 + self.sqrt_one_minus_alpha_bar[t, ...] * eps
        eps_pred = self.model(x_t)  # self.model(x_t, t)
        return eps_pred

    @torch.inference_mode()
    def sample(self, n_samples, n_steps=None):
        if n_steps is None:
            n_steps = self.nT
        x_t = torch.randn([n_samples, *self.img_dims], device=self.sigma.device)

        for t in reversed(range(n_steps)):
            z = torch.randn_like(x_t) if t > 0 else 0.0
            eps = self.model(x_t)  # self.model(x_t, torch.full([n_samples], t, device=x_t.device))
            x_t = self.rsqrt_alpha[t, ...] * (x_t - self.beta_rsqrt_omab[t, ...] * eps) + self.sigma[t, ...] * z

        return x_t

torch.manual_seed(3985)
ddpm = DDPM(**asdict(cfg_m)).to('cuda')

In [47]:
def ddpm_schedules(beta1: float, beta2: float, T: int):
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }

var_sched = ddpm_schedules(1e-4, 2e-2, 1000)
var_sched = {k: v.to('cuda') for k, v in var_sched.items()}

In [48]:
x_t = torch.randn([1, *ddpm.img_dims], device='cuda')
print(f'x_0: {x_t.min().item():.4f}, {x_t.max().item():.4f}, {x_t.mean().item():.4f}')

for t in reversed(range(1000)):
    z = torch.rand_like(x_t)
    eps = ddpm.model(x_t)
    if (t + 1) % 50 == 0:
        print('eps: ', f'{eps.min().item():.4f}, {eps.max().item():.4f}, {eps.mean().item():.4f}')

    rsqrt_alpha_t = var_sched['oneover_sqrta'][t].reshape(-1, 1, 1, 1)
    if (t + 1) % 50 == 0:
        print(f'rsqrt_alpha_{t}: {rsqrt_alpha_t.min().item():.4f}, {rsqrt_alpha_t.max().item():.4f}, {rsqrt_alpha_t.mean().item():.4f}')

    beta_rsqrt_omab_t = var_sched['mab_over_sqrtmab'][t].reshape(-1, 1, 1, 1)
    if (t + 1) % 50 == 0:
        print(f'beta_rsqrt_omab_{t}: {beta_rsqrt_omab_t.min().item():.4f}, {beta_rsqrt_omab_t.max().item():.4f}, {beta_rsqrt_omab_t.mean().item():.4f}')

    sigma_t = var_sched['sqrt_beta_t'][t].reshape(-1, 1, 1, 1)
    if (t + 1) % 50 == 0:
        print(f'sigma_{t}: {sigma_t.min().item():.4f}, {sigma_t.max().item():.4f}, {sigma_t.mean().item():.4f}')

    x_t = rsqrt_alpha_t * (x_t - beta_rsqrt_omab_t * eps) + sigma_t * z
    if (t + 1) % 50 == 0:
        print(f'x_{t}: {x_t.min().item():.4f}, {x_t.max().item():.4f}, {x_t.mean().item():.4f}')
        print('='*30)

x_0: -3.1213, 3.0137, -0.0284
eps:  -1.9394, 0.9716, -0.4507
rsqrt_alpha_999: 1.0101, 1.0101, 1.0101
beta_rsqrt_omab_999: 0.0200, 0.0200, 0.0200
sigma_999: 0.1414, 0.1414, 0.1414
x_999: -3.0654, 3.1292, 0.0520
eps:  -1.5674, 0.6231, -0.4503
rsqrt_alpha_949: 1.0096, 1.0096, 1.0096
beta_rsqrt_omab_949: 0.0190, 0.0190, 0.0190
sigma_949: 0.1378, 0.1378, 0.1378
x_949: -0.5556, 10.4634, 5.1287
eps:  -1.5160, 0.8430, -0.4477
rsqrt_alpha_899: 1.0091, 1.0091, 1.0091
beta_rsqrt_omab_899: 0.0180, 0.0180, 0.0180
sigma_899: 0.1341, 0.1341, 0.1341
x_899: 3.5703, 21.3412, 13.0226
eps:  -1.6866, 0.6745, -0.4476
rsqrt_alpha_849: 1.0086, 1.0086, 1.0086
beta_rsqrt_omab_849: 0.0170, 0.0170, 0.0170
sigma_849: 0.1304, 0.1304, 0.1304
x_849: 9.0905, 38.4132, 24.8721
eps:  -1.7596, 0.6730, -0.4483
rsqrt_alpha_799: 1.0081, 1.0081, 1.0081
beta_rsqrt_omab_799: 0.0160, 0.0160, 0.0160
sigma_799: 0.1265, 0.1265, 0.1265
x_799: 16.9520, 63.2330, 42.0922
eps:  -1.7434, 0.6748, -0.4489
rsqrt_alpha_749: 1.0076, 1.0076, 1

In [42]:
x_t = torch.randn([1, *ddpm.img_dims], device='cuda')
print(f'x_0: {x_t.min().item():.4f}, {x_t.max().item():.4f}, {x_t.mean().item():.4f}')

for t in reversed(range(1000)):
    z = torch.rand_like(x_t)
    eps = ddpm.model(x_t)
    if (t + 1) % 50 == 0:
        print('eps: ', f'{eps.min().item():.4f}, {eps.max().item():.4f}, {eps.mean().item():.4f}')

    rsqrt_alpha_t = ddpm.rsqrt_alpha[t, ...]
    if (t + 1) % 50 == 0:
        print(f'rsqrt_alpha_{t}: {rsqrt_alpha_t.min().item():.4f}, {rsqrt_alpha_t.max().item():.4f}, {rsqrt_alpha_t.mean().item():.4f}')

    beta_rsqrt_omab_t = ddpm.beta_rsqrt_omab[t, ...]
    if (t + 1) % 50 == 0:
        print(f'beta_rsqrt_omab_{t}: {beta_rsqrt_omab_t.min().item():.4f}, {beta_rsqrt_omab_t.max().item():.4f}, {beta_rsqrt_omab_t.mean().item():.4f}')

    sigma_t = ddpm.sigma[t, ...]
    if (t + 1) % 50 == 0:
        print(f'sigma_{t}: {sigma_t.min().item():.4f}, {sigma_t.max().item():.4f}, {sigma_t.mean().item():.4f}')

    x_t = rsqrt_alpha_t * (x_t - beta_rsqrt_omab_t * eps) + sigma_t * z
    if (t + 1) % 50 == 0:
        print(f'x_{t}: {x_t.min().item():.4f}, {x_t.max().item():.4f}, {x_t.mean().item():.4f}')
        print('='*30)

x_0: -3.0920, 3.1590, -0.0616
eps:  -1.7370, 0.5680, -0.4516
rsqrt_alpha_999: 1.0102, 1.0102, 1.0102
beta_rsqrt_omab_999: 0.0200, 0.0200, 0.0200
sigma_999: 0.1414, 0.1414, 0.1414
x_999: -3.1001, 3.2815, 0.0168
eps:  -1.8719, 0.6611, -0.4499
rsqrt_alpha_949: 1.0096, 1.0096, 1.0096
beta_rsqrt_omab_949: 0.0190, 0.0190, 0.0190
sigma_949: 0.1379, 0.1379, 0.1379
x_949: 0.1748, 10.5276, 5.0767
eps:  -1.7749, 0.5422, -0.4513
rsqrt_alpha_899: 1.0091, 1.0091, 1.0091
beta_rsqrt_omab_899: 0.0180, 0.0180, 0.0180
sigma_899: 0.1342, 0.1342, 0.1342
x_899: 4.6673, 22.8883, 12.9272
eps:  -1.9163, 0.5978, -0.4504
rsqrt_alpha_849: 1.0086, 1.0086, 1.0086
beta_rsqrt_omab_849: 0.0170, 0.0170, 0.0170
sigma_849: 0.1304, 0.1304, 0.1304
x_849: 11.1083, 42.1722, 24.7351
eps:  -1.8862, 0.6170, -0.4510
rsqrt_alpha_799: 1.0081, 1.0081, 1.0081
beta_rsqrt_omab_799: 0.0160, 0.0160, 0.0160
sigma_799: 0.1266, 0.1266, 0.1266
x_799: 20.1733, 69.3169, 41.9050
eps:  -1.7639, 0.6398, -0.4500
rsqrt_alpha_749: 1.0076, 1.0076, 1