# Video diffusion models

In [None]:
!wget https://www.kaggleusercontent.com/kf/116808580/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..jsredvRKL5mb488KjnL_ag.Q6U8Eq6OWh4aijOVHH_YHitw8nG5nzKNuviS8PRAFG0HHi0fKQsU_CbQJc3rsT0w-u2eh_yG4hLX6qcIe7vEY3KmCSusbmxHFjjXah5Sl3CrU0pkDdH66MunAuEBExsPiITkJej3kjisldzwzKHbewtykhcj_sqclyXqLc4rATXfkNy6O6BhLwuiyzQrN98VlcMpUgBCDXLJvew38DZ8MtqA0GsKFjVfA_FZjxfVBetPeOGogWuraTBfNV4b5alWjSf0NekBqE2tB1MOQ2889KCUwgeo57zMEUeyoV45Id30BfHWuClTMpnLxg3xfzEh03OgKWvrLXCnm1Xt4lJG1v1jyZirKLNHE_WAjb5emN_Xu96iQxX920HeFfB4DddxzbBuptjdnZnp0Ln9wZW0UMPVUN_JvgMM4IcwqQ6yb1jq7FS1R7SkfUR7HjMBIF53wZFW7-SzAI3Auer3qxs_bkU6xZk_XQfHTdjeGiCOv6M_2BlTFW6kXmqR5erSUPi9-kqwcg7ZDwVK3Znu9SRoqSyzW1Ja0eSbpz7hFsW6W3YtpKpQivpqWPslGMCfRGMMWIa1XUCI6wrYFB-iXtmlvc8cVug4zI4XlgJDRdJwS225qy9n6TCRQtQ18S5dKNYTq3x-e1hzE35aAGik6Y_SNA.-wfOeyfpHsSR9AZlhtEuCw/checkpoint-e1_35-s1_600000-e2_36-s2_29695-1674197094.pt

In [None]:
# Params
image_size = 128
frames = 12
process_batch_size = 128

max_run_minutes = 12 * 60 - 10

In [None]:
import time
import os
start_time = time.time()

## Install dependencies

In [None]:
!pip install imagen_pytorch==1.16.5 --no-cache-dir

## Utility functions to resize and crop GIFs

In [None]:
# GIF pre-processing

import numpy as np
from torchvision import transforms as T
from math import floor, fabs
from PIL import Image, ImageSequence


CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

def center_crop(img, new_width, new_height): 
    width = img.size[0]
    height = img.size[1]
    left = int(np.ceil((width - new_width) / 2))
    right = width - int(np.floor((width - new_width) / 2))
    top = int(np.ceil((height - new_height) / 2))
    bottom = height - int(np.floor((height - new_height) / 2))
    return img.crop((left, top, right, bottom))

def resize_crop_img(img, width, height):
    # width < height
    if( img.size[0] < img.size[1]):
        wpercent = (width/float(img.size[0]))
        hsize = int((float(img.size[1])*float(wpercent)))
        img = img.resize((width, hsize), Image.Resampling.LANCZOS)
    else: # width >= height
        hpercent = (height/float(img.size[1]))
        wsize = int((float(img.size[0])*float(hpercent)))
        img = img.resize((wsize, height), Image.Resampling.LANCZOS)
    img = center_crop(img, width, height)
    return img

def transform_gif(img, new_width, new_height, frames, channels = 3):
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    mode = CHANNELS_TO_MODE[channels]
    gif_frames = img.n_frames
    for i in range(0, frames):
        img.seek(i % gif_frames)
        if img.size[0] != new_width or img.size[0] != new_width:
            print("Resizing")
            img_out = resize_crop_img(img, new_width, new_height)
        else:
            img_out = img
        yield img_out.convert(mode)
        
# tensor of shape (channels, frames, height, width) -> gif
def video_tensor_to_gif(tensor, path, fps = 10, loop = 0, optimize = True):
    print("Converting video tensors to GIF")
    images = map(T.ToPILImage(), tensor.unbind(dim = 1))
    first_img, *rest_imgs = images
    print(1000/fps)
    first_img.save(path, save_all = True, append_images = rest_imgs, duration = int(1000/fps), loop = loop, optimize = optimize)
    print("Gif saved")
    return images

# gif -> (channels, frame, height, width) tensor
def gif_to_tensor(path, width = 256, height = 256, frames = 32, channels = 3, transform = T.ToTensor()):
    print("Converting GIF to video tensors")
    img = Image.open(path)
    imgs = transform_gif(img, new_width = width, new_height = height, frames = frames, channels = channels)
    tensors = tuple(map(transform, imgs))
    return torch.stack(tensors, dim = 1)


## Utility functions to process dataset

In [None]:
import os
import torch
import shutil
import urllib
import traceback

from concurrent.futures import ThreadPoolExecutor, wait
import time
import threading

dataset_size = 600000
train_data = "../input/webvid-gif-600k/dataset_600K.tsv"
data_dir = "../input/webvid-gif-600k/data/"

current_step = 0
texts = []
list_videos = []


lock = threading.Lock()
executor = ThreadPoolExecutor(max_workers=15)

def process_parallel(index, file_img, file_text):
    try:
        print(f"Processing image {file_img}")
        tensor = gif_to_tensor(data_dir + file_img, width = image_size, height = image_size, frames = frames)
        file_text = file_text[:-1] # Remove \n
        with lock:
            list_videos.append(tensor)
            texts.append(file_text)
    except:
        traceback.print_exc()
    
def get_videos_parallel(index_start, index_end):
    global texts
    global list_videos
    
    texts = []
    list_videos = []

    with open(train_data) as fp:
        futures = []
        for i, line in enumerate(fp):
            if i >= index_start and i< index_end :
                file_img, file_text = line.split("\t")
                future = executor.submit(process_parallel, i, file_img, file_text)
                futures.append(future)
            elif i > index_end:
                break
        wait(futures)
                
def get_next_videos():
    global current_step
    get_videos_parallel(current_step, current_step + process_batch_size)
    current_step += len(texts)


## Utility functions to save and load checkpoints

In [None]:
import shutil
import torch
import time
import gc
import os
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from imagen_pytorch.data import Dataset

checkpoints_path = "./"
checkpoint_path = ""

# If there is a checkpoint these changes automatically at runtime
epoch_1 = 0
epoch_2 = 0

step_1 = 0
step_2 = 0

train_unet = 1

def save_checkpoint(trainer: ImagenTrainer):
    global checkpoint_path
    global current_step
    global epoch_1
    global epoch_2
    global step_1
    global step_2
    print("Saving checkpoint")
    current_time = int(time.time())
    if os.path.exists(checkpoint_path):
        os.remove(checkpoint_path)
    if train_unet == 1:
        step_1 = current_step
    else:
        step_2 = current_step
    checkpoint_path = os.path.join(checkpoints_path, f"checkpoint-e1_{epoch_1}-s1_{step_1}-e2_{epoch_2}-s2_{step_2}-{current_time}.pt")
    trainer.save(checkpoint_path)

def update_config(checkpoint):
    global epoch_1
    global epoch_2
    global step_1
    global step_2
    global train_unet
    global current_step
    splitted = (checkpoint.replace(".pt", "").split("checkpoint-")[1]).split("-")
    epoch_1 = int(splitted[0].replace("e1_", ""))
    step_1 = int(splitted[1].replace("s1_", ""))
    epoch_2 = int(splitted[2].replace("e2_", ""))
    step_2 = int(splitted[3].replace("s2_", ""))
    print("Loaded configuration")
    print(f"Epoch unet 1: {epoch_1}")
    print(f"Steps unet 1: {step_1}")
    print(f"Epoch unet 2: {epoch_2}")
    print(f"Steps unet 2: {step_2}")
    if epoch_2 >= epoch_1:
        train_unet = 1
        if step_1 >= dataset_size:
            current_step = 0
            epoch_1 += 1
        else:
            current_step = step_1
    else:
        train_unet = 2
        if step_2 >= dataset_size:
            current_step = 0
            epoch_2 += 1
        else:
            current_step = step_2
    print(f"Unet {train_unet} selected")
    print(f"Current step: {current_step}")
    
def config_new_epoch():
    global epoch_1
    global epoch_2
    global train_unet
    global current_step
    if train_unet == 1:
        epoch_1 += 1
    else:
        epoch_2 += 1
    current_step = 0
        
def load_checkpoint(trainer: ImagenTrainer):
    global checkpoint_path
    global epoch_1
    global epoch_2
    global step_1
    global step_2
    global train_unet
    global current_step
    print("Loading checkpoint")
    timestamp = -1
    for file in os.listdir(checkpoints_path):
        if file.endswith('.pt'):
            new_timestamp = int((file.split("-")[-1]).replace(".pt", ""))
            if new_timestamp > timestamp:
                checkpoint_path = os.path.join(checkpoints_path, file)
                timestamp = new_timestamp        
    if not os.path.exists(checkpoint_path):
        print("No checkpoint found -> starting from scratch")
        epoch_1 = 0
        epoch_2 = 0
        step_1 = 0
        step_2 = 0
        current_step = 0
        train_unet = 1
        return None
    trainer.load(checkpoint_path)
    update_config(checkpoint_path)

In [None]:
unet1 = Unet3D(
    dim = 64,
    cond_dim = 128,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet3D(
    dim = 64,
    cond_dim = 128,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (16, 64),
    random_crop_sizes = (None, 16),
    num_sample_steps = 64,
    cond_drop_prob = 0.1,                       # gives the probability of dropout for classifier-free guidance.
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

trainer = ImagenTrainer(imagen)


## Train Unet

In [None]:
# Train 
load_checkpoint(trainer)
while True:
    # if execution time is more than max_run_minutes stops
    if time.time() - start_time >= max_run_minutes * 60:
        break
    get_next_videos()
    if len(texts) == 0:
        save_checkpoint(trainer)
        config_new_epoch()
        get_next_videos()
#         break
    print("Generating tensor from videos")
    videos = torch.stack(list_videos, dim = 0).cuda()
    print(f"Training Unet {train_unet}")
    trainer(videos, texts = texts, unet_number = train_unet, max_batch_size = 32)
    trainer.update(unet_number = train_unet)
    del videos
save_checkpoint(trainer)

In [None]:
texts_sample = ['A red cat']
load_checkpoint(trainer)
videos_out = trainer.sample(texts = texts_sample, video_frames = 24)
print(videos_out.shape)
video_tensor_to_gif(videos_out[0], f'out.gif', fps = 5)

In [None]:
# !pip install GPUtil

# from GPUtil import showUtilization as gpu_usage
# gpu_usage()    

In [None]:
#end