In [1]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from abc import ABC, abstractmethod
from typing import Callable

In [2]:
"""
MSE with per-sample weighting
pred, target, and weights should all have shape (B)
"""
def loss(pred, target, weights):
    se = (pred - target) ** 2
    return (weights.unsqueeze(1) * se).mean()

"""
From https://arxiv.org/pdf/2011.13456; target is ∇ log p(xt | x0)
"""
def sbgm_loss(pred, target, sigma_t):
    return loss(pred, target, sigma_t.pow(2))

In [None]:
def ddpm_forward(x0, T, alpha_bar, sigma):
    B = x0.shape[0]
    t = torch.randint(1, T, (B,))
    alpha_bar_t = torch.take(alpha_bar, t)
    eps = torch.randn(B)
    sigma_t = torch.take(sigma, t)
    xt = torch.sqrt(alpha_bar_t) * x0 + sigma_t * eps
    return t, xt, eps, sigma_t

In [None]:
class SBGM():
    def __init__(self, beta: torch.Tensor):
        self.T = beta.shape[0]
        self.alpha_bar = torch.cumprod(1 - beta, dim=0)
        self.sigma = torch.sqrt(1 - self.alpha_bar)

    def forward(self, x0):
        return ddpm_forward(x0, self.T, self.alpha_bar, self.sigma)

    def drift(self, eps, sigma):
        return -eps / sigma

beta = torch.linspace(0, 1, 250)
sbgm = SBGM(beta)
model = nn.Sequential(
    nn.Linear(2, 8),
    nn.ReLU(),
    nn.Linear(8, 1))
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)

In [14]:
data_dist = torch.distributions.mixture_same_family.MixtureSameFamily(
    torch.distributions.Categorical(torch.tensor([1, 1])),
    torch.distributions.Normal(torch.Tensor([-4, 4]), torch.Tensor([1, 1]))
)

In [16]:
EPOCHS = 100
B = 50000

model.train(True)
for epoch in range(EPOCHS):
    optimizer.zero_grad()
    x0 = data_dist.sample(torch.Size([B]))
    t, xt, eps, sigma_t = sbgm.forward(x0)
    pred = model(torch.stack((xt, t), dim=1))
    target = sbgm.drift(eps, sigma_t).unsqueeze(1)
    loss = sbgm_loss(pred, target, sigma_t).mean()
    if epoch % 10 == 0:
        print(f"Epoch {epoch} loss {loss}")
    loss.backward()
    optimizer.step()

Epoch 0 loss 2700.250244140625
Epoch 10 loss 63.83141326904297
Epoch 20 loss 130.23263549804688
Epoch 30 loss 16.963125228881836
Epoch 40 loss 8.12024211883545
Epoch 50 loss 5.364798545837402
Epoch 60 loss 0.8837070465087891
Epoch 70 loss 1.3042747974395752
Epoch 80 loss 0.5459372401237488
Epoch 90 loss 0.5773472189903259
