In [1]:
"""
Extremely Minimalistic Implementation of DDPM
https://arxiv.org/abs/2006.11239
Everything is self contained. (Except for pytorch and torchvision... of course)
"""

from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from IPython.display import Image
from event_diffusion import DDPM, ddpm_schedules, UNet

In [2]:
print(ddpm_schedules(beta1=1e-4, beta2=0.02, T=10))

{'alpha_t': tensor([0.9999, 0.9979, 0.9959, 0.9939, 0.9919, 0.9900, 0.9880, 0.9860, 0.9840,
        0.9820, 0.9800]), 'sqrt_beta_t': tensor([0.0100, 0.0457, 0.0639, 0.0779, 0.0898, 0.1002, 0.1097, 0.1184, 0.1266,
        0.1342, 0.1414]), 'alphabar_t': tensor([0.9999, 0.9978, 0.9937, 0.9877, 0.9797, 0.9699, 0.9582, 0.9448, 0.9296,
        0.9129, 0.8946])}


In [3]:
import tonic
import torchvision

def train_frames(epochs:int=100, diffusion_steps:int=1000, lr=2e-4, device="cuda:0", batch_size=12) -> None:

    # ddpm = DDPM(autoencoder_model=AutoEncoderModel(1), betas=(1e-4, 0.02), n_T=diffusion_steps)
    ddpm = DDPM(autoencoder_model=UNet(1,1), betas=(1e-4, 0.02), n_T=diffusion_steps)
    ddpm.to(device)

    transform=torchvision.transforms.Compose([
        torch.as_tensor,
        lambda x: x.float().unsqueeze(1) / x.max(),
        torchvision.transforms.CenterCrop((180, 180)),
        torchvision.transforms.Resize((90,90)),
        torchvision.transforms.Normalize((0.5,), (0.5)),
        ]) # convert to [-1, 1] range

    def frame_transform(data):
        events, imu, frames = data
        frames = transform(frames['frames'])
        return frames # events, imu,

    # dataset = tonic.datasets.DAVISDATA(save_to="data", recording=["shapes_6dof", "shapes_translation" , "shapes_rotation"], transform=frame_transform)
    # dataset = tonic.datasets.DAVISDATA(save_to="data", recording=["slider_close", "slider_far" , "slider_hdr_close", "slider_hdr_far", "slider_depth"], transform=frame_transform)
    dataset = tonic.datasets.DAVISDATA(save_to="data", recording="all", transform=frame_transform)
    
    frames = torch.empty(0, 1, 90, 90)
    for imgs, targets in dataset:
        frames = torch.cat((frames, imgs))

    print(frames.shape)
    dataloader = torch.utils.data.DataLoader(frames, batch_size=batch_size, shuffle=True, num_workers=4)

    optim = torch.optim.Adam(ddpm.parameters(), lr=lr)

    for i in range(epochs):
        ddpm.train()

        progress_bar = tqdm(dataloader)
        loss_current = None
        for x in progress_bar:
            optim.zero_grad()
            x = x.to(device)
            loss = ddpm(x)
            loss.backward()
            if loss_current is None:
                loss_current = loss.item()
            else:
                loss_current = 0.9 * loss_current + 0.1 * loss.item()
            progress_bar.set_description(f"loss: {loss_current:.4f}")
            optim.step()

        ddpm.eval()
        with torch.no_grad():
            xh = ddpm.sample(4, (1, 90, 90), device)
            grid = make_grid(xh, nrow=4)
            save_image(grid, f"./images/ddpm_sample_{i}.png")
            display(Image(f"./images/ddpm_sample_{i}.png", width=600, height=600))

            # save model
            torch.save(ddpm.state_dict(), f"./ddpm_mnist.pth")

    display(Image(f"./images/ddpm_sample_{epochs-1}.png", width=600, height=600))

In [4]:
train_frames(epochs=30, diffusion_steps=1000, batch_size=64) # diffusion_steps aka T

torch.Size([27259, 1, 90, 90])


  0%|          | 0/426 [00:00<?, ?it/s]


TypeError: forward() takes 2 positional arguments but 3 were given