In [None]:
import os
import numpy as np
from PIL import Image
import cv2
import json
import pickle
import torch
import random
from IPython.display import Image as ipython_image
from models.transform import MotionVectorProcessor, extract_motions
from diffusers.utils import export_to_video, export_to_gif
from models import build_video_detokenizer
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

In [None]:
# The local directory to save LaVIT checkpoint, set to yours
model_path = "/home/jinyang06/models/VideoLaVIT-v1"
detokenizer_weight = os.path.join(model_path, 'video_3d_unet.bin')

seed = 0
random.seed(seed)
torch.manual_seed(seed)

device_id = 0
torch.cuda.set_device(device_id)

model = build_video_detokenizer(model_path, model_dtype='fp16', pretrained_weight=detokenizer_weight)
model = model.to("cuda")

width = 576
height = 320

max_frames = 24
motion_transform = MotionVectorProcessor(width=width // 8, height=height // 8)

pil_transform = [
    transforms.Resize((height, width), interpolation=InterpolationMode.BICUBIC),
]
pil_transform = transforms.Compose(pil_transform)
image_transform = [
    transforms.Resize((height, width), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
]
image_transform = transforms.Compose(image_transform)

In [None]:
def sample_video_clips(video_path):
    frames, motions, frame_types = extract_motions(video_path, raw_file=True, temp_dir='./tmp', fps=12)
    total_frames = len(frame_types)
    start_indexs = np.where(np.array(frame_types)=='I')[0]
    
    if len(start_indexs) == 0:
        raise ValueError(f"Empty Start indexs: {video_path}")

    # FIlter one I-Frame + 11 P-Frame
    if len(start_indexs) > 1:
        end_indexs = start_indexs + 12
        filter_start_indexs = start_indexs[:-1][end_indexs[:-1] == start_indexs[1:]]    
    else:
        filter_start_indexs = start_indexs

    # FIlter the frames that exceed the max frames
    filter_start_indexs = filter_start_indexs[filter_start_indexs + max_frames <= total_frames]

    if len(filter_start_indexs) > 0:
        # Stack the motions
        start_index = np.random.choice(filter_start_indexs)
        indices = np.arange(start_index, start_index + max_frames)
        motions = [torch.from_numpy(motions[i].transpose((2,0,1))) for i in indices]
        motions = torch.stack(motions).float()
        motions = motion_transform(motions)
        filtered_frames = [Image.fromarray(frames[i]).convert("RGB") for i in indices]
        pil_frames = [pil_transform(frame) for frame in filtered_frames]
        frame_tensors = [image_transform(frame) for frame in filtered_frames]
        frame_tensors = torch.stack(frame_tensors)
        frame_tensors = 2.0 * frame_tensors - 1.0
        return pil_frames, frame_tensors, motions

    else:
        raise ValueError(f"Empty Filtered Start indexs: {video_path}")

In [None]:
video_path = 'demo/31200691.mp4'

video_frames, video_tensors, motions = sample_video_clips(video_path)
output_video_path = "original.gif"
export_to_gif(video_frames, output_video_path)
display(ipython_image(open(output_video_path,'rb').read()))
keyframe = video_tensors[0:1]
motions = motions.unsqueeze(0)

In [None]:
frames =  model.reconstruct_from_token(keyframe.to("cuda"), motions.to("cuda"), decode_chunk_size=8, 
        width=width, height=height, num_frames=24, noise_aug_strength=0.02, cond_on_ref_frame=True, 
        use_linear_guidance=True, max_guidance_scale=3.0, min_guidance_scale=1.0,)[0]
output_video_path = "reconstruct.gif"
export_to_gif(frames, output_video_path)
display(ipython_image(open(output_video_path,'rb').read()))