In [1]:
%env CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8

env: CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8


In [2]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
import random
from timm.optim import Mars
from types import SimpleNamespace
from IPython.display import HTML
from types import SimpleNamespace
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import CenterCrop, RandomCrop
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from decord import VideoReader
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent

In [None]:
device = "cuda"
dataset = datasets.load_dataset("danjacobellis/davis").cast_column('video',datasets.Video()).with_format("torch")

In [None]:
checkpoint = torch.load('../../hf/dance/video_f8c24.pth', map_location="cpu",weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']
model = AutoCodecND(
    dim=3,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device)
model.load_state_dict(state_dict)
model.eval();

In [None]:
def pad3d(x,p=8):
    b, c, f, h, w = x.shape
    t = math.ceil(f / p) * p
    fp1 = (t - f) // 2
    fp2 = (t - f) - fp1
    t = math.ceil(h / p) * p
    hp1 = (t - h) // 2
    hp2 = (t - h) - hp1
    t = math.ceil(w / p) * p
    wp1 = (t - w) // 2
    wp2 = (t - w) - wp1
    return torch.nn.functional.pad(x, pad=(wp1, wp2, hp1, hp2, fp1, fp2), mode="reflect")

In [None]:
sample = dataset['train'][1]
video = sample['video']
len_video = len(video)
x1080 = video.get_batch(range(len_video))
x1080 = einops.rearrange(x1080,'f h w c -> c f h w')
x = []
for i_frame in range(x1080.shape[1]):
    frame = x1080[:,i_frame]
    x.append(pil_to_tensor(to_pil_image(frame).resize((1920,1080))).unsqueeze(1))
x = torch.cat(x,dim=1).unsqueeze(0)
x = x/127.5 - 1.0
x = x.to(device)
x = pad3d(x)
with torch.no_grad():
    z = model.encode(x)
    latent = model.quantize.compand(z).round().cpu()

decode_bs = 6
x_hat_list = []
with torch.no_grad():
    for start_idx in range(0, latent.shape[2], decode_bs):
        end_idx = min(start_idx + decode_bs, latent.shape[2])
        latent_batch = latent[:, :, range(start_idx,end_idx)].to(device)
        x_hat_batch = model.decode(latent_batch).clamp(-1, 1)
        x_hat_list.append(x_hat_batch)
    x_hat = torch.cat(x_hat_list, dim=2)

x_orig_01 = x / 2 + 0.5
x_hat_01 = x_hat / 2 + 0.5

mse = torch.nn.functional.mse_loss(x_orig_01[0], x_hat_01[0])
PSNR = -10 * mse.log10().item()
PSNR