In [1]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
import random
import time
from timm.optim import Mars
from types import SimpleNamespace
from IPython.display import HTML
from types import SimpleNamespace
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import CenterCrop, RandomCrop
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from decord import VideoReader
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent

In [2]:
device = "cuda"
dataset = datasets.load_dataset("danjacobellis/davis").cast_column('video',datasets.Video()).with_format("torch")

In [3]:
checkpoint = torch.load('../../hf/autocodec/video_f8c48.pth', map_location="cpu",weights_only=False)

config = checkpoint['config']
state_dict = checkpoint['state_dict']
model = AutoCodecND(
    dim=3,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device).to(torch.bfloat16)
model.load_state_dict(state_dict)
model.train();

In [4]:
def pad3d(x, p, extra, small_dim_mode):
    b, c, f, h, w = x.shape
    extra_f, extra_h, extra_w = extra  # Unpack the extra tuple for each dimension
    
    for dim, size, extra_pad in zip(['f', 'h', 'w'], [f, h, w], [extra_f, extra_h, extra_w]):
        if small_dim_mode and size < p:
            pad1 = extra_pad
            pad2 = extra_pad
        else:
            t = math.ceil(size / p) * p
            pad_total = t - size
            pad1 = pad_total // 2
            pad2 = pad_total - pad1
            pad1 += extra_pad
            pad2 += extra_pad
        if dim == 'f':
            fp1, fp2 = pad1, pad2
        elif dim == 'h':
            hp1, hp2 = pad1, pad2
        elif dim == 'w':
            wp1, wp2 = pad1, pad2
            
    return torch.nn.functional.pad(
        x,
        pad=(wp1, wp2, hp1, hp2, fp1, fp2),
        mode="reflect"
    )

def center_crop_3d(x, f, h, w):
    assert x.ndim == 5
    _, _, F, H, W = x.shape
    front = (F - f) // 2
    back  = front + f
    top   = (H - h) // 2
    bottom = top + h
    left  = (W - w) // 2
    right = left + w
    return x[:, :, front:back, top:bottom, left:right]

In [None]:
target_w = 1920
target_h = 1080
PSNR_list = []
CR_list = []
pb = progress_bar(dataset['train'])
encode_time = 0
decode_time = 0
total_frames = 0
for sample in pb:
    
    video = sample['video']
    len_video = len(video)
    total_frames += len_video
    xr = video.get_batch(range(len_video))
    xr = einops.rearrange(xr, 'f h w c -> c f h w')
    x = []
    for i_frame in range(len_video):
        frame = xr[:, i_frame]
        pil_img = to_pil_image(frame)
        resized_img = pil_img.resize((target_w, target_h))
        tensor_frame = pil_to_tensor(resized_img).unsqueeze(1)
        x.append(tensor_frame)
    x = torch.cat(x, dim=1).unsqueeze(0)
    x = x / 127.5 - 1.0
    x = x.to(device).to(torch.bfloat16)
    x_orig = x.clone()
    x = pad3d(x, p=config.F, extra=(8,0,0), small_dim_mode=True)
    # encode
    with torch.no_grad():
        t0 = time.time()
        z = model.encode(x)
        latent = model.quantize.compand(z).round().to(torch.bfloat16)
        dt = time.time() - t0
        encode_time += dt

    # decode
    x_hat = []
    for i_chunk in range(latent.shape[2]-2):
        i1 = i_chunk; i2 = i_chunk+3
        latent_chunk = latent[:,:,i1:i2,:,:]
        with torch.no_grad():
            t0 = time.time()
            x_hat.append(model.decode(latent_chunk)[:,:,8:16].clamp(-1,1))
            dt = time.time() - t0
            decode_time += dt
    x_hat = torch.cat(x_hat,dim=2)
    
    _,_,f,h,w = x_orig.shape
    x_hat = center_crop_3d(x_hat,f,h,w)
    
    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5
    PSNR = []
    for i_frame in range(x_orig_01.shape[2]):
        mse = torch.nn.functional.mse_loss(x_orig_01[0, :, i_frame], x_hat_01[0, :, i_frame])
        PSNR.append(-10 * mse.log10().item())
    PSNR_list.append(PSNR)
    
    size_bytes = 0
    t0 = time.time()
    for chunk in latent_to_pil(einops.rearrange(latent[0], 'c f h w -> f c h w').cpu(),n_bits=8,C=3):
        buff = io.BytesIO()
        chunk.save(buff,format='webp',lossless=True)
        size_bytes += len(buff.getbuffer())
    dt = time.time() - t0
    encode_time += dt
    CR_list.append(x_orig.numel()/size_bytes)

    pb.comment = (f"PSNR: {np.mean(PSNR)}, CR:{CR_list[-1]}")

In [14]:
np.min([np.mean(per_frame_psnr) for per_frame_psnr in PSNR_list])

np.float64(25.7921875)

In [15]:
np.mean([np.mean(per_frame_psnr) for per_frame_psnr in PSNR_list])

np.float64(30.241360876081732)

In [16]:
np.max([np.mean(per_frame_psnr) for per_frame_psnr in PSNR_list])

np.float64(40.760690789473685)

In [17]:
total_frames/encode_time

33.606962237145176

In [18]:
total_frames/decode_time

4.8869175532133635

In [19]:
np.argmax([np.mean(per_frame_psnr) for per_frame_psnr in PSNR_list])

np.int64(65)

In [20]:
np.argmin([np.mean(per_frame_psnr) for per_frame_psnr in PSNR_list])

np.int64(63)

In [21]:
np.mean(CR_list)

np.float64(79.60111699678902)

---

In [None]:
video = dataset['train'][70]['video']
len_video = len(video)
total_frames += len_video
xr = video.get_batch(range(len_video))
xr = einops.rearrange(xr, 'f h w c -> c f h w')
x = []
for i_frame in range(len_video):
    frame = xr[:, i_frame]
    pil_img = to_pil_image(frame)
    resized_img = pil_img.resize((target_w, target_h))
    tensor_frame = pil_to_tensor(resized_img).unsqueeze(1)
    x.append(tensor_frame)
x = torch.cat(x, dim=1).unsqueeze(0)
x = x / 127.5 - 1.0
x = x.to(device).to(torch.bfloat16)
x_orig = x.clone()
x = pad3d(x, p=config.F, extra=(8,0,0), small_dim_mode=True)
# encode
with torch.no_grad():
    t0 = time.time()
    z = model.encode(x)
    latent = model.quantize.compand(z).round().to(torch.bfloat16)
    dt = time.time() - t0
    encode_time += dt

# decode
x_hat = []
for i_chunk in range(latent.shape[2]-2):
    i1 = i_chunk; i2 = i_chunk+3
    latent_chunk = latent[:,:,i1:i2,:,:]
    with torch.no_grad():
        t0 = time.time()
        x_hat.append(model.decode(latent_chunk)[:,:,8:16].clamp(-1,1))
        dt = time.time() - t0
        decode_time += dt
x_hat = torch.cat(x_hat,dim=2)

_,_,f,h,w = x_orig.shape
x_hat = center_crop_3d(x_hat,f,h,w)

x_orig_01 = x_orig / 2 + 0.5
x_hat_01 = x_hat / 2 + 0.5
PSNR = []
for i_frame in range(x_orig_01.shape[2]):
    mse = torch.nn.functional.mse_loss(x_orig_01[0, :, i_frame], x_hat_01[0, :, i_frame])
    PSNR.append(-10 * mse.log10().item())

size_bytes = 0
t0 = time.time()
for chunk in latent_to_pil(einops.rearrange(latent[0], 'c f h w -> f c h w').cpu(),n_bits=8,C=3):
    buff = io.BytesIO()
    chunk.save(buff,format='webp',lossless=True)
    size_bytes += len(buff.getbuffer())
dt = time.time() - t0
encode_time += dt
x_orig.numel()/size_bytes

In [None]:
np.mean(PSNR)

In [None]:
!rm -rf test/

In [24]:
import os
import subprocess
output_dir = 'test'
os.makedirs(output_dir, exist_ok=True)
num_frames = x_hat.shape[2]
for i_frame in range(num_frames):
    img = to_pil_image(x_hat[0, :, i_frame].to(torch.float) / 2 + 0.5)
    filename = os.path.join(output_dir, f"frame_{i_frame:04d}.jpg")
    img.save(filename, format="jpeg", quality=100)
output_video = 'output_mjpeg.avi'
ffmpeg_cmd = [
    'ffmpeg',
    '-y', 
    '-framerate', '24',
    '-i', os.path.join(output_dir, 'frame_%04d.jpg'),
    '-c:v', 'mjpeg',
    '-q:v', '1',
    output_video
]
subprocess.run(ffmpeg_cmd)

ffmpeg version 4.4.2-0ubuntu0.22.04.1+esm6 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.4.0-1ubuntu1~22.04)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1+esm6 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enabl

CompletedProcess(args=['ffmpeg', '-y', '-framerate', '24', '-i', 'test/frame_%04d.jpg', '-c:v', 'mjpeg', '-q:v', '1', 'output_mjpeg.avi'], returncode=0)