In [97]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from abc import ABC, abstractmethod
from typing import Callable

In [163]:
"""
From https://arxiv.org/pdf/2011.13456; target is ∇ log p(xt | x0)
"""
def sbgm_loss(pred, target, sigma):
    se = (pred - target) ** 2
    w = sigma.pow(2).unsqueeze(1)
    return (w * se).mean()

class SBGM():
    def __init__(self, beta: torch.Tensor):
        self.T = beta.shape[0]
        self.alpha_bar = torch.cumprod(1 - beta, dim=0)

    def forward(self, x0):
        B = x0.shape[0]
        t = torch.randint(1, self.T, (B,))
        alpha_bar_t = torch.take(self.alpha_bar, t)
        eps = torch.randn(B)
        sigma_t = torch.sqrt(1 - alpha_bar_t)
        xt = torch.sqrt(alpha_bar_t) * x0 + sigma_t * eps
        return t, xt, (eps, sigma_t)

    def drift(self, eps, sigma):
        return -eps / sigma

beta = torch.linspace(0, 1, 250)
sbgm = SBGM(beta)
model = nn.Sequential(
    nn.Linear(2, 8),
    nn.ReLU(),
    nn.Linear(8, 4),
    nn.ReLU(),
    nn.Linear(4, 1),
)

In [164]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

In [165]:
data_dist = torch.distributions.mixture_same_family.MixtureSameFamily(
    torch.distributions.Categorical(torch.tensor([1, 1])),
    torch.distributions.Normal(torch.Tensor([-4, 4]), torch.Tensor([1, 1]))
)

In [157]:
x = torch.Tensor([1, 2, 3, 4, 5])
t = torch.Tensor([-1, -2, -3, -4, -5])
torch.stack((x, t), dim=1)

tensor([[ 1., -1.],
        [ 2., -2.],
        [ 3., -3.],
        [ 4., -4.],
        [ 5., -5.]])

In [166]:
EPOCHS = 50
B = 5000

model.train(True)
for epoch in range(EPOCHS):
    optimizer.zero_grad()
    print(f"EPOCH {epoch}")
    x0 = data_dist.sample(torch.Size([B]))
    t, xt, (eps, sigma) = sbgm.forward(x0)
    pred = model(torch.stack((xt, t), dim=1))
    target = sbgm.drift(eps, sigma).unsqueeze(1)
    loss = sbgm_loss(pred, target, sigma).mean()
    print(f"Loss {loss}")
    loss.backward()
    optimizer.step()

EPOCH 0
Loss 34.119232177734375
EPOCH 1
Loss 31.85466957092285
EPOCH 2
Loss 28.50072479248047
EPOCH 3
Loss 26.924367904663086
EPOCH 4
Loss 24.780254364013672
EPOCH 5
Loss 23.694568634033203
EPOCH 6
Loss 21.406719207763672
EPOCH 7
Loss 19.524141311645508
EPOCH 8
Loss 18.236099243164062
EPOCH 9
Loss 16.27016258239746
EPOCH 10
Loss 14.662297248840332
EPOCH 11
Loss 13.245734214782715
EPOCH 12
Loss 11.882158279418945
EPOCH 13
Loss 10.823841094970703
EPOCH 14
Loss 9.448290824890137
EPOCH 15
Loss 8.580032348632812
EPOCH 16
Loss 7.491120338439941
EPOCH 17
Loss 6.657615661621094
EPOCH 18
Loss 5.996732234954834
EPOCH 19
Loss 5.246511936187744
EPOCH 20
Loss 4.552018642425537
EPOCH 21
Loss 4.004706382751465
EPOCH 22
Loss 3.421097755432129
EPOCH 23
Loss 2.9624276161193848
EPOCH 24
Loss 2.5680902004241943
EPOCH 25
Loss 2.2404026985168457
EPOCH 26
Loss 1.9472671747207642
EPOCH 27
Loss 1.7287918329238892
EPOCH 28
Loss 1.529773235321045
EPOCH 29
Loss 1.3920978307724
EPOCH 30
Loss 1.2824277877807617
EPO