<a href="https://colab.research.google.com/github/LuthandoMaqondo/phenaki-pytorch/blob/main/notebooks/training.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount the drive

In [1]:
import os
import sys
import platform
import requests
import torch

try:
    from google.colab import drive
    IN_COLAB = True
except:
    WORKING_DIR = '.'
    IN_COLAB = False
if IN_COLAB:
    WORKING_DIR = '/content/drive/MyDrive/Colab Notebooks'
    drive.mount('/content/drive',  force_remount=True)
if IN_COLAB:
    sys.path.insert(0, WORKING_DIR)
else:
    # The actual code is one level higher in folder depth/structure, so we're elevating this notebook.
    sys.path.insert(0,f".{WORKING_DIR}/")

# Install The Model

In [2]:
!pip install phenaki-pytorch



# Usage

### Training process

#### Train teh AutoEncoder

In [4]:
from phenaki_pytorch import CViViT, CViViTTrainer, MaskGit, Phenaki

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
).cuda()

data_folder = os.path.expanduser(f"{WORKING_DIR}/datasets/Appimate/train") if IN_COLAB else os.path.expanduser(f"~/.cache/Appimate/train")
trainer = CViViTTrainer(
    cvivit,
    folder = data_folder,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False,  # you can train on images first, before fine tuning on video, for sample efficiency
    use_ema = False,          # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
    num_train_steps = 10
)
# trainer.train()               # reconstructions and checkpoints will be saved periodically to ./results

AssertionError: Torch not compiled with CUDA enabled

#### Train the Phenaki

In [5]:
maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

videos = torch.randn(3, 3, 17, 256, 128).cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 17)).bool().cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch

num_ephochs = 10
for epoch in range(0, num_ephochs):
    texts = [
        'a whale breaching from afar',
        'young girl blowing out candles on her birthday cake',
        'fireworks with blue and green sparkles'
    ]

    loss = phenaki(videos, texts = texts, video_frame_mask = mask)
    loss.backward()
    # do the above for many steps, then ...

### Testing process

In [None]:
video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# so in the paper, they do not really achieve 2 minutes of coherent video
# at each new scene with new text conditioning, they condition on the previous K frames
# you can easily achieve this with this framework as so

video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3
video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# the total video
entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)

# and so on...