In [None]:
!wget https://huggingface.co/danjacobellis/dance/resolve/main/video_f8c48_24f.pth

In [1]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
import random
import time
from timm.optim import Mars
from types import SimpleNamespace
from IPython.display import HTML
from types import SimpleNamespace
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import CenterCrop, RandomCrop
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from decord import VideoReader
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent

In [2]:
device = "cuda"
dataset = datasets.load_dataset("danjacobellis/davis").cast_column('video',datasets.Video()).with_format("torch")

In [27]:
checkpoint = torch.load('checkpoint_f8c48_784.pth', map_location="cpu",weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']
model = AutoCodecND(
    dim=3,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device).to(torch.bfloat16)
model.load_state_dict(state_dict)
model.eval();

In [38]:
target_w = 1920
target_h = 1080
n_frames = 24

def pad3d(x, p, extra, small_dim_mode):
    b, c, f, h, w = x.shape
    for dim, size in zip(['f', 'h', 'w'], [f, h, w]):
        if small_dim_mode and size < p:
            pad1 = extra
            pad2 = extra
        else:
            t = math.ceil(size / p) * p
            pad_total = t - size
            pad1 = pad_total // 2
            pad2 = pad_total - pad1
            pad1 += extra
            pad2 += extra
        if dim == 'f':
            fp1, fp2 = pad1, pad2
        elif dim == 'h':
            hp1, hp2 = pad1, pad2
        elif dim == 'w':
            wp1, wp2 = pad1, pad2
    return torch.nn.functional.pad(
        x,
        pad=(wp1, wp2, hp1, hp2, fp1, fp2),
        mode="reflect"
    )

In [39]:
PSNR_list = []
CR_list = []
pb = progress_bar(dataset['train'])
total_time = 0
for sample in pb:
    
    video = sample['video']
    len_video = len(video)
    xr = video.get_batch(range(len_video))
    xr = einops.rearrange(xr, 'f h w c -> c f h w')
    x = []
    for i_frame in range(n_frames):
        frame = xr[:, i_frame]
        pil_img = to_pil_image(frame)
        resized_img = pil_img.resize((target_w, target_h))
        tensor_frame = pil_to_tensor(resized_img).unsqueeze(1)
        x.append(tensor_frame)
    x = torch.cat(x, dim=1).unsqueeze(0)
    x = x / 127.5 - 1.0
    x = x.to(device).to(torch.bfloat16)
    x_orig = x.clone()
    x = pad3d(x, p=config.F, extra=0, small_dim_mode=True)
    
    with torch.no_grad():
        t0 = time.time()
        z = model.encode(x)
        latent = model.quantize.compand(z).round().to(torch.bfloat16)
        x_hat = model.decode(latent).clamp(-1,1)
        dt = time.time() - t0
        total_time += dt
    
    x_cropped = []
    x_hat_cropped  = []
    cc = CenterCrop((target_h,target_w))
    for i_frame in range(n_frames):
        x_hat_cropped.append(cc(x_hat[:,:,i_frame]).unsqueeze(2))
    x_hat = torch.cat(x_hat_cropped,dim=2)
    
    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5
    PSNR = []
    for i_frame in range(x_orig_01.shape[2]):
        mse = torch.nn.functional.mse_loss(x_orig_01[0, :, i_frame], x_hat_01[0, :, i_frame])
        PSNR.append(-10 * mse.log10().item())
    PSNR_list.append(PSNR)
    
    size_bytes = 0
    t0 = time.time()
    for chunk in latent_to_pil(einops.rearrange(latent[0], 'c f h w -> f c h w').cpu(),n_bits=8,C=3):
        buff = io.BytesIO()
        chunk.save(buff,format='webp',lossless=True)
        size_bytes += len(buff.getbuffer())
    dt = time.time() - t0
    total_time += dt
    CR_list.append(x_orig.numel()/size_bytes)

    pb.comment = (f"PSNR: {np.mean(PSNR)}, CR:{CR_list[-1]}")

---

24x240x240

In [18]:
np.mean([np.mean(per_frame_psnr) for per_frame_psnr in PSNR])

np.float64(27.115885416666668)

In [19]:
np.mean([np.mean(per_frame_cr) for per_frame_cr in CR_list])

np.float64(67.60885125372157)

In [20]:
target_w*target_h*24*90/total_time/1e6

68.95044907525985

---

24x432x240

In [12]:
np.mean([np.mean(per_frame_psnr) for per_frame_psnr in PSNR])

np.float64(25.338541666666668)

In [13]:
np.mean([np.mean(per_frame_cr) for per_frame_cr in CR_list])

np.float64(93.7856267152748)

In [14]:
target_w*target_h*24*90/total_time/1e6

64.01907287532129

---

540

In [24]:
np.mean([np.mean(per_frame_psnr) for per_frame_psnr in PSNR])

np.float64(27.734375)

In [25]:
np.mean([np.mean(per_frame_cr) for per_frame_cr in CR_list])

np.float64(78.7022754068909)

In [26]:
target_w*target_h*24*90/total_time/1e6

59.96029137687159

---

784

In [35]:
np.mean([np.mean(per_frame_psnr) for per_frame_psnr in PSNR])

np.float64(32.708333333333336)

In [36]:
np.mean([np.mean(per_frame_cr) for per_frame_cr in CR_list])

np.float64(68.16791729653269)

In [37]:
target_w*target_h*24*90/total_time/1e6

96.5282218012801

---

784->1080

In [40]:
np.mean([np.mean(per_frame_psnr) for per_frame_psnr in PSNR])

np.float64(33.606770833333336)

In [41]:
np.mean([np.mean(per_frame_cr) for per_frame_cr in CR_list])

np.float64(72.18689670043877)

In [42]:
target_w*target_h*24*90/total_time/1e6

88.52410005860598