<a href="https://colab.research.google.com/github/LuthandoMaqondo/phenaki-pytorch/blob/main/notebooks/training.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount the drive

In [2]:
import os
import sys
import platform
import requests
import torch

try:
    from google.colab import drive
    IN_COLAB = True
except:
    WORKING_DIR = '.'
    IN_COLAB = False
if IN_COLAB:
    WORKING_DIR = '/content/drive/MyDrive/Colab Notebooks'
    drive.mount('/content/drive',  force_remount=True)
if IN_COLAB:
    sys.path.insert(0, WORKING_DIR)
else:
    # The actual code is one level higher in folder depth/structure, so we're elevating this notebook.
    sys.path.insert(0,f".{WORKING_DIR}/")

# Install The Model

In [7]:
!pip install phenaki-pytorch
!pip install git+https://xxxxxxxxxxxx@github.com/AppimateSA/AutoVisual.git

# Usage

### Training process

In [9]:
from autovisual import DatasetConfig, VideoCustomDataset
dataset_config_args = {
    'refreshData': True,
    # 'useCLIP': "openai/clip-vit-large-patch14", # Use the pretrained CLIP model for handling ALL Text inputs.
    'structure': 'text_video_pair',
    # 'data_folder': f'.{WORKING_DIR}/datasets/Appimate',
    'data_folder': 'https://appimate1storage.blob.core.windows.net/datasets/Appimate',
    'data_json': "dataset.json",
    'data_points': None, # None

    'max_text_length': 77,
    'max_num_frames': 6,
    'resolution': 64,
    'num_channels': 1, 
    'normalize': True,
    'scale_to': 0,#0.5,
    'has_start_end_token': True,

    'frame_rate': 2,
    'frame_rate_ratio': 0.01,
    'output_format': 'TCHW'
}
datasetConfig = DatasetConfig(**dataset_config_args, train=True)
full_dataset = VideoCustomDataset(datasetConfig)
train_len = int(len(full_dataset) * (datasetConfig.train_split) )
train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_len, len(full_dataset)- train_len])

#### Train the C-ViViT

In [15]:
from phenaki_pytorch import CViViT, CViViTTrainer, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 256),
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
).cuda()

data_folder = os.path.expanduser(f"{WORKING_DIR}/datasets/Appimate/train") if IN_COLAB else os.path.expanduser(f"~/.cache/datasets/Appimate/train")
trainer = CViViTTrainer(
    cvivit,
    folder = data_folder,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False,  # you can train on images first, before fine tuning on video, for sample efficiency
    use_ema = True,          # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
    num_train_steps = 10
)
trainer.train()               # reconstructions and checkpoints will be saved periodically to ./results

#### Train the Phenaki

In [11]:

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()
phenaki_trainer = PhenakiTrainer(
    phenaki,
    batch_size=4,
    num_frames=17,
    train_lr=0.0001,
    train_num_steps=1000,
    train_on_images=False,
    save_and_sample_every=100,
    num_samples=4,
    dataset = train_dataset,
    sample_texts_file_path = f"{'/content'}/phenaki-pytorch/data/sample_texts.txt" # each caption should be on a new line, during sampling, will be randomly drawn
)
phenaki_trainer.train()

### Testing process

In [None]:
video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# so in the paper, they do not really achieve 2 minutes of coherent video
# at each new scene with new text conditioning, they condition on the previous K frames
# you can easily achieve this with this framework as so

video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3
video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# the total video
entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)

# and so on...