In [1]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
import random
import av
import time
import os
from IPython.display import HTML
from types import SimpleNamespace
from timm.optim import Mars
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import ToPILImage, PILToTensor, CenterCrop, RandomCrop
from autocodec.codec import AutoCodecND, pil_to_latent, latent_to_pil

In [2]:
dataset = datasets.load_dataset("danjacobellis/nuscenes_front").cast_column('video', datasets.Video()).with_format("torch")

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/115 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/115 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/74 [00:00<?, ?it/s]

In [3]:
device = "cuda"
config = SimpleNamespace()
config.F = 8
config.latent_dim = 96
config.input_channels = 3
config.lightweight_encode = True
config.lightweight_decode = False
config.encoder_depth = 6
config.λ = 1e-2
config.ema_decay = 0.999
config.consistency_start = 0.05
config.consistency_loss = 1.0
config.max_lr = 1e-3
config.min_lr = config.max_lr / 1e3
config.lr_pow = 6
config.n_frames=32
config.progressive_sizes = [8*int(s) for s in np.linspace(2,2.75,10)**3]; print(config.progressive_sizes)
config.batch_size = 4
config.num_workers = 20
config.epochs = 300
config.total_steps = config.epochs * (dataset['train'].num_rows // config.batch_size)
config.checkpoint = "../hf/dance/LF_nuscenes_f8c96_v1.0.pth"

[64, 72, 80, 88, 96, 112, 120, 136, 144, 160]


In [4]:
model = AutoCodecND(
    dim=3,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
    post_filter = False
).to(device)

if config.checkpoint:
    checkpoint = torch.load(config.checkpoint,weights_only=False)
    model.load_state_dict(checkpoint['state_dict'])

print(f"{sum(p.numel() for p in model.parameters())/1e6} M parameters")

268.818528 M parameters


In [5]:
def collate_fn(batch,epoch):
    n_frames = config.n_frames
    h = get_epoch_size(epoch)
    w = get_epoch_size(epoch)
    batch_tensors = []
    for sample in batch:
        video = sample['video']
        len_video = len(video)
        start = random.randint(0, len_video - n_frames)
        indices = range(start, start + n_frames)
        frames = video.get_batch(indices)
        H, W = frames.shape[1], frames.shape[2]
        i = random.randint(0, H - h)
        j = random.randint(0, W - w)
        frames_cropped = frames[:, i:i+h, j:j+w, :]
        frames_cropped = frames_cropped.permute(3, 0, 1, 2)
        batch_tensors.append(frames_cropped)
    return torch.stack(batch_tensors, dim=0).to(torch.float)/127.5 - 1.0

In [19]:
from IPython.display import display, clear_output
from ipywidgets import interact

model.eval()
config.n_frames=256
def get_epoch_size(epoch):
    return 256
x = collate_fn(dataset['validation'].select([1]),0).to(device)
orig_dim = x.numel()

with torch.no_grad():
    z = model.encode(x)
    latent = model.quantize.compand(z).round().to(torch.int16) + 128
    x = latent.cpu().to(torch.uint8).numpy()[0]
# latent = einops.rearrange(latent, 'b c d h w -> b (c d) h w').cpu()

# webp = latent_to_pil(latent, n_bits=8, C=3)
# # display(webp[0])
# buff = io.BytesIO()
# webp[0].save(buff, format='WEBP', lossless=True)
# size_bytes = len(buff.getbuffer())

# print(f"Compressed size: {size_bytes / 1e3:.2f} KB")
# print(f"Compression ratio: {orig_dim / size_bytes:.2f}x")
# print(f"Dimension reduction: {orig_dim / latent.numel():.2f}x")

# latent_decoded = pil_to_latent(webp, N=latent.shape[1], n_bits=8, C=3)
# latent_decoded = einops.rearrange(latent_decoded, 'b (c d) h w -> b c d h w', d=32).to(device)

# with torch.no_grad():
#     x_hat = model.decode(latent_decoded).clamp(-1, 1)

# mse = torch.nn.functional.mse_loss(x, x_hat)
# PSNR = -10 * mse.log10().item() + 6.02

# print(f"PSNR: {PSNR:.2f} dB")

# original_volume = x[0].cpu().numpy()
# reconstructed_volume = x_hat[0].cpu().numpy()

# def show_slices(slice_idx):
#     clear_output(wait=True)
#     orig_slice = ToPILImage()(einops.rearrange(original_volume[:,slice_idx],'c h w -> h w c')/2+0.5)
#     recon_slice = ToPILImage()(einops.rearrange(reconstructed_volume[:,slice_idx],'c h w -> h w c')/2+0.5)
#     display(orig_slice)
#     display(recon_slice)

# interact(show_slices, slice_idx=(0, original_volume.shape[1]-1, 1));

In [32]:
def tile_grayscale(x, f):
    channels = x[:, f, :, :]  # (96, 32, 32)
    frame = np.zeros((320, 320), dtype=np.uint8)  # Single 2D array
    for i in range(10):
        for j in range(10):
            idx = i * 10 + j
            if idx < 96:
                frame[i*32:(i+1)*32, j*32:(j+1)*32] = channels[idx]
    return frame

def tile_3ch(x, f):
    channels = x[:, f, :, :]  # (96, 32, 32)
    groups = channels.reshape(32, 3, 32, 32)  # (32, 3, 32, 32)
    y_plane = np.zeros((192, 192), dtype=np.uint8)
    u_plane = np.zeros((192, 192), dtype=np.uint8)
    v_plane = np.zeros((192, 192), dtype=np.uint8)
    for i in range(6):
        for j in range(6):
            idx = i * 6 + j
            if idx < 32:
                y_plane[i*32:(i+1)*32, j*32:(j+1)*32] = groups[idx, 0, :, :]
                u_plane[i*32:(i+1)*32, j*32:(j+1)*32] = groups[idx, 1, :, :]
                v_plane[i*32:(i+1)*32, j*32:(j+1)*32] = groups[idx, 2, :, :]
    return [y_plane, u_plane, v_plane]

def tile_4ch(x, f):
    channels = x[:, f, :, :]  # (96, 32, 32)
    groups = channels.reshape(24, 4, 32, 32)  # (24, 4, 32, 32)
    y_plane = np.zeros((160, 160), dtype=np.uint8)
    u_plane = np.zeros((160, 160), dtype=np.uint8)
    v_plane = np.zeros((160, 160), dtype=np.uint8)
    a_plane = np.zeros((160, 160), dtype=np.uint8)
    for i in range(5):
        for j in range(5):
            idx = i * 5 + j
            if idx < 24:
                y_plane[i*32:(i+1)*32, j*32:(j+1)*32] = groups[idx, 0, :, :]
                u_plane[i*32:(i+1)*32, j*32:(j+1)*32] = groups[idx, 1, :, :]
                v_plane[i*32:(i+1)*32, j*32:(j+1)*32] = groups[idx, 2, :, :]
                a_plane[i*32:(i+1)*32, j*32:(j+1)*32] = groups[idx, 3, :, :]
    return [y_plane, u_plane, v_plane, a_plane]

# Codec settings
codecs = {
    'ffv1': {'codec': 'ffv1', 'options': {'level': '3', 'coder': '1'}},
    'h264': {'codec': 'libx264', 'options': {'qp': '0', 'preset': 'veryslow'}},
    'hevc': {'codec': 'libx265', 'options': {'x265-params': 'lossless=1'}},
    'vp9': {'codec': 'vp9', 'options': {'lossless': '1'}}
}

# Processing options
options = {
    'grayscale': {'tile_func': tile_grayscale, 'width': 320, 'height': 320, 'pix_fmt': 'gray8'},
    '3ch': {'tile_func': tile_3ch, 'width': 192, 'height': 192, 'pix_fmt': 'yuv444p'},
    '4ch': {'tile_func': tile_4ch, 'width': 160, 'height': 160, 'pix_fmt': 'yuva444p'}
}

# Encode and measure
results = {}
for opt_name, opt in options.items():
    for codec_name, codec in codecs.items():
        # Skip unsupported combinations
        if opt_name == 'grayscale' and codec_name == 'vp9':
            continue
        if opt_name == '4ch' and codec_name == 'h264':
            continue
        if opt_name == '4ch' and codec_name == 'hevc':
            continue
        if opt_name == '4ch' and codec_name == 'vp9':
            continue
        
        filename = f"{opt_name}_{codec_name}.mkv"
        start_time = time.time()
        
        with av.open(filename, 'w') as container:
            stream = container.add_stream(codec['codec'], rate=1)
            stream.width = opt['width']
            stream.height = opt['height']
            stream.pix_fmt = opt['pix_fmt']
            stream.options = codec['options']
            
            for f in range(32):
                frame_data = opt['tile_func'](x, f)
                if opt['pix_fmt'] == 'gray8':
                    # Single 2D array for grayscale
                    frame = av.VideoFrame.from_ndarray(frame_data, format=opt['pix_fmt'])
                else:
                    # Planar format: create empty frame and copy planes
                    frame = av.VideoFrame(opt['width'], opt['height'], opt['pix_fmt'])
                    for i, plane_data in enumerate(frame_data):
                        frame.planes[i].update(plane_data)
                
                packet = stream.encode(frame)
                container.mux(packet)
            
            # Flush encoder
            packet = stream.encode(None)
            container.mux(packet)
        
        encoding_time = time.time() - start_time
        file_size = os.path.getsize(filename) / (1024 * 1024)  # MB
        results[f"{opt_name}_{codec_name}"] = {'time (s)': encoding_time, 'size (MB)': file_size}

# Print results
for key, value in results.items():
    print(f"{key}: Time = {value['time (s)']:.2f}s, Size = {value['size (MB)']:.2f}MB")

x265 [info]: frame I:      1, Avg QP:4.00  kb/s: 336.70  
x265 [info]: frame P:      7, Avg QP:4.00  kb/s: 359.32  
x265 [info]: frame B:     24, Avg QP:4.00  kb/s: 363.05  
x265 [info]: Weighted P-Frames: Y:0.0% UV:0.0%
x265 [info]: consecutive B-frames: 12.5% 0.0% 12.5% 25.0% 50.0% 
x265 [info]: lossless compression ratio 2.27::1

encoded 32 frames in 21.67s (1.48 fps), 361.41 kb/s, Avg QP:4.00
x265 [info]: HEVC encoder version 3.5+1-f0c1022b6
x265 [info]: build info [Linux][GCC 10.2.1][64 bit] 8bit+10bit+12bit
x265 [info]: using cpu capabilities: MMX2 SSE2Fast LZCNT SSSE3 SSE4.2 AVX FMA3 BMI2 AVX2
x265 [info]: Unknown profile, Level-8.5 (Main tier)
x265 [info]: Thread pool created using 32 threads
x265 [info]: Slices                              : 1
x265 [info]: frame threads / pool features       : 5 / wpp(5 rows)
x265 [info]: Coding QT: max CU size, min CU size : 64 / 8
x265 [info]: Residual QT: max TU size, max depth : 32 / 1 inter / 1 intra
x265 [info]: ME / range / subpel / mer

grayscale_ffv1: Time = 0.04s, Size = 1.33MB
grayscale_h264: Time = 0.18s, Size = 1.35MB
grayscale_hevc: Time = 0.22s, Size = 1.38MB
3ch_ffv1: Time = 0.04s, Size = 1.34MB
3ch_h264: Time = 0.20s, Size = 1.37MB
3ch_hevc: Time = 0.32s, Size = 1.39MB
3ch_vp9: Time = 1.55s, Size = 1.33MB
4ch_ffv1: Time = 0.03s, Size = 1.35MB
