In [3]:
import os
import sys

WORKING_DIR = "../"
sys.path.insert(0, WORKING_DIR)

In [6]:
# !pip install phenaki-pytorch

# Usage

In [8]:
import torch
from phenaki_pytorch import CViViT, CViViTTrainer

In [9]:
cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
).cuda()

data_folder = os.path.expanduser(f"~/.cache/Appimate")
trainer = CViViTTrainer(
    cvivit,
    folder = data_folder,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False,  # you can train on images first, before fine tuning on video, for sample efficiency
    use_ema = False,          # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
    num_train_steps = 10000
)

# trainer.train()               # reconstructions and checkpoints will be saved periodically to ./results


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /Users/luthandomaqondo/.cache/torch/hub/checkpoints/vgg16-397923af.pth
 45%|████▌     | 238M/528M [06:10<10:40, 474kB/s]    

In [None]:
import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 128),  # video with rectangular screen allowed
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

model_path = os.path.expanduser(f"~/.cache/Appimate")
cvivit.load(model_path)
# cvivit.load('/path/to/trained/cvivit.pt')
maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

videos = torch.randn(3, 3, 17, 256, 128).cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 17)).bool().cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
]

loss = phenaki(videos, texts = texts, video_frame_mask = mask)
loss.backward()

# do the above for many steps, then ...

video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# so in the paper, they do not really achieve 2 minutes of coherent video
# at each new scene with new text conditioning, they condition on the previous K frames
# you can easily achieve this with this framework as so

video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3

video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# the total video

entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)

# and so on...