In [1]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
import random
import av
import time
import os
import gc
from IPython.display import HTML
from types import SimpleNamespace
from timm.optim import Mars
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import ToPILImage, PILToTensor, CenterCrop, RandomCrop
from autocodec.codec import AutoCodecND, pil_to_latent, latent_to_pil
from IPython.display import display, clear_output
from ipywidgets import interact
from piq import ssim, psnr

In [2]:
device="cuda"
dataset = datasets.load_dataset("danjacobellis/nuscenes_front",split='validation').cast_column('video', datasets.Video()).with_format("torch")
subset = dataset.select(range(40)).filter(lambda s: len(s['video'])>=256)
checkpoint = torch.load("../hf/dance/LF_nuscenes_f8c24_v1.2.pth",weights_only=False)
config = checkpoint['config']

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/115 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/115 [00:00<?, ?it/s]

In [3]:
model = AutoCodecND(
    dim=3,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
    post_filter = config.post_filter
).to(device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
print(f"{sum(p.numel() for p in model.parameters())/1e6} M parameters")

275.498072 M parameters


In [19]:
# video = subset[0]['video']
# ℓ = 256
# x = video.get_batch(range(ℓ)).permute(3, 0, 1, 2)
# h = 256; w = 256
# x = CenterCrop((h,w))(x).unsqueeze(0).to(device)
# x = x.to(torch.float)/127.5 - 1.0

# orig_dim = x.numel()
# with torch.no_grad():
#     z = model.encode(x)
#     latent = model.quantize.compand(z).round()
# latent = einops.rearrange(latent, 'b c d h w -> b (c d) h w').cpu()

# webp = latent_to_pil(latent, n_bits=8, C=3)
# display(webp[0])
# buff = io.BytesIO()
# webp[0].save(buff, format='WEBP', lossless=True)
# size_bytes = len(buff.getbuffer())

# print(f"Compressed size: {size_bytes / 1e3:.2f} KB")
# print(f"Compression ratio: {orig_dim / size_bytes:.2f}x")
# print(f"Dimension reduction: {orig_dim / latent.numel():.2f}x")

# latent_decoded = pil_to_latent(webp, N=latent.shape[1], n_bits=8, C=3)
# latent_decoded = einops.rearrange(latent_decoded, 'b (c d) h w -> b c d h w', d=32).to(device)

# with torch.no_grad():
#     x_hat = model.decode(latent_decoded).clamp(-1, 1)

# mse = torch.nn.functional.mse_loss(x, x_hat)
# PSNR = -10 * mse.log10().item() + 6.02

# print(f"PSNR: {PSNR:.2f} dB")

In [31]:
def eval_video_codec(sample):

    video = sample['video']
    num_frames = 256
    x_orig = video.get_batch(range(num_frames)).permute(3, 0, 1, 2)
    h = 256; w = 256;
    x_orig_cropped = CenterCrop((h,w))(x_orig).unsqueeze(0).to(device)
    x = x_orig_cropped.to(torch.float)/127.5 - 1.0
    
    orig_dim = x.numel()
    with torch.no_grad():
        z = model.encode(x)
        latent = model.quantize.compand(z).round()
    latent = einops.rearrange(latent, 'b c d h w -> b (c d) h w').cpu()
    
    webp = latent_to_pil(latent, n_bits=8, C=3)
    buff = io.BytesIO()
    webp[0].save(buff, format='WEBP', lossless=True)
    size_bytes = len(buff.getbuffer())
    latent_decoded = pil_to_latent(webp, N=latent.shape[1], n_bits=8, C=3)
    latent_decoded = einops.rearrange(latent_decoded, 'b (c d) h w -> b c d h w', d=32).to(device)
    
    with torch.no_grad():
        x_hat = model.decode(latent_decoded).clamp(-1, 1)
    
    psnr_values = []
    ssim_values = []
    
    for frame_idx in range(x.shape[1]):
        x_frame_orig = x[:, :, frame_idx]
        x_frame_hat = x_hat[:, :, frame_idx]
    
        mse = torch.nn.functional.mse_loss(x_frame_orig, x_frame_hat)
        psnr_val = -10 * mse.log10().item() + 6.02 # PSNR in dB for [-1, 1] range. However, original code uses 20*log10(2). Let's use that.
        psnr_val = -10 * mse.log10().item() + 20*np.log10(2)
        psnr_values.append(psnr_val)
    
        x_frame_orig_01 = x_frame_orig / 2 + 0.5
        x_frame_hat_01 = x_frame_hat / 2 + 0.5
        ssim_val = ssim(x_frame_orig_01, x_frame_hat_01, data_range=1.0).item() # piq ssim returns scalar
        ssim_values.append(ssim_val)
    
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    video_cr = orig_dim / size_bytes

    return {
        'video_PSNR': avg_psnr,
        'video_SSIM': avg_ssim,
        'video_CR': video_cr
    }

metrics = [
    'video_PSNR',
    'video_SSIM',
    'video_CR'
]

In [38]:
results_dataset = subset.select(range(2)).map(eval_video_codec)
print("mean\n---")
for metric in metrics:
    μ = results_dataset[metric].mean()
    print(f"{metric}: {μ}")

print("median\n---")
for metric in metrics:
    μ = results_dataset[metric].median()
    print(f"{metric}: {μ}")

mean
---
video_PSNR: 33.52024459838867
video_SSIM: 0.8803455829620361
video_CR: 128.71095275878906
median
---
video_PSNR: 33.52024459838867
video_SSIM: 0.8803455233573914
video_CR: 128.71095275878906
