### Check if CUDA is enabled

In [2]:
!nvidia-smi

Thu Dec  1 21:45:45 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.73.05    Driver Version: 510.73.05    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 5000     Off  | 00000000:00:05.0 Off |                  Off |
| 33%   35C    P8    13W / 230W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Utility functions to resize and crop GIFs



In [1]:
import numpy as np
from torchvision import transforms as T
from math import floor, fabs
from PIL import Image, ImageSequence


CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

def center_crop(img, new_width, new_height): 
    width = img.size[0]
    height = img.size[1]
    left = int(np.ceil((width - new_width) / 2))
    right = width - int(np.floor((width - new_width) / 2))
    top = int(np.ceil((height - new_height) / 2))
    bottom = height - int(np.floor((height - new_height) / 2))
    return img.crop((left, top, right, bottom))

def resize_crop_img(img, width, height):
    # width < height
    if( img.size[0] < img.size[1]):
      wpercent = (width/float(img.size[0]))
      hsize = int((float(img.size[1])*float(wpercent)))
      img = img.resize((width, hsize), Image.LANCZOS)
    else: # width >= height
      hpercent = (height/float(img.size[1]))
      wsize = int((float(img.size[0])*float(hpercent)))
      img = img.resize((wsize, height), Image.LANCZOS)
    img = center_crop(img, width, height)
    # print(img.size[0])
    # print(img.size[1])
    return img

def transform_gif(img, new_width, new_height, frames, channels = 3):
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    mode = CHANNELS_TO_MODE[channels]
    gif_frames = img.n_frames
    for i in range(0, frames):
        img.seek(i % gif_frames)
        img_out = resize_crop_img(img, new_width, new_height)
        yield img_out.convert(mode)
        
# tensor of shape (channels, frames, height, width) -> gif
def video_tensor_to_gif(tensor, path, fps = 10, loop = 0, optimize = True):
    print("Converting video tensors to GIF")
    images = map(T.ToPILImage(), tensor.unbind(dim = 1))
    first_img, *rest_imgs = images
    print(1000/fps)
    first_img.save(path, save_all = True, append_images = rest_imgs, duration = int(1000/fps), loop = loop, optimize = optimize)
    print("Gif saved")
    return images

# gif -> (channels, frame, height, width) tensor
def gif_to_tensor(path, width = 256, height = 256, frames = 32, channels = 3, transform = T.ToTensor()):
    print("Converting GIF to video tensors")
    img = Image.open(path)
    imgs = transform_gif(img, new_width = width, new_height = height, frames = frames, channels = channels)
    tensors = tuple(map(transform, imgs))
    return torch.stack(tensors, dim = 1)

# tensor = gif_to_tensor('example.gif', width = 128, height = 128, max_frames = 20).cuda()
# gif = video_tensor_to_gif(tensor, 'example2.gif', fps = 5)
# print(gif)


## Dataset import and normalization

In [2]:
import torch

texts = []
videos = torch.empty(0)
frames = 10
width = 256
height = 256

In [3]:
!wget -O train_data.tvs https://raw.githubusercontent.com/raingo/TGIF-Release/master/data/tgif-v1.0.tsv
!mkdir train
!cd train

--2022-12-03 11:07:58--  https://raw.githubusercontent.com/raingo/TGIF-Release/master/data/tgif-v1.0.tsv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18660908 (18M) [text/plain]
Saving to: ‘train_data.tvs’


2022-12-03 11:08:02 (4.81 MB/s) - ‘train_data.tvs’ saved [18660908/18660908]

mkdir: cannot create directory ‘train’: File exists


In [6]:
# Download all images
texts = []
list_videos = []
max_iter = 100

with open('train_data.tvs') as train_file:
  i = 0;
  for line in train_file:
    file_img, file_text = line.split("\t")
    file_text = file_text[:-1] # Remove \n
    !wget -O train/{i}.gif {file_img}
    tensor = gif_to_tensor(f'train/{i}.gif', width = width, height = height, frames = frames)
    print(tensor.shape)
    list_videos.append(tensor)
    # video_tensor_to_gif(tensor, f'train/out_{i}.gif', fps = 5)
    texts.append(file_text)
    i+=1
    if(i==max_iter): break

videos = torch.stack(list_videos, dim = 0).cuda()
# videos = videos.cuda()
print(videos.shape)

--2022-12-01 19:20:02--  https://38.media.tumblr.com/9f6c25cc350f12aa74a7dc386a5c4985/tumblr_mevmyaKtDf1rgvhr8o1_500.gif
Resolving 38.media.tumblr.com (38.media.tumblr.com)... 74.114.154.22, 74.114.154.18
Connecting to 38.media.tumblr.com (38.media.tumblr.com)|74.114.154.22|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://64.media.tumblr.com/9f6c25cc350f12aa74a7dc386a5c4985/tumblr_mevmyaKtDf1rgvhr8o1_500.gif [following]
--2022-12-01 19:20:02--  https://64.media.tumblr.com/9f6c25cc350f12aa74a7dc386a5c4985/tumblr_mevmyaKtDf1rgvhr8o1_500.gif
Resolving 64.media.tumblr.com (64.media.tumblr.com)... 192.0.77.3
Connecting to 64.media.tumblr.com (64.media.tumblr.com)|192.0.77.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1022700 (999K) [image/gif]
Saving to: ‘train/0.gif’


2022-12-01 19:20:02 (1.66 MB/s) - ‘train/0.gif’ saved [1022700/1022700]

Converting GIF to video tensors
torch.Size([3, 10, 256, 256])


  img = img.resize((wsize, height), Image.LANCZOS)


--2022-12-01 19:20:03--  https://38.media.tumblr.com/9ead028ef62004ef6ac2b92e52edd210/tumblr_nok4eeONTv1s2yegdo1_400.gif
Resolving 38.media.tumblr.com (38.media.tumblr.com)... 74.114.154.18, 74.114.154.22
Connecting to 38.media.tumblr.com (38.media.tumblr.com)|74.114.154.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://64.media.tumblr.com/9ead028ef62004ef6ac2b92e52edd210/tumblr_nok4eeONTv1s2yegdo1_400.gif [following]
--2022-12-01 19:20:08--  https://64.media.tumblr.com/9ead028ef62004ef6ac2b92e52edd210/tumblr_nok4eeONTv1s2yegdo1_400.gif
Resolving 64.media.tumblr.com (64.media.tumblr.com)... 192.0.77.3
Connecting to 64.media.tumblr.com (64.media.tumblr.com)|192.0.77.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2088110 (2.0M) [image/gif]
Saving to: ‘train/1.gif’


2022-12-01 19:20:09 (5.03 MB/s) - ‘train/1.gif’ saved [2088110/2088110]

Converting GIF to video tensors
torch.Size([3, 10, 256, 256])


  img = img.resize((width, hsize), Image.LANCZOS)


--2022-12-01 19:20:09--  https://38.media.tumblr.com/9f43dc410be85b1159d1f42663d811d7/tumblr_mllh01J96X1s9npefo1_250.gif
Resolving 38.media.tumblr.com (38.media.tumblr.com)... 74.114.154.18, 74.114.154.22
Connecting to 38.media.tumblr.com (38.media.tumblr.com)|74.114.154.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://64.media.tumblr.com/9f43dc410be85b1159d1f42663d811d7/tumblr_mllh01J96X1s9npefo1_250.gif [following]
--2022-12-01 19:20:09--  https://64.media.tumblr.com/9f43dc410be85b1159d1f42663d811d7/tumblr_mllh01J96X1s9npefo1_250.gif
Resolving 64.media.tumblr.com (64.media.tumblr.com)... 192.0.77.3
Connecting to 64.media.tumblr.com (64.media.tumblr.com)|192.0.77.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 566152 (553K) [image/gif]
Saving to: ‘train/2.gif’


2022-12-01 19:20:10 (1.68 MB/s) - ‘train/2.gif’ saved [566152/566152]

Converting GIF to video tensors
torch.Size([3, 10, 256, 256])
--2022-12-01

In [4]:
# Get pre-downloades images

texts = []
list_videos = []
max_iter = 100

with open('train_data.tvs') as train_file:
  i = 0;
  for line in train_file:
    file_img, file_text = line.split("\t")
    file_text = file_text[:-1] # Remove \n
    tensor = gif_to_tensor(f'train/{i}.gif', width = width, height = height, frames = frames)
    list_videos.append(tensor)
    texts.append(file_text)
    i+=1
    if(i==max_iter): break

videos = torch.stack(list_videos, dim = 0).cuda()
#torch.save(videos, 'videos.pt')
print(videos.shape)

Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors


  img = img.resize((wsize, height), Image.LANCZOS)
  img = img.resize((width, hsize), Image.LANCZOS)


Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converting GIF to video tensors
Converti

## Train

In [5]:
# Install imagen dependencies
#!pip install torch
!pip install imagen_pytorch==1.16.5

import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from imagen_pytorch.data import Dataset

Collecting imagen_pytorch==1.16.5
  Downloading imagen_pytorch-1.16.5-py3-none-any.whl (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ema-pytorch>=0.0.3
  Downloading ema_pytorch-0.1.2-py3-none-any.whl (4.2 kB)
Collecting beartype
  Downloading beartype-0.11.0-py3-none-any.whl (702 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m702.5/702.5 kB[0m [31m71.0 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-warmup
  Downloading pytorch_warmup-0.1.1-py3-none-any.whl (6.6 kB)
Collecting kornia
  Downloading kornia-0.6.8-py2.py3-none-any.whl (551 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m551.1/551.1 kB[0m [31m73.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
Collecti

Downloading:   0%|          | 0.00/605 [00:00<?, ?B/s]

    https://github.com/beartype/beartype#pep-585-deprecations
  warn(


In [8]:
unet1 = Unet3D(dim = 128, dim_mults = (1, 2, 4, 8)).cuda()
unet2 = Unet3D(dim = 128, dim_mults = (1, 2, 4, 8)).cuda()


# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (128, 256),
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()


print(videos.shape)

trainer = ImagenTrainer(imagen)
# for u in (1, 2):
#     loss = trainer(videos, texts = texts, unet_number = u, max_batch_size = 1)
#     trainer.update(unet_number = u)

loss = trainer(videos, texts = texts, unet_number = 1,  max_batch_size = 2)
trainer.update(unet_number = 1)

torch.Size([100, 3, 10, 256, 256])


RuntimeError: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 15.75 GiB total capacity; 13.28 GiB already allocated; 483.44 MiB free; 14.21 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [7]:
unet1 = Unet3D(dim = 128, dim_mults = (1, 2, 4, 8)).cuda()
unet2 = Unet3D(dim = 128, dim_mults = (1, 2, 4, 8)).cuda()

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (16, 32),
    random_crop_sizes = (None, 16),
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# texts = [
#     'a whale breaching from afar',
#     'young girl blowing out candles on her birthday cake',
#     'fireworks with blue and green sparkles',
#     'dust motes swirling in the morning sunshine on the windowsill'
# ]

# #videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)

# videos = ['prova', 'sadasd','sadasd','sadasd']
print(videos.shape)

trainer = ImagenTrainer(imagen)
trainer(videos, texts = texts, unet_number = 1, max_batch_size = 1)
trainer.update(unet_number = 1)

torch.Size([100, 3, 10, 32, 32])


Downloading:   0%|          | 0.00/945M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

In [7]:
# Test
texts_tests = ['cat']
videos_out = trainer.sample(texts = texts_tests, video_frames = 20)
print(videos_out.shape)
video_tensor_to_gif(videos_out[0], f'out.gif', fps = 5)


unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([1, 3, 20, 128, 128])
Converting video tensors to GIF
200.0
Gif saved


<map at 0x7f6c5e09fd90>

# TODO

In [15]:
import torch
torch.cuda.empty_cache()

In [None]:
import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer

unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

# elucidated imagen, which contains the unets above (base unet and super resoluting ones)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (16, 32),
    random_crop_sizes = (None, 16),
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# mock videos (get a lot of this) and text encodings from large T5

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
]

videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)

# feed images into imagen, training each unet in the cascade
# for this example, only training unet 1

trainer = ImagenTrainer(imagen)
trainer(videos, texts = texts, unet_number = 1)
trainer.update(unet_number = 1)

videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames

videos.shape # (4, 3, 20, 32, 32)

print(type(videos))

Downloading:   0%|          | 0.00/990M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.86k [00:00<?, ?B/s]

unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

<class 'torch.Tensor'>
