<a href="https://colab.research.google.com/github/LuthandoMaqondo/phenaki-pytorch/blob/luthando-contribution/notebooks/training.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount the drive

In [2]:
import os
import sys
import platform
import requests
import torch

try:
    from google.colab import drive
    IN_COLAB = True
except:
    WORKING_DIR = '.'
    IN_COLAB = False
if IN_COLAB:
    WORKING_DIR = '/content/drive/MyDrive/Colab Notebooks'
    drive.mount('/content/drive',  force_remount=True)
if IN_COLAB:
    sys.path.insert(0, WORKING_DIR)
else:
    # The actual code is one level higher in folder depth/structure, so we're elevating this notebook.
    sys.path.insert(0,f".{WORKING_DIR}/")

In [4]:
device = "mps" if torch.backends.mps.is_built() else "cuda" if torch.cuda.is_available() else "cpu"

'mps'

# Install The Model

In [5]:
if IN_COLAB:
    !git clone https://github.com/LuthandoMaqondo/phenaki-pytorch.git
    %cd /content/phenaki-pytorch
    !git checkout luthando-contribution
    !pip install -r requirements.txt
    # !pip install phenaki-pytorch
    !pip install git+https://xxxxxxxxxxxx@github.com/AppimateSA/AutoVisual.git


# Usage

### Training process

In [8]:
import os


In [9]:
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from autovisual import DatasetConfig, VideoCustomDataset
dataset_config_args = {
    'refreshData': False,
    # 'useCLIP': "openai/clip-vit-large-patch14", # Use the pretrained CLIP model for handling ALL Text inputs.
    'structure': 'text_video_pair',
    'tokenize_text': False,
    # 'data_folder': f'.{WORKING_DIR}/datasets/Appimate',
    'data_folder': 'https://appimate1storage.blob.core.windows.net/datasets/Appimate',
    'data_json': "dataset.json",
    'data_points': None, # None

    'max_text_length': 77,
    'max_num_frames': 6,
    'resolution': 64,
    'num_channels': 1, 
    'normalize': True,
    'scale_to': 0,#0.5,
    'has_start_end_token': True,

    'frame_rate': 2,
    'frame_rate_ratio': 0.01,
    'output_format': 'TCHW'
}
datasetConfig = DatasetConfig(**dataset_config_args, train=True)
custom_dataset = VideoCustomDataset(datasetConfig)

Downloading 'dataset.json' (from Source)
Downloading 'vocab.txt' (from Source)


Downloading videos (from Source) : 100%|[38;2;0;255;255m████████████████████████████[0m| 34/34 [00:39<00:00,  1.17s/it][0m
100%|██████████| 3/3 [00:20<00:00,  7.00s/it]


Dataset prepared, 'dataset.json' file gives us: 34 text & 34 video datapoints.






In [10]:
class MockTextVideoDataset(Dataset):
    def __init__(
        self,
        length = 100,
        image_size = 256,
        num_frames = 17
    ):
        super().__init__()
        self.num_frames = num_frames
        self.image_size = image_size
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        video = torch.randn(3, self.num_frames, self.image_size, self.image_size)
        caption = 'video caption'
        return video, caption

mock_dataset = MockTextVideoDataset()

In [13]:
full_dataset = ConcatDataset([
    mock_dataset,
    # custom_dataset
])

train_len = int(len(full_dataset) * (datasetConfig.train_split) )
train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_len, len(full_dataset)- train_len])

In [17]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
# for i, batch in enumerate(train_dataloader):
#     # text_in, visuals_out = batch
#     # print('text_in: ', text_in)
#     # print('visuals_out: ', visuals_out)
#     break

#### Train the C-ViViT

In [20]:
from phenaki_pytorch import CViViT, CViViTTrainer, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 256),
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
).to(device)

data_folder = os.path.expanduser(f"{WORKING_DIR}/datasets/Appimate/train") if IN_COLAB else os.path.expanduser(f"~/.cache/datasets/Appimate/train")
trainer = CViViTTrainer(
    cvivit,
    folder = data_folder,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False,  # you can train on images first, before fine tuning on video, for sample efficiency
    use_ema = True,          # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
    num_train_steps = 10
)
# trainer.train()               # reconstructions and checkpoints will be saved periodically to ./results

training with dataset of 57 samples and validating with randomly splitted 4 samples


#### Train the Phenaki

In [6]:
maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
).to(device)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).to(device)
phenaki_trainer = PhenakiTrainer(
    phenaki,
    batch_size=4,
    num_frames=17,
    train_lr=0.0001,
    train_num_steps=2,
    grad_accum_every = 2,
    train_on_images=False,
    save_and_sample_every=100,
    num_samples=4,
    dataset = train_dataset,
    # sample_texts_file_path = f"{'/content'}/phenaki-pytorch/data/sample_texts.txt" # each caption should be on a new line, during sampling, will be randomly drawn
    sample_texts_file_path = f"/Users/luthandomaqondo/Development/Python/phenaki-pytorch/data/sample_texts.txt" # each caption should be on a new line, during sampling, will be randomly drawn
)
phenaki_trainer.train()

### Testing process

In [None]:
video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# so in the paper, they do not really achieve 2 minutes of coherent video
# at each new scene with new text conditioning, they condition on the previous K frames
# you can easily achieve this with this framework as so

video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3
video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# the total video
entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)

# and so on...