In [1]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
import random
import time
import gc
from timm.optim import Mars
from types import SimpleNamespace
from IPython.display import HTML
from types import SimpleNamespace
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import CenterCrop, RandomCrop
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from decord import VideoReader
from huggingface_hub import snapshot_download
from cosmos_tokenizer.image_lib import ImageTokenizer

In [2]:
device = "cuda"
dataset = datasets.load_dataset("danjacobellis/davis").cast_column('video',datasets.Video()).with_format("torch")

In [3]:
model_path = snapshot_download(repo_id='nvidia/Cosmos-Tokenizer-DV4x8x8')
encoder = ImageTokenizer(checkpoint_enc=f'{model_path}/encoder.jit').to(device)
decoder = ImageTokenizer(checkpoint_dec=f'{model_path}/decoder.jit').to(device)

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
def pad3d(x, p, extra, small_dim_mode):
    b, c, f, h, w = x.shape
    extra_f, extra_h, extra_w = extra  # Unpack the extra tuple for each dimension
    
    for dim, size, extra_pad in zip(['f', 'h', 'w'], [f, h, w], [extra_f, extra_h, extra_w]):
        if small_dim_mode and size < p:
            pad1 = extra_pad
            pad2 = extra_pad
        else:
            t = math.ceil(size / p) * p
            pad_total = t - size
            pad1 = pad_total // 2
            pad2 = pad_total - pad1
            pad1 += extra_pad
            pad2 += extra_pad
        if dim == 'f':
            fp1, fp2 = pad1, pad2
        elif dim == 'h':
            hp1, hp2 = pad1, pad2
        elif dim == 'w':
            wp1, wp2 = pad1, pad2
            
    return torch.nn.functional.pad(
        x,
        pad=(wp1, wp2, hp1, hp2, fp1, fp2),
        mode="reflect"
    )

def center_crop_3d(x, f, h, w):
    assert x.ndim == 5
    _, _, F, H, W = x.shape
    front = (F - f) // 2
    back  = front + f
    top   = (H - h) // 2
    bottom = top + h
    left  = (W - w) // 2
    right = left + w
    return x[:, :, front:back, top:bottom, left:right]

In [6]:
F = (4,8)
target_w = 1920
target_h = 1080
PSNR_list = []
CR_list = []
pb = progress_bar(dataset['train'])
encode_time = 0
decode_time = 0
total_frames = 0
for sample in pb:
    
    video = sample['video']
    len_video = len(video)
    total_frames += len_video
    xr = video.get_batch(range(len_video))
    xr = einops.rearrange(xr, 'f h w c -> c f h w')
    x = []
    for i_frame in range(len_video):
        frame = xr[:, i_frame]
        pil_img = to_pil_image(frame)
        resized_img = pil_img.resize((target_w, target_h))
        tensor_frame = pil_to_tensor(resized_img).unsqueeze(1)
        x.append(tensor_frame)
    x = torch.cat(x, dim=1).unsqueeze(0)
    x = x / 127.5 - 1.0
    x_orig = x.clone()
    x = pad3d(x, p=F[1], extra=(F[0],0,0), small_dim_mode=True)

    x_hat = []
    for i_chunk in range(x.shape[2]//F[0] - 2):
        f1 = i_chunk * F[0]
        f2 = f1 + 3*F[0]
        x_chunk = x[:, :, f1:f2, :, :].to(torch.bfloat16).to(device)
        with torch.no_grad():
            t0 = time.time()
            z = encoder.encode(x_chunk)[0]
            encode_time += time.time()-t0
            gc.collect()
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
            t0 = time.time()
            x_hat.append(decoder.decode(z).clamp(-1,1)[:,:,4:8].detach().cpu())
            decode_time += time.time() - t0
            gc.collect()
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
    x_hat = torch.cat(x_hat,dim=2).to(torch.float)
    
    _,_,f,h,w = x_orig.shape
    x_hat = center_crop_3d(x_hat,f,h,w)
    
    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5
    PSNR = []
    for i_frame in range(x_orig_01.shape[2]):
        mse = torch.nn.functional.mse_loss(x_orig_01[0, :, i_frame], x_hat_01[0, :, i_frame])
        PSNR.append(-10 * mse.log10().item())
    PSNR_list.append(PSNR)
    
    size_bytes = 2*z.numel()*(x.shape[2]//F[0] - 2)
    CR_list.append(x_orig.numel()/size_bytes)

    pb.comment = (f"PSNR: {np.mean(PSNR)}, CR:{CR_list[-1]}")
    break

In [7]:
np.min([np.min(per_frame_psnr) for per_frame_psnr in PSNR_list])

np.float64(24.13890838623047)

In [8]:
np.mean([np.mean(per_frame_psnr) for per_frame_psnr in PSNR_list])

np.float64(26.649215337706774)

In [9]:
np.max([np.max(per_frame_psnr) for per_frame_psnr in PSNR_list])

np.float64(28.22328805923462)

In [10]:
total_frames/encode_time

7.535644370218479

In [11]:
total_frames/decode_time

3.4176818879824324

In [12]:
np.mean(CR_list)

np.float64(89.45454545454545)