In [1]:
from dataclasses import dataclass

import numpy as np
import torch
from PIL import Image
from video_diffusion_pytorch import GaussianDiffusion, Trainer, Unet3D

In [2]:
@dataclass
class Config:
    image_size = 64
    batch_size = 1
    num_samples = 1

config = Config()

In [3]:
model = Unet3D(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
)

In [4]:
diffusion = GaussianDiffusion(
    model,
    image_size = config.image_size,
    num_frames = 20,
    timesteps = 1000,
    loss_type = 'l1'
).cuda()

In [10]:
trainer = Trainer(
    diffusion,
    '/home/s_gladkykh/thesis/gif_dataset_64',
    train_batch_size = 12,
    train_lr = 1e-4,
    save_and_sample_every = 1000,
    train_num_steps = 700000,
    gradient_accumulate_every = 2,
    ema_decay = 0.995,
    amp = True
)
data = torch.load("results/model-97.pt")
trainer.step = data['step']
trainer.model.load_state_dict(data['model'])
trainer.ema_model.load_state_dict(data['ema'])
trainer.scaler.load_state_dict(data['scaler'])

found 1710 videos as gif files at /home/s_gladkykh/thesis/gif_dataset_64


In [11]:
def create_gif(arr, gif_path, duration=100, size=64):
    image_list = [Image.fromarray(np.uint8(myarray*255)) for myarray in arr]

    image_list[0].save(
            gif_path,
            save_all=True,
            append_images=image_list[1:],
            duration=100,
            loop=1)

In [12]:
generated_samples = 0
while generated_samples < config.num_samples:
    gen = trainer.model.sample(batch_size = config.batch_size)
    for j in range(0, config.batch_size):
        create_gif(gen[j].permute(1,2,3,0).cpu().numpy(), f"test/{generated_samples+j}.gif")
    generated_samples += 1

sampling loop time step: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:31<00:00, 10.92it/s]


In [15]:
gen.shape

torch.Size([1, 3, 20, 64, 64])