# Video diffusion models

In [2]:
# Test 1
image_size = 128
frames = 10
checkpoint_path = "./"
max_images = 120000
batch_size = 32

In [3]:
# Install imagen dependencies
!pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
!pip install imagen_pytorch==1.16.5

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu116
Collecting torch==1.12.1+cu116
  Downloading https://download.pytorch.org/whl/cu116/torch-1.12.1%2Bcu116-cp39-cp39-linux_x86_64.whl (1904.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 GB[0m [31m767.9 kB/s[0m eta [36m0:00:00[0m0:01[0m00:01[0mm
[?25hCollecting torchvision==0.13.1+cu116
  Downloading https://download.pytorch.org/whl/cu116/torchvision-0.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl (23.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.5/23.5 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchaudio==0.12.1
  Downloading https://download.pytorch.org/whl/cu116/torchaudio-0.12.1%2Bcu116-cp39-cp39-linux_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
Installing collected packages: torch, torchvis

In [4]:
# GIF pre-processing

import numpy as np
from torchvision import transforms as T
from math import floor, fabs
from PIL import Image, ImageSequence


CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

def center_crop(img, new_width, new_height): 
    width = img.size[0]
    height = img.size[1]
    left = int(np.ceil((width - new_width) / 2))
    right = width - int(np.floor((width - new_width) / 2))
    top = int(np.ceil((height - new_height) / 2))
    bottom = height - int(np.floor((height - new_height) / 2))
    return img.crop((left, top, right, bottom))

def resize_crop_img(img, width, height):
    # width < height
    if( img.size[0] < img.size[1]):
      wpercent = (width/float(img.size[0]))
      hsize = int((float(img.size[1])*float(wpercent)))
      img = img.resize((width, hsize), Image.Resampling.LANCZOS)
    else: # width >= height
      hpercent = (height/float(img.size[1]))
      wsize = int((float(img.size[0])*float(hpercent)))
      img = img.resize((wsize, height), Image.Resampling.LANCZOS)
    img = center_crop(img, width, height)
    # print(img.size[0])
    # print(img.size[1])
    return img

def transform_gif(img, new_width, new_height, frames, channels = 3):
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    mode = CHANNELS_TO_MODE[channels]
    gif_frames = img.n_frames
    for i in range(0, frames):
        img.seek(i % gif_frames)
        img_out = resize_crop_img(img, new_width, new_height)
        yield img_out.convert(mode)
        
# tensor of shape (channels, frames, height, width) -> gif
def video_tensor_to_gif(tensor, path, fps = 10, loop = 0, optimize = True):
    print("Converting video tensors to GIF")
    images = map(T.ToPILImage(), tensor.unbind(dim = 1))
    first_img, *rest_imgs = images
    print(1000/fps)
    first_img.save(path, save_all = True, append_images = rest_imgs, duration = int(1000/fps), loop = loop, optimize = optimize)
    print("Gif saved")
    return images

# gif -> (channels, frame, height, width) tensor
def gif_to_tensor(path, width = 256, height = 256, frames = 32, channels = 3, transform = T.ToTensor()):
    print("Converting GIF to video tensors")
    img = Image.open(path)
    imgs = transform_gif(img, new_width = width, new_height = height, frames = frames, channels = channels)
    tensors = tuple(map(transform, imgs))
    return torch.stack(tensors, dim = 1)

In [5]:
import os
import torch
import shutil

train_data = "./train_data.tvs"
train_index = "./train_index.txt"

if not os.path.exists(train_data):
  !wget -O {train_data} https://raw.githubusercontent.com/raingo/TGIF-Release/master/data/tgif-v1.0.tsv

current_index = 0
texts = []
list_videos = []

def get_videos(index_start, index_end):
    global texts
    global list_videos
    
    texts = []
    list_videos = []
    max_iter = 100

    with open("train_data.tvs") as fp:
        for i, line in enumerate(fp):
            if i >= index_start and i< index_end :
                file_img, file_text = line.split("\t")
                try:
                    print(f"Downloading image {i}");
                    !wget -O download.gif -o /dev/null {file_img}
                    tensor = gif_to_tensor('download.gif', width = image_size, height = image_size, frames = frames)
                    list_videos.append(tensor)
                    file_text = file_text[:-1] # Remove \n
                    texts.append(file_text)
                    os.remove('download.gif')
                except Exception as ex:
                    print(ex)
                    pass
            elif i > index_end:
                break

def get_next_videos():
    global current_index
    index = 0
    if not os.path.exists(train_index):
        with open(train_index, 'w') as fp:
            fp.write("0")
    else:
        with open(train_index, 'r') as fp:
            index = int(fp.readlines()[0])
    index_end = index + batch_size
    get_videos(index, index_end)
    with open(train_index, 'w') as fp:
        fp.write(f"{index_end}")
    current_index = index_end
    

--2022-12-06 10:58:29--  https://raw.githubusercontent.com/raingo/TGIF-Release/master/data/tgif-v1.0.tsv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18660908 (18M) [text/plain]
Saving to: ‘./train_data.tvs’


2022-12-06 10:58:35 (3.04 MB/s) - ‘./train_data.tvs’ saved [18660908/18660908]



In [5]:
get_next_videos()

Downloading image 32
--2022-12-06 10:09:12--  https://38.media.tumblr.com/902f71ae45f50dc0a35c04323de6495a/tumblr_nj4iqcVaFQ1s26nzro1_500.gif
Resolving 38.media.tumblr.com (38.media.tumblr.com)... 74.114.154.22, 74.114.154.18
Connecting to 38.media.tumblr.com (38.media.tumblr.com)|74.114.154.22|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://64.media.tumblr.com/902f71ae45f50dc0a35c04323de6495a/tumblr_nj4iqcVaFQ1s26nzro1_500.gif [following]
--2022-12-06 10:09:12--  https://64.media.tumblr.com/902f71ae45f50dc0a35c04323de6495a/tumblr_nj4iqcVaFQ1s26nzro1_500.gif
Resolving 64.media.tumblr.com (64.media.tumblr.com)... 192.0.77.3
Connecting to 64.media.tumblr.com (64.media.tumblr.com)|192.0.77.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 991818 (969K) [image/gif]
Saving to: ‘download.gif’


2022-12-06 10:09:12 (3.16 MB/s) - ‘download.gif’ saved [991818/991818]

Converting GIF to video tensors
Downloading image 3

In [6]:
import shutil
import torch
import datetime
import gc
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from imagen_pytorch.data import Dataset

checkpoints_path = "./checkpoints"
if not os.path.exists(checkpoints_path):
    os.mkdir(checkpoints_path)

last_checkpoint_path = os.path.join(checkpoints_path, "last_checkpoint.txt")
num_saves = 0

def save_checkpoint(trainer: ImagenTrainer, unet, step):
    print("Saving checkpoint")
    current_datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    new_checkpoint_path = os.path.join(checkpoints_path, f"checkpoint-unet{unet}-step{step}-{current_datetime}.pt")
    trainer.save(new_checkpoint_path)
    if os.path.exists(last_checkpoint_path):
        with open(last_checkpoint_path, 'r') as fp:
            os.remove(fp.readlines()[0])
    with open(last_checkpoint_path, 'w') as fp:
        fp.write(new_checkpoint_path)
    global num_saves
    num_saves += 1
    with open("num_saves.txt", 'w') as fp:
        fp.write(f"{num_saves}")

def load_checkpoint(trainer: ImagenTrainer):
    if not os.path.exists(last_checkpoint_path):
        return None
    with open(last_checkpoint_path, 'r') as fp:
        checkpoint_path = fp.readlines()[0]
    try:
        print("Loading checkpoint")
        trainer.load(checkpoint_path)
    except:
        return None


Downloading:   0%|          | 0.00/605 [00:00<?, ?B/s]

    https://github.com/beartype/beartype#pep-585-deprecations
  warn(


In [7]:
unet1 = None
unet2 = None
imagen = None
trainer = None

def setup_trainer():
    global unet1
    global unet2
    global imagen
    global trainer
    
    unet1 = Unet3D(
        dim = 64,
        cond_dim = 128,
        dim_mults = (1, 2, 4, 8),
        num_resnet_blocks = 3,
        layer_attns = (False, True, True, True),
        layer_cross_attns = (False, True, True, True)
    )

    unet2 = Unet3D(
        dim = 64,
        cond_dim = 128,
        dim_mults = (1, 2, 4, 8),
        num_resnet_blocks = (2, 4, 8, 8),
        layer_attns = (False, False, False, True),
        layer_cross_attns = (False, False, False, True)
    )

    imagen = ElucidatedImagen(
        unets = (unet1, unet2),
        image_sizes = (16, 32),
        random_crop_sizes = (None, 16),
        num_sample_steps = 10,
        # timesteps = 1000,
        cond_drop_prob = 0.1,                       # gives the probability of dropout for classifier-free guidance.
        sigma_min = 0.002,                          # min noise level
        sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
        sigma_data = 0.5,                           # standard deviation of data distribution
        rho = 7,                                    # controls the sampling schedule
        P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
        P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
        S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
        S_tmin = 0.05,
        S_tmax = 50,
        S_noise = 1.003,
    ).cuda()

    trainer = ImagenTrainer(imagen)

def free_trainer():
    global unet1
    global unet2
    global imagen
    global trainer
    del unet1
    del unet2
    del imagen
    del trainer
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
# unet1 = Unet3D(dim = 128, dim_mults = (1, 2, 4, 8)).cuda()
# unet2 = Unet3D(dim = 128, dim_mults = (1, 2, 4, 8)).cuda()

# imagen = ElucidatedImagen(
#     unets = (unet1, unet2),
#     image_sizes = (16, 32),
#     random_crop_sizes = (None, 16),
#     num_sample_steps = 10,
#     cond_drop_prob = 0.1,
#     sigma_min = 0.002,                          # min noise level
#     sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
#     sigma_data = 0.5,                           # standard deviation of data distribution
#     rho = 7,                                    # controls the sampling schedule
#     P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
#     P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
#     S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
#     S_tmin = 0.05,
#     S_tmax = 50,
#     S_noise = 1.003,
# ).cuda()

In [8]:
# Train Unets

batch_size = 512
get_next_videos()
for unet in range(0, 2):
    print(f"Training unet {unet}")
    videos = torch.stack(list_videos, dim = 0).cuda()
    setup_trainer()
    load_checkpoint(trainer)
    trainer(videos, texts = texts, unet_number = unet, max_batch_size = 32)
    trainer.update(unet_number = unet)
    save_checkpoint(trainer, unet, current_index)
    del videos
    free_trainer()



Downloading image 0
--2022-12-06 10:59:07--  https://38.media.tumblr.com/9f6c25cc350f12aa74a7dc386a5c4985/tumblr_mevmyaKtDf1rgvhr8o1_500.gif
Resolving 38.media.tumblr.com (38.media.tumblr.com)... 74.114.154.22, 74.114.154.18
Connecting to 38.media.tumblr.com (38.media.tumblr.com)|74.114.154.22|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://64.media.tumblr.com/9f6c25cc350f12aa74a7dc386a5c4985/tumblr_mevmyaKtDf1rgvhr8o1_500.gif [following]
--2022-12-06 10:59:07--  https://64.media.tumblr.com/9f6c25cc350f12aa74a7dc386a5c4985/tumblr_mevmyaKtDf1rgvhr8o1_500.gif
Resolving 64.media.tumblr.com (64.media.tumblr.com)... 192.0.77.3
Connecting to 64.media.tumblr.com (64.media.tumblr.com)|192.0.77.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1022700 (999K) [image/gif]
Saving to: ‘download.gif’


2022-12-06 10:59:07 (3.20 MB/s) - ‘download.gif’ saved [1022700/1022700]

Converting GIF to video tensors
Downloading image

In [9]:
# Train Unet 2
unet_number = 2
videos = torch.stack(list_videos, dim = 0).cuda()
setup_trainer()
load_checkpoint(trainer)

print(videos.shape)
print(len(texts))

trainer(videos, texts = texts, unet_number = unet_number, max_batch_size = 32)
trainer.update(unet_number = unet_number)
save_checkpoint(trainer, unet_number, current_index)
del videos
free_trainer()

checkpoint loaded from ./checkpoints/checkpoint-unet1-step64-20221206-101018.pt
Loaded checkpoint
torch.Size([32, 3, 10, 128, 128])
32
Saving checkpoint
checkpoint saved to ./checkpoints/checkpoint-unet2-step64-20221206-101110.pt


In [None]:
# !pip install GPUtil

# from GPUtil import showUtilization as gpu_usage
# gpu_usage()    

In [None]:
# Train Unet 2

# trainer = ImagenTrainer(imagen)
# trainer(videos, texts = texts, unet_number = 2, max_batch_size = 1)
# trainer.update(unet_number = 2)

In [None]:
#end