# Video diffusion models

In [1]:
# Test 1
image_size = 128
frames = 10
max_images = 125782
download_batch_size = 128

In [2]:
# Install imagen dependencies
!pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
!pip install imagen_pytorch==1.16.5

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu116
Collecting torch==1.12.1+cu116
  Downloading https://download.pytorch.org/whl/cu116/torch-1.12.1%2Bcu116-cp39-cp39-linux_x86_64.whl (1904.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 GB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0mm
[?25hCollecting torchvision==0.13.1+cu116
  Downloading https://download.pytorch.org/whl/cu116/torchvision-0.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl (23.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.5/23.5 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchaudio==0.12.1
  Downloading https://download.pytorch.org/whl/cu116/torchaudio-0.12.1%2Bcu116-cp39-cp39-linux_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m59.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: torch, torchvisi

In [3]:
# GIF pre-processing

import numpy as np
from torchvision import transforms as T
from math import floor, fabs
from PIL import Image, ImageSequence


CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

def center_crop(img, new_width, new_height): 
    width = img.size[0]
    height = img.size[1]
    left = int(np.ceil((width - new_width) / 2))
    right = width - int(np.floor((width - new_width) / 2))
    top = int(np.ceil((height - new_height) / 2))
    bottom = height - int(np.floor((height - new_height) / 2))
    return img.crop((left, top, right, bottom))

def resize_crop_img(img, width, height):
    # width < height
    if( img.size[0] < img.size[1]):
      wpercent = (width/float(img.size[0]))
      hsize = int((float(img.size[1])*float(wpercent)))
      img = img.resize((width, hsize), Image.Resampling.LANCZOS)
    else: # width >= height
      hpercent = (height/float(img.size[1]))
      wsize = int((float(img.size[0])*float(hpercent)))
      img = img.resize((wsize, height), Image.Resampling.LANCZOS)
    img = center_crop(img, width, height)
    # print(img.size[0])
    # print(img.size[1])
    return img

def transform_gif(img, new_width, new_height, frames, channels = 3):
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    mode = CHANNELS_TO_MODE[channels]
    gif_frames = img.n_frames
    for i in range(0, frames):
        img.seek(i % gif_frames)
        img_out = resize_crop_img(img, new_width, new_height)
        yield img_out.convert(mode)
        
# tensor of shape (channels, frames, height, width) -> gif
def video_tensor_to_gif(tensor, path, fps = 10, loop = 0, optimize = True):
    print("Converting video tensors to GIF")
    images = map(T.ToPILImage(), tensor.unbind(dim = 1))
    first_img, *rest_imgs = images
    print(1000/fps)
    first_img.save(path, save_all = True, append_images = rest_imgs, duration = int(1000/fps), loop = loop, optimize = optimize)
    print("Gif saved")
    return images

# gif -> (channels, frame, height, width) tensor
def gif_to_tensor(path, width = 256, height = 256, frames = 32, channels = 3, transform = T.ToTensor()):
    print("Converting GIF to video tensors")
    img = Image.open(path)
    imgs = transform_gif(img, new_width = width, new_height = height, frames = frames, channels = channels)
    tensors = tuple(map(transform, imgs))
    return torch.stack(tensors, dim = 1)

In [4]:
import os
import torch
import shutil

train_data = "./train_data.tvs"
train_index = "./train_index.txt"

if not os.path.exists(train_data):
  !wget -O {train_data} https://raw.githubusercontent.com/raingo/TGIF-Release/master/data/tgif-v1.0.tsv

current_index = 0
texts = []
list_videos = []

def get_videos(index_start, index_end):
    global texts
    global list_videos
    
    texts = []
    list_videos = []
    max_iter = 100

    with open("train_data.tvs") as fp:
        for i, line in enumerate(fp):
            if i >= index_start and i< index_end :
                file_img, file_text = line.split("\t")
                try:
                    print(f"Downloading image {i}")
                    !wget -O download.gif -o /dev/null {file_img}
                    tensor = gif_to_tensor('download.gif', width = image_size, height = image_size, frames = frames)
                    list_videos.append(tensor)
                    file_text = file_text[:-1] # Remove \n
                    texts.append(file_text)
                    os.remove('download.gif')
                except Exception as ex:
                    print(ex)
                    pass
            elif i > index_end:
                break

def get_next_videos():
    global current_index
    index = 0
    if not os.path.exists(train_index):
        with open(train_index, 'w') as fp:
            fp.write("0")
    else:
        with open(train_index, 'r') as fp:
            index = int(fp.readlines()[0])
    index_end = index + download_batch_size
    get_videos(index, index_end)
    with open(train_index, 'w') as fp:
        fp.write(f"{index_end}")
    current_index = index_end
    

In [5]:
import shutil
import torch
import datetime
import gc
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from imagen_pytorch.data import Dataset

checkpoints_path = "./"
last_checkpoint_path = os.path.join(checkpoints_path, "last_checkpoint.txt")

def save_checkpoint(trainer: ImagenTrainer, unet, step):
    print("Saving checkpoint")
    current_datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    new_checkpoint_path = os.path.join(checkpoints_path, f"checkpoint-unet{unet}-step{step}-{current_datetime}.pt")
    trainer.save(new_checkpoint_path)
    if os.path.exists(last_checkpoint_path):
        with open(last_checkpoint_path, 'r') as fp:
            os.remove(fp.readlines()[0])
    with open(last_checkpoint_path, 'w') as fp:
        fp.write(new_checkpoint_path)

def load_checkpoint(trainer: ImagenTrainer):
    if not os.path.exists(last_checkpoint_path):
        return None
    with open(last_checkpoint_path, 'r') as fp:
        checkpoint_path = fp.readlines()[0]
        try:
            print("Loading checkpoint")
            trainer.load(checkpoint_path)
        except:
            return None


Downloading:   0%|          | 0.00/605 [00:00<?, ?B/s]

    https://github.com/beartype/beartype#pep-585-deprecations
  warn(


In [6]:
unet1 = Unet3D(
    dim = 64,
    cond_dim = 128,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet3D(
    dim = 64,
    cond_dim = 128,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (16, 64),
    random_crop_sizes = (None, 16),
    num_sample_steps = 10,
    # timesteps = 1000,
    cond_drop_prob = 0.1,                       # gives the probability of dropout for classifier-free guidance.
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

trainer = ImagenTrainer(imagen)


The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/


In [8]:
# Train Unet 1
unet = 1
load_checkpoint(trainer)

while True:
    get_next_videos()
    if len(texts) == 0:
        break
    print("Generating tensor from videos")
    videos = torch.stack(list_videos, dim = 0).cuda()
    print(f"Training unet-{unet}")
    trainer(videos, texts = texts, unet_number = unet, max_batch_size = 32)
    trainer.update(unet_number = unet)
    del videos
    torch.cuda.empty_cache()
    print("Allocated memory")
    print(torch.cuda.memory_allocated())
    save_checkpoint(trainer, unet, current_index)

Loading checkpoint
Downloading image 5376
Converting GIF to video tensors
Downloading image 5377
Converting GIF to video tensors
Downloading image 5378
Converting GIF to video tensors
Downloading image 5379
Converting GIF to video tensors
Downloading image 5380
Converting GIF to video tensors
Downloading image 5381
Converting GIF to video tensors
Downloading image 5382
Converting GIF to video tensors
Downloading image 5383
Converting GIF to video tensors
Downloading image 5384
Converting GIF to video tensors
Downloading image 5385
Converting GIF to video tensors
Downloading image 5386
Converting GIF to video tensors
Downloading image 5387
Converting GIF to video tensors
Downloading image 5388
Converting GIF to video tensors
Downloading image 5389
Converting GIF to video tensors
Downloading image 5390
Converting GIF to video tensors
Downloading image 5391
Converting GIF to video tensors
Downloading image 5392
Converting GIF to video tensors
Downloading image 5393
Converting GIF to video

Downloading:   0%|          | 0.00/945M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Allocated memory
1850773504
Saving checkpoint
checkpoint saved to ./checkpoint-unet1-step5504-20221212-103737.pt


FileNotFoundError: [Errno 2] No such file or directory: './checkpoints/checkpoint-unet1-step5376-20221212-032038.pt'

In [7]:
# # Train Unet 2
# unet = 2
# load_checkpoint(trainer)

# while True:
#     get_next_videos()
#     if len(texts) == 0:
#         break
#     print("Generating tensor from videos")
#     videos = torch.stack(list_videos, dim = 0).cuda()
#     print(f"Training unet-{unet}")
#     trainer(videos, texts = texts, unet_number = unet, max_batch_size = 32)
#     trainer.update(unet_number = unet)
#     del videos
#     torch.cuda.empty_cache()
#     print("Allocated memory")
#     print(torch.cuda.memory_allocated())
#     save_checkpoint(trainer, unet, current_index)

The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/
Loading checkpoint
checkpoint loaded from ./checkpoints/checkpoint-unet1-step128-20221211-185101.pt
Downloading image 128
Converting GIF to video tensors
Downloading image 129
Converting GIF to video tensors
Downloading image 130
Converting GIF to video tensors
Downloading image 131
Converting GIF to video tensors
Downloading image 132
Converting GIF to video tensors
Downloading image 133
Converting GIF to video tensors
Downloading image 134
Converting GIF to video tensors
Downloading image 135
Converting GIF to video tensors
Downloading image 136
Converting GIF to video tensors
Downloading image 137
Converting GIF to video tensors
Downloading image 138
Converting GIF to video tensors
Downloading image 139
Converting GIF to video tensors
Downloading image 140
Converting GIF to video ten

Downloading:   0%|          | 0.00/945M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Allocated memory
3529052672
Saving checkpoint
checkpoint saved to ./checkpoints/checkpoint-unet2-step192-20221211-190038.pt


In [12]:
# !pip install GPUtil

# from GPUtil import showUtilization as gpu_usage
# gpu_usage()    

Collecting GPUtil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: GPUtil
  Building wheel for GPUtil (setup.py) ... [?25ldone
[?25h  Created wheel for GPUtil: filename=GPUtil-1.4.0-py3-none-any.whl size=7394 sha256=9e59b4827870b979608e22fafd27d86dddbe0e49710f6ae3fa0fc74616de8e71
  Stored in directory: /root/.cache/pip/wheels/2b/b5/24/fbb56595c286984f7315ee31821d6121e1b9828436021a88b3
Successfully built GPUtil
Installing collected packages: GPUtil
Successfully installed GPUtil-1.4.0
[0m| ID | GPU | MEM |
------------------
|  0 |  0% | 98% |


In [None]:
#end