In [None]:
"""
Extremely Minimalistic Implementation of DDPM
https://arxiv.org/abs/2006.11239
Everything is self contained. (Except for pytorch and torchvision... of course)
"""

from typing import Dict, Tuple
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from IPython.display import Image

In [None]:
def ddpm_schedules(beta1: float, beta2: float, T: int) -> Dict[str, torch.Tensor]:
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    alphabar_t = torch.cumprod(alpha_t, dim=0)

    return {
        "alpha_t": alpha_t,
        "sqrt_beta_t": sqrt_beta_t,
        "alphabar_t": alphabar_t,
    }

In [None]:
print(ddpm_schedules(beta1=1e-4, beta2=0.02, T=10))

In [None]:
class DDPM(nn.Module):
    def __init__(
        self,
        autoencoder_model: nn.Module,
        betas: Tuple[float, float],
        n_T: int,
        criterion: nn.Module = nn.MSELoss(),
    ) -> None:
        super(DDPM, self).__init__()
        self.autoencoder_model = autoencoder_model

        # register_buffer allows us to freely access these tensors by name. It helps device placement.
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.criterion = criterion

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Implements Algorithm 1 from the paper.
        Makes forward diffusion x_t, and tries to guess epsilon value from x_t using autoencoder_model.
        """

        t = torch.randint(1, self.n_T, (x.shape[0],)).to(x.device)  # t ~ Uniform(0, n_T)
        eps = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            torch.sqrt(self.alphabar_t[t, None, None, None]) * x
            + torch.sqrt(1 - self.alphabar_t[t, None, None, None]) * eps
        )

        return self.criterion(eps, self.autoencoder_model(x_t, t / self.n_T))

    def sample(self, n_sample: int, size, device) -> torch.Tensor:
        """
        Implements Algorithm 2 from the paper.
        """

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1)

        for i in range(self.n_T, 0, -1):
            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
            eps = self.autoencoder_model(x_i, i / self.n_T)
            x_i = (
                (1 / torch.sqrt(self.alpha_t[i])) *
                (x_i - eps * (1 - self.alpha_t[i]) / torch.sqrt(1 - self.alphabar_t[i]))
                + self.sqrt_beta_t[i] * z
            )

        return x_i


In [None]:
blk = lambda ic, oc: nn.Sequential(
    nn.Conv2d(ic, oc, 7, padding=3),
    nn.BatchNorm2d(oc),
    nn.LeakyReLU(),
)

class AutoEncoderModel(nn.Module):
    """
    This should be unet-like, but let's don't think about the model too much :P
    Basically, any universal R^n -> R^n model should work.
    """

    def __init__(self, n_channel: int) -> None:
        super(AutoEncoderModel, self).__init__()
        self.conv = nn.Sequential(
            blk(n_channel, 64),
            blk(64, 128),
            blk(128, 256),
            blk(256, 512),
            blk(512, 256),
            blk(256, 128),
            blk(128, 64),
            nn.Conv2d(64, n_channel, 3, padding=1),
        )

    def forward(self, x, t) -> torch.Tensor:
        # Lets think about using t later. Paper uses positional embeddings.
        return self.conv(x)

In [None]:
import tonic
import torchvision

def train_frames(epochs:int=100, diffusion_steps:int=1000, lr=2e-4, device="cuda:0", batch_size=12) -> None:

    ddpm = DDPM(autoencoder_model=AutoEncoderModel(1), betas=(1e-4, 0.02), n_T=diffusion_steps)
    ddpm.to(device)

    transform=torchvision.transforms.Compose([
        torch.as_tensor,
        lambda x: x.float().unsqueeze(1) / x.max(),
        torchvision.transforms.CenterCrop((180, 180)),
        torchvision.transforms.Resize((90,90)),
        torchvision.transforms.Normalize((0.5,), (0.5)),
        ]) # convert to [-1, 1] range

    def frame_transform(data):
        events, imu, frames = data
        frames = transform(frames['frames'])
        return frames # events, imu,

    dataset = tonic.datasets.DAVISDATA(save_to="data", recording=["shapes_6dof", "shapes_translation" , "shapes_rotation"], transform=frame_transform)
    frames1 = dataset[0][0]
    frames2 = dataset[1][0]
    frames3 = dataset[2][0]
    frames = torch.cat((frames1, frames2, frames3))
    print(frames.shape)
    dataloader = torch.utils.data.DataLoader(frames, batch_size=batch_size, shuffle=True, num_workers=4)

    optim = torch.optim.Adam(ddpm.parameters(), lr=lr)

    for i in range(epochs):
        ddpm.train()

        progress_bar = tqdm(dataloader)
        loss_current = None
        for x in progress_bar:
            optim.zero_grad()
            x = x.to(device)
            loss = ddpm(x)
            loss.backward()
            if loss_current is None:
                loss_current = loss.item()
            else:
                loss_current = 0.9 * loss_current + 0.1 * loss.item()
            progress_bar.set_description(f"loss: {loss_current:.4f}")
            optim.step()

        ddpm.eval()
        with torch.no_grad():
            xh = ddpm.sample(4, (1, 90, 90), device)
            grid = make_grid(xh, nrow=4)
            save_image(grid, f"./ddpm_sample_{i}.png")
            display(Image(f"./ddpm_sample_{i}.png", width=600, height=600))

            # save model
            torch.save(ddpm.state_dict(), f"./ddpm_mnist.pth")

    display(Image(f"./ddpm_sample_{epochs-1}.png", width=448, height=448))

In [None]:
train_frames(epochs=10, diffusion_steps=500) # diffusion_steps aka T