In [1]:
import sys
import os
import pytorch_lightning as pl
from src.videogpt.vqvae import VQVAE
from pytorch_lightning.callbacks import ModelCheckpoint
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import torch
from torchvision.io import read_video, read_video_timestamps
import math
import torch.nn.functional as F

##### Settin2 (25 Frineds_vqvae)

- max_epochs 400 
- n_codes 1024
- embedding_dim 128
- downsample 4 8 8
- sequence_length 16

In [2]:
#load the model that is trained on season1 

model = VQVAE.load_from_checkpoint("/lustre06/project/6002071/sana4471/Second_work_narval/H5/video_transformer-main/video_transformer-main/model/Friends_VQVAE/lightning_logs/version_12038476/checkpoints/epoch=211-step=12083.ckpt")
model

VQVAE(
  (encoder): Encoder(
    (convs): ModuleList(
      (0): SamePadConv3d(
        (conv): Conv3d(3, 240, kernel_size=(4, 4, 4), stride=(2, 2, 2))
      )
      (1): SamePadConv3d(
        (conv): Conv3d(240, 240, kernel_size=(4, 4, 4), stride=(2, 2, 2))
      )
      (2): SamePadConv3d(
        (conv): Conv3d(240, 240, kernel_size=(4, 4, 4), stride=(1, 2, 2))
      )
    )
    (conv_last): SamePadConv3d(
      (conv): Conv3d(240, 240, kernel_size=(3, 3, 3), stride=(1, 1, 1))
    )
    (res_stack): Sequential(
      (0): AttentionResidualBlock(
        (block): Sequential(
          (0): BatchNorm3d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
          (2): SamePadConv3d(
            (conv): Conv3d(240, 120, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
          )
          (3): BatchNorm3d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU()
          (5): SamePadConv3d(
            (conv

In [3]:
def preprocess(video, resolution, sequence_length=None):
    # video: THWC, {0, ..., 255}
    video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
    t, c, h, w = video.shape

    # temporal crop
    if sequence_length is not None:
        assert sequence_length <= t
        video = video[:sequence_length]

    # scale shorter side to resolution
    scale = resolution / min(h, w)
    if h < w:
        target_size = (resolution, math.ceil(w * scale))
    else:
        target_size = (math.ceil(h * scale), resolution)
    video = F.interpolate(video, size=target_size, mode='bilinear',
                          align_corners=False)

    # center crop
    t, c, h, w = video.shape
    w_start = (w - resolution) // 2
    h_start = (h - resolution) // 2
    video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
    video = video.permute(1, 0, 2, 3).contiguous() # CTHW

    video -= 0.5

    return video

In [4]:
# test on season2
sequence_length=10560
resolution=128
video_filename='friends_s2e01a.mkv'
pts = read_video_timestamps(video_filename, pts_unit='sec')[0]
video = read_video(video_filename, pts_unit='sec', start_pts=pts[0], end_pts=pts[sequence_length - 1])[0]
# selsub_video=sub_video[None,:,:,:,:]ect part of video 
Start_frame, End_frame= 500,600
sub_video=video[Start_frame:End_frame,:,:,:]

In [5]:
sub_video.shape

torch.Size([100, 480, 720, 3])

In [6]:
sub_video1_2 = preprocess(sub_video, resolution, 100).unsqueeze(0)

In [7]:
sub_video1_2.shape

torch.Size([1, 3, 100, 128, 128])

In [8]:
sub_video=sub_video1_2
with torch.no_grad():
    vqvae=model
    encodings = vqvae.encode(sub_video)
    video_recon = vqvae.decode(encodings)
    video_recon = torch.clamp(video_recon, -0.5, 0.5)

In [9]:
videos = torch.cat((sub_video, video_recon), dim=-1)
videos = videos[0].permute(1, 2, 3, 0) # CTHW -> THWC
videos = ((videos + 0.5) * 255).cpu().numpy().astype('uint8')

fig = plt.figure()
plt.title('real (left), reconstruction (right)')
plt.axis('off')
im = plt.imshow(videos[0, :, :, :])
plt.close()

def init():
    im.set_data(videos[0, :, :, :])

def animate(i):
    im.set_data(videos[i, :, :, :])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=videos.shape[0], interval=200) # 200ms = 5 fps
HTML(anim.to_html5_video())