In [1]:
import sys
sys.path.append('../')

%load_ext autoreload
%autoreload 2

In [12]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
import einops
from pathlib import Path
import imageio
from IPython.display import Video

from genie.tokenizer import VideoTokenizer
from genie.dataset import LightningPlatformer2D

## Load Model

In [3]:
sa_enc_desc = (
    ('spacetime_downsample', {
        'in_channels' : 3,
        'kernel_size' : 3,
        'out_channels' : 512,
        'time_factor' : 1,
        'space_factor' : 4,
    }),
    ('space-time_attn', {
        'n_rep' : 8,
        'n_head': 8,
        'd_head': 64,
    }),
)

sa_dec_desc = (
    ('space-time_attn', {
        'n_rep' : 8,
        'n_head': 8,
        'd_head': 64,
    }),
    ('depth2spacetime_upsample', {
        'in_channels' : 512,
        'kernel_size' : 3,
        'out_channels' : 3,
        'time_factor' : 1,
        'space_factor' : 4,
    })
)

sa_disc_desc = {
    'inp_size': [64, 64],
    'model_dim': 128,
    'num_heads': 8,
    'dim_mults': [1, 2, 4],
    'down_step': [None, 2, 2],
    'inp_channels': 3,
    'kernel_size': 3,
    'num_groups': 8,
    'act_fn': 'leaky',
    'use_blur': True,
    'use_attn': True,
}

In [4]:
video_tokenizer = VideoTokenizer(
    enc_desc=sa_enc_desc,
    dec_desc=sa_dec_desc,
    disc_kwargs=sa_disc_desc,
    d_codebook=10,
    n_codebook=1,
    lfq_bias=True,
    lfq_frac_sample=1,
    lfq_commit_weight= 0.25,
    lfq_entropy_weight= 0.01,
    lfq_diversity_weight= 1.,
    optimizer=torch.optim.AdamW,
    perceptual_model='vgg16',
    perc_feat_layers=('features.6', 'features.13', 'features.18', 'features.25'),
    gan_discriminate='frames',
    gan_frames_per_batch=4,
    gan_loss_weight=0.1,
    perc_loss_weight=0.1,
    quant_loss_weight=0.1,
)

In [5]:
ckpt_path = '/home/sm/PycharmProjects/open-genie/open-genie/video_tokenizer1/checkpoints/last.ckpt'
ckpt = torch.load(ckpt_path)

  ckpt = torch.load(ckpt_path)


In [6]:
tokenizer = video_tokenizer.load_state_dict(ckpt['state_dict'])

## Create Dataset

In [7]:
dataset = LightningPlatformer2D(
    root='/home/sm/Datasets/open-genie',
    env_name='Coinrun',
    padding=None,
    randomize=True,
    transform=None,
    num_frames=16,
    batch_size=8,
    output_format='c t h w',
    num_workers=1
)

dataset.setup('fit')

In [8]:
val_dataset = dataset.valid_dataset

In [9]:
val_dataset[0].size()

torch.Size([3, 16, 64, 64])

In [21]:
def show_video(episode_index, dataset):
    frames = dataset[episode_index]
    Path("outputs/examples").mkdir(parents=True, exist_ok=True)
    video_path = f"outputs/examples/{episode_index}.mp4"
    return show_video_frames(frames, video_path)

In [22]:
def show_video_frames(frames, video_path):
    frames = einops.rearrange(frames, 'c t h w -> t c h w')
    frames = [(frame * 255).type(torch.uint8).permute(1, 2, 0) for frame in frames]
    frames = [frame.numpy() for frame in frames]
    imageio.mimsave(video_path, frames, fps=10)
    
    return video_path

In [23]:
video_path = show_video(0, val_dataset)
Video(video_path, embed=True, width=360, height=360)

In [24]:
tokens, idxs = video_tokenizer.tokenize(val_dataset[0].unsqueeze(0))

In [25]:
rec_video = video_tokenizer.decode(tokens)

In [27]:
rec_video_path = show_video_frames(rec_video.squeeze(0), video_path = f"outputs/examples/rec.mp4")
Video(rec_video_path, embed=True, width=360, height=360)