<a href="https://colab.research.google.com/github/LuthandoMaqondo/phenaki-pytorch/blob/luthando-contribution/notebooks/training.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount the drive

In [1]:
import os
import sys
import platform
import requests
import torch
from torch.utils.data import Dataset, ConcatDataset, DataLoader
import imageio
from matplotlib import pyplot as plt, animation
from IPython.display import display, Image, HTML

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
try:
    from google.colab import drive
    IN_COLAB = True
except:
    WORKING_DIR = '.'
    IN_COLAB = False
if IN_COLAB:
    WORKING_DIR = '/content/drive/MyDrive/Colab Notebooks'
    drive.mount('/content/drive',  force_remount=True)

In [2]:
device = "mps" if torch.backends.mps.is_built() else "cuda" if torch.cuda.is_available() else "cpu"

# Install The Model

In [3]:
if IN_COLAB:
    !git clone https://github.com/LuthandoMaqondo/phenaki-pytorch.git
    %cd /content/phenaki-pytorch
    !git checkout luthando-contribution
    !pip install -r requirements.txt
    # !pip install phenaki-pytorch


# Usage

### Training process

In [5]:
# from autovisual import DatasetConfig, VideoCustomDataset
# dataset_config_args = {
#     'refreshData': True,
#     # 'useCLIP': "openai/clip-vit-large-patch14", # Use the pretrained CLIP model for handling ALL Text inputs.
#     'structure': 'text_video_pair',
#     'tokenize_text': False,
#     # 'data_folder': f'.{WORKING_DIR}/datasets/Appimate',
#     'data_folder': 'https://appimate1storage.blob.core.windows.net/datasets/Appimate',
#     'data_json': "dataset.json",
#     'data_points': None, # None

#     'max_text_length': 77,
#     'max_num_frames': 6,
#     'resolution': 256,
#     'num_channels': 1, 
#     'normalize': True,
#     'scale_to': 0,#0.5,
#     'has_start_end_token': True,

#     'frame_rate': 2,
#     'frame_rate_ratio': 0.01,
#     'output_format': 'TCHW'
# }
# datasetConfig = DatasetConfig(**dataset_config_args, train=True)
# custom_dataset = VideoCustomDataset(datasetConfig)

In [6]:
class MockTextVideoDataset(Dataset):
    def __init__(
        self,
        length = 100,
        image_size = 256,
        num_frames = 17
    ):
        super().__init__()
        self.num_frames = num_frames
        self.image_size = image_size
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx): # Video data of shape: CTHW
        video = torch.randn(3, self.num_frames, self.image_size, self.image_size)
        caption = f'video caption {idx}'
        return video, caption

mock_dataset = MockTextVideoDataset()

In [7]:
full_dataset = ConcatDataset([
    mock_dataset,
    # custom_dataset
])

# train_len = int(len(full_dataset) * (datasetConfig.train_split) )
train_len = int(len(full_dataset) * (.9) )
train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_len, len(full_dataset)- train_len])

# Phenaki (Components)

In [8]:
from phenaki_pytorch import CViViT, CViViTTrainer, MaskGit, Phenaki, PhenakiTrainer


#### Train the Causal-ViViT

In [9]:
cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 256),
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
).to(device)

data_folder = os.path.expanduser(f"~/.cache/datasets")
trainer = CViViTTrainer(
    cvivit,
    folder = data_folder,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False,  # you can train on images first, before fine tuning on video, for sample efficiency
    use_ema = True,          # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources

    save_results_every = 100,
    save_model_every = 100,
    num_train_steps = 10000
)



training with dataset of 4768 samples and validating with randomly splitted 251 samples


In [10]:
trainer.train()               # reconstructions and checkpoints will be saved periodically to ./results

04/03/2024 at 00:41:29
0: vae loss: 24695.996826171875 - discr loss: 12.016433715820312
0: saving to results
0: saving model to results
1: vae loss: 20232.353271484375 - discr loss: 10.936233758926392
2: vae loss: 17431.91796875 - discr loss: 10.75072431564331
3: vae loss: 17244.566650390625 - discr loss: 10.229156017303467
4: vae loss: 22994.951171875 - discr loss: 7.5924235582351685
5: vae loss: 18668.1982421875 - discr loss: 0.5987280756235123
6: vae loss: 17323.72412109375 - discr loss: 1443.1932678222656
7: vae loss: 22257.77197265625 - discr loss: 0.49602970853447914
8: vae loss: 21792.02587890625 - discr loss: 1.414475828409195
9: vae loss: 20138.2255859375 - discr loss: 4.482433319091797
10: vae loss: 15859.2939453125 - discr loss: 4.648132503032684
11: vae loss: 26069.21044921875 - discr loss: 4.876460671424866
12: vae loss: 13698.11669921875 - discr loss: 3.7602545022964478
13: vae loss: 20418.930908203125 - discr loss: 3.162919044494629
14: vae loss: 13798.666259765625 - dis

[mov,mp4,m4a,3gp,3g2,mj2 @ 0x55c0a467e040] moov atom not found


ValueError: need at least one array to concatenate

In [None]:
# from phenaki_pytorch.data import video_tensor_to_gif
# from IPython.display import display, Image
# for i, tensor in enumerate(final.unbind(dim = 0)):
    
#     print('real video:')
#     video_tensor_to_gif(real_frames[i].cpu(), 'original_video_'+str(i)+'.gif')
#     display(Image('original_video_'+str(i)+'.gif'))
    
#     print('reconstruction:')
#     video_tensor_to_gif(tensor.cpu(), 'reconstructed_video_'+str(i)+'.gif')
#     display(Image('reconstructed_video_'+str(i)+'.gif'))

#### Train the Phenaki

In [None]:
# if not ('cvivit' in locals()):
#     cvivit = CViViT(
#         dim = 512,
#         codebook_size = 65536,
#         image_size =  (256, 256),  # video with rectangular screen allowed
#         patch_size = 32,
#         temporal_patch_size = 2,
#         spatial_depth = 4,
#         temporal_depth = 4,
#         dim_head = 64,
#         heads = 8
#     )
#     cvivit.load('./results/vae.2600.pt')

# maskgit = MaskGit(
#                 dim=cvivit.dim,
#                 num_tokens=cvivit.codebook_size,
#                 max_seq_len = 1024,
#                 dim_context = 768,
#                 depth = 6
#             )
# phenaki = Phenaki(
#     cvivit = cvivit,
#     maskgit = maskgit,
#     self_token_critic= True  # set this to True
# ).to(device)
# phenaki_trainer = PhenakiTrainer(
#     phenaki,
#     batch_size=4,
#     num_frames=17,
#     train_lr=0.0001,
#     train_num_steps=2,
#     grad_accum_every = 2,
#     train_on_images=False,
#     save_and_sample_every=100,
#     num_samples=4,
#     dataset = train_dataset,
#     sample_texts_file_path = f"{'/content' if IN_COLAB else '/home/luthando/Desktop'}/phenaki-pytorch/data/sample_texts.txt" # each caption should be on a new line, during sampling, will be randomly drawn
# )

In [None]:
# phenaki_trainer.train()

### Testing process

In [None]:
# video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# # so in the paper, they do not really achieve 2 minutes of coherent video
# # at each new scene with new text conditioning, they condition on the previous K frames
# # you can easily achieve this with this framework as so

# video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3
# video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# # the total video
# entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)
# entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)

# # and so on...

In [None]:
# # ... above code

# from phenaki_pytorch import make_video

# entire_video, scenes = make_video(phenaki, texts = [
#     'a squirrel examines an acorn buried in the snow',
#     'a cat watches the squirrel from a frosted window sill',
#     'zoom out to show the entire living room, with the cat residing by the window sill'
# ], num_frames = (17, 14, 14), prime_lengths = (5, 5))

# entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)

# # scenes - List[Tensor[3]] - video segment of each scene

In [None]:



# # video = entire_video[0].permute(1, 2, 3, 0) # CTHW -> THWC
# # video = video.cpu().numpy()#.astype('uint8')

# # # fig = plt.figure()
# # fig = plt.figure(figsize=(2.2,2.2))  #Display size specification
# # fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
# # plt.axis('off')
# # im = plt.imshow(video[0, :, :, :])
# # plt.close()
# # def init():
# #     im.set_data(video[0, :, :, :])
# # def animate(i):
# #     im.set_data(video[i, :, :, :])
# #     return im
# # anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=200) # 200ms = 5 fps
# # display(HTML(anim.to_html5_video()))

# from phenaki_pytorch.data import video_tensor_to_gif
# from IPython.display import display, Image

# video = entire_video[0]#.permute(1, 2, 3, 0) # CTHW -> THWC
# video = video.cpu()#.numpy()#.astype('uint8')
# print('generated video:')
# video_tensor_to_gif(video, 'generated_video_.gif')
# display(Image('generated_video_.gif'))