In [1]:
import importlib
import tiny_video_dit
import dit_training
import video_dataset
import torch
from torch.utils.data import DataLoader, random_split

importlib.reload(tiny_video_dit)
importlib.reload(dit_training)
importlib.reload(video_dataset)

from tiny_video_dit import VideoDiT
from dit_training import train_video_dit
from video_dataset import VideoDataset

In [3]:
import gc

gc.collect()

478

In [4]:
device = torch.device("cpu")

In [5]:
dataset = VideoDataset("downsampled_psft_videos")
train_size = 800
val_size = 200
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
trainloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
valloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)
model = VideoDiT().to(device)

In [6]:
num_epochs = 15
lr = 1e-4

train_video_dit(model, trainloader, valloader, num_epochs=num_epochs, lr=lr, device=device)

Epoch 1/15: 100%|██████████| 200/200 [04:59<00:00,  1.50s/batch]


Epoch 1 loss: 1.0123408102989198


Epoch 2/15: 100%|██████████| 200/200 [04:52<00:00,  1.46s/batch]


Epoch 2 loss: 0.9958772939443589


Epoch 3/15: 100%|██████████| 200/200 [05:59<00:00,  1.80s/batch]


Epoch 3 loss: 0.8654862490296363


Epoch 4/15: 100%|██████████| 200/200 [04:48<00:00,  1.44s/batch]


Epoch 4 loss: 0.7111909911036491


Epoch 5/15: 100%|██████████| 200/200 [04:51<00:00,  1.46s/batch]


Epoch 5 loss: 0.6146021865308284


Epoch 6/15: 100%|██████████| 200/200 [04:52<00:00,  1.46s/batch]


Epoch 6 loss: 0.5564515973627567


Epoch 7/15: 100%|██████████| 200/200 [04:46<00:00,  1.43s/batch]


Epoch 7 loss: 0.5080086004734039


Epoch 8/15: 100%|██████████| 200/200 [05:00<00:00,  1.50s/batch]


Epoch 8 loss: 0.476918338984251


Epoch 9/15: 100%|██████████| 200/200 [04:51<00:00,  1.46s/batch]


Epoch 9 loss: 0.4288316252827644


Epoch 10/15: 100%|██████████| 200/200 [04:51<00:00,  1.46s/batch]


Epoch 10 loss: 0.4170646706223488


Epoch 11/15: 100%|██████████| 200/200 [05:11<00:00,  1.56s/batch]


Epoch 11 loss: 0.3806696970760822


Epoch 12/15: 100%|██████████| 200/200 [05:21<00:00,  1.61s/batch]


Epoch 12 loss: 0.3441214190423489


Epoch 13/15: 100%|██████████| 200/200 [05:18<00:00,  1.59s/batch]


Epoch 13 loss: 0.32711549147963526


Epoch 14/15: 100%|██████████| 200/200 [05:14<00:00,  1.57s/batch]


Epoch 14 loss: 0.31020453676581383


Epoch 15/15: 100%|██████████| 200/200 [05:21<00:00,  1.61s/batch]

Epoch 15 loss: 0.2993907058238983
Finished training!





In [8]:
torch.save(model.state_dict(), "video_dit_model.pth")

In [9]:
import imageio
from dit_training import beta_schedule, unpatchify_video, alpha_schedules

def compute_sigma(t, alphas_cum):
    return torch.sqrt(1.0 - alphas_cum[t])

def reverse_diffusion(model, T, first_frame, device):
    betas = beta_schedule(T)
    _, alphas_cum, _ = alpha_schedules(betas)
    alphas_cum = alphas_cum.to(device)

    video_shape = (1, 16, 64, 64, 3)
    noisy_video = torch.randn(video_shape, device=device)

    for t in range(T - 1, -1, -1):
        pred_noise = model(noisy_video, torch.tensor([t], device=device), first_frame)
        noisy_video = noisy_video - unpatchify_video(pred_noise)

        if t > 0:
            sigma = compute_sigma(t, alphas_cum)
            noise = torch.randn_like(noisy_video) * sigma
            noisy_video += noise

    denoised_video = noisy_video.clip(0, 1)
    return denoised_video

val_batch = next(iter(valloader))
sample_video = val_batch[0].unsqueeze(0)

T = 100
first_frame = sample_video[:, 0, :, :, :]

In [10]:
denoised = reverse_diffusion(model, T, first_frame, device)

