# CViViT forward pass

In [None]:
import torch
from phenaki_pytorch.cvivit import CViViT
from phenaki_pytorch.cvivit_trainer import CViViTTrainer

cvivit = CViViT(
    dim=512,  # embedding size
    codebook_size=8192,  # codebook size
    image_size=128,  # H,W
    patch_size=8,  # spatial patch size
    local_vgg=True,
    wandb_mode='disabled',
    temporal_patch_size=2,  # temporal patch size
    spatial_depth=4,  # nb of layers in the spatial transfo
    temporal_depth=4,  # nb of layers in the temporal transfo
    dim_head=64,  # hidden size in transfo
    heads=8,  # nb of heads for multi head transfo
    ff_mult=4,  # 32 * 64 = 2048 MLP size in transfo out
    commit_loss_w=1.,  # commit loss weight
    gen_loss_w=1.,  # generator loss weight
    perceptual_loss_w=1.,  # vgg loss weight
    i3d_loss_w=1.,  # i3d loss weight
    recon_loss_w=10.,  # reconstruction loss weight
    use_discr=0,  # whether to use a stylegan loss or not
    gp_weight=10
    
)

trainer = CViViTTrainer(
    cvivit,
    folder='.',
    batch_size=1,
    force_cpu=False,
    wandb_mode='disabled',
    train_on_images=False,
    grad_accum_every=4,  # use this as a multiplier of the batch size
    # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
    use_ema=False,
    num_train_steps=100000,
    lr=0.0001,  # Learning rate
    wd=0.0001,  # Weight decay
    max_grad_norm=10,  # gradient clipping
    # start the warmup at this factor of the lr
    linear_warmup_start_factor=0.5,
    # nb of iterations for the warm up
    linear_warmup_total_iters=10000,
    # nb of iterations for the cosine annealing
    cosine_annealing_T_max=100000,
    cosine_annealing_eta_min=0.00005,  # lr at the end of annealing
    results_folder='results/',
    inference=True
)


trainer.load('CVIVIT/')
trainer.vae.eval()

In [None]:
import cv2

from torchvision import transforms as T
from einops import rearrange
import numpy as np

import os

real_frames = []

for i in range(len(os.listdir('example_cvivit/obvious/'))):
    video = cv2.VideoCapture('example_cvivit/obvious/obvious____split_'+str(i+1)+'.mp4')
    frames = []

    transform = T.Compose([
                T.ToPILImage(),
                T.ToTensor()
            ])


    check = True
    while check:
        check, frame = video.read()

        if not check:
            continue

        # if exists(crop_size):
        #    frame = crop_center(frame, *pair(crop_size))

        frames.append(rearrange(frame, '... -> 1 ...'))

    # convert list of frames to numpy array
    frames = np.array(np.concatenate(frames, axis=0))
    frames = rearrange(frames, 'f h w c -> c f h w')

    def bgr_to_rgb(video_tensor):
        video_tensor = video_tensor[[2, 1, 0], :, :, :]
        return video_tensor

    frames_torch = bgr_to_rgb(frames)

    frames_torch = torch.tensor(frames_torch).float()/255.

    frames_torch = rearrange(frames_torch, 'c f h w -> 1 c f h w')
    
    real_frames.append(frames_torch[0])
    
    codebook_ids = trainer.vae(frames_torch.cuda(), return_only_codebook_ids = True)
    
    recons = trainer.vae.decode_from_codebook_indices(codebook_ids)
    
    if i == 0:
        final = recons
    else:
        final = torch.vstack((final, recons))

In [None]:
from phenaki_pytorch.data import video_tensor_to_gif
from IPython.display import display, Image
for i, tensor in enumerate(final.unbind(dim = 0)):
    
    print('real video:')
    video_tensor_to_gif(real_frames[i].cpu(), 'original_video_'+str(i)+'.gif')
    display(Image('original_video_'+str(i)+'.gif'))
    
    print('reconstruction:')
    video_tensor_to_gif(tensor.cpu(), 'reconstructed_video_'+str(i)+'.gif')
    display(Image('reconstructed_video_'+str(i)+'.gif'))