<a href="https://colab.research.google.com/github/LuthandoMaqondo/phenaki-pytorch/blob/luthando-contribution/notebooks/training.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount the drive

In [1]:
import os
import sys
import platform
import requests
import torch
from torch.utils.data import Dataset, ConcatDataset, DataLoader

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
try:
    from google.colab import drive
    IN_COLAB = True
except:
    WORKING_DIR = '.'
    IN_COLAB = False
if IN_COLAB:
    WORKING_DIR = '/content/drive/MyDrive/Colab Notebooks'
    drive.mount('/content/drive',  force_remount=True)

In [2]:
device = "mps" if torch.backends.mps.is_built() else "cuda" if torch.cuda.is_available() else "cpu"

# Install The Model

In [3]:
if IN_COLAB:
    !git clone https://github.com/LuthandoMaqondo/phenaki-pytorch.git
    %cd /content/phenaki-pytorch
    !git checkout luthando-contribution
    !pip install -r requirements.txt
    # !pip install phenaki-pytorch
    
# !pip install git+https://xxxxxxxxxxxx@github.com/AppimateSA/AutoVisual.git


# Usage

### Training process

In [4]:
import os
os.environ["AZURE_BLOB_STORAGE_CONN_STR"] = "DefaultEndpointsProtocol=https;AccountName=appimate1storage;AccountKey=nM8FPc0H/suHN/bBo7O3LUUFpXPGruvpyUTYOXdjf0UXw9P2snz4+OyIVaykCz+WsLu7n6FLbehM+AStCrLjsA==;EndpointSuffix=core.windows.net"

In [5]:
# from autovisual import DatasetConfig, VideoCustomDataset
# dataset_config_args = {
#     'refreshData': True,
#     # 'useCLIP': "openai/clip-vit-large-patch14", # Use the pretrained CLIP model for handling ALL Text inputs.
#     'structure': 'text_video_pair',
#     'tokenize_text': False,
#     # 'data_folder': f'.{WORKING_DIR}/datasets/Appimate',
#     'data_folder': 'https://appimate1storage.blob.core.windows.net/datasets/Appimate',
#     'data_json': "dataset.json",
#     'data_points': None, # None

#     'max_text_length': 77,
#     'max_num_frames': 6,
#     'resolution': 256,
#     'num_channels': 1, 
#     'normalize': True,
#     'scale_to': 0,#0.5,
#     'has_start_end_token': True,

#     'frame_rate': 2,
#     'frame_rate_ratio': 0.01,
#     'output_format': 'TCHW'
# }
# datasetConfig = DatasetConfig(**dataset_config_args, train=True)
# custom_dataset = VideoCustomDataset(datasetConfig)

In [6]:
class MockTextVideoDataset(Dataset):
    def __init__(
        self,
        length = 100,
        image_size = 256,
        num_frames = 17
    ):
        super().__init__()
        self.num_frames = num_frames
        self.image_size = image_size
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx): # Video data of shape: CTHW
        video = torch.randn(3, self.num_frames, self.image_size, self.image_size)
        caption = f'video caption {idx}'
        return video, caption

mock_dataset = MockTextVideoDataset()

In [7]:
full_dataset = ConcatDataset([
    mock_dataset,
    # custom_dataset
])

# train_len = int(len(full_dataset) * (datasetConfig.train_split) )
train_len = int(len(full_dataset) * (.9) )
train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_len, len(full_dataset)- train_len])

#### Train the C-ViViT

In [8]:
from phenaki_pytorch import CViViT, CViViTTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 256),
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
).to(device)

data_folder = os.path.expanduser(f"{WORKING_DIR}/datasets/Appimate/train") if IN_COLAB else os.path.expanduser(f"~/.cache/datasets")
trainer = CViViTTrainer(
    cvivit,
    folder = data_folder,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False,  # you can train on images first, before fine tuning on video, for sample efficiency
    use_ema = True,          # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources

    save_results_every = 100,
    save_model_every = 100,
    num_train_steps = 10000
)



training with dataset of 494 samples and validating with randomly splitted 26 samples


In [9]:
trainer.train()               # reconstructions and checkpoints will be saved periodically to ./results

0: vae loss: 19032.031982421875 - discr loss: 11.95794153213501
0: saving to results
0: saving model to results
1: vae loss: 20086.9287109375 - discr loss: 10.997035503387451
2: vae loss: 25521.232421875 - discr loss: 10.768080949783325
3: vae loss: 27216.9951171875 - discr loss: 9.937973260879517
4: vae loss: 18230.06396484375 - discr loss: 20.128325939178467
5: vae loss: 22421.19140625 - discr loss: 8.170398712158203
6: vae loss: 17686.709228515625 - discr loss: 9.332286357879639
7: vae loss: 16220.058837890625 - discr loss: 11.962282180786133
8: vae loss: 18154.435546875 - discr loss: 9.596046209335327
9: vae loss: 12768.8271484375 - discr loss: 10.142586469650269
10: vae loss: 18826.08056640625 - discr loss: 10.69494104385376
11: vae loss: 17306.86376953125 - discr loss: 10.83767557144165
12: vae loss: 24295.339599609375 - discr loss: 10.943338871002197
13: vae loss: 17165.032958984375 - discr loss: 11.077337980270386
14: vae loss: 18728.271606445312 - discr loss: 11.16306400299072

#### Train the Phenaki

In [None]:
from phenaki_pytorch import Phenaki, PhenakiTrainer

phenaki = Phenaki(
    cvivit = cvivit,
    self_token_critic= True  # set this to True
).to(device)
phenaki_trainer = PhenakiTrainer(
    phenaki,
    batch_size=4,
    num_frames=17,
    train_lr=0.0001,
    train_num_steps=2,
    grad_accum_every = 2,
    train_on_images=False,
    save_and_sample_every=100,
    num_samples=4,
    dataset = train_dataset,
    sample_texts_file_path = f"{'/content' if IN_COLAB else '/home/luthando/Desktop'}/phenaki-pytorch/data/sample_texts.txt" # each caption should be on a new line, during sampling, will be randomly drawn
)

In [None]:
phenaki_trainer.train()

### Testing process

In [None]:
# video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# # so in the paper, they do not really achieve 2 minutes of coherent video
# # at each new scene with new text conditioning, they condition on the previous K frames
# # you can easily achieve this with this framework as so

# video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3
# video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# # the total video
# entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)

# # and so on...

In [None]:
# # ... above code

# from phenaki_pytorch import make_video

# entire_video, scenes = make_video(phenaki, texts = [
#     'a squirrel examines an acorn buried in the snow',
#     'a cat watches the squirrel from a frosted window sill',
#     'zoom out to show the entire living room, with the cat residing by the window sill'
# ], num_frames = (17, 14, 14), prime_lengths = (5, 5))

# entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)

# # scenes - List[Tensor[3]] - video segment of each scene