In [1]:
import io
import time
import math
import torch
import datasets
import einops
import PIL.Image
import numpy as np
import torch.nn as nn
from types import SimpleNamespace
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor

In [2]:
device = "cuda"
aviris = datasets.load_dataset("danjacobellis/aviris_1k_val", split="validation")

checkpoint = torch.load('../../hf/autocodec/hyper_f8c8.pth', map_location="cpu", weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']

model = AutoCodecND(
    dim=3,
    input_channels=config.input_channels,
    J=int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth=config.encoder_depth,
    encoder_kernel_size=config.encoder_kernel_size,
    decoder_depth=config.decoder_depth,
    lightweight_encode=config.lightweight_encode,
    lightweight_decode=config.lightweight_decode,
).to(device).to(torch.bfloat16)
model.load_state_dict(state_dict)
model.eval();

Resolving data files:   0%|          | 0/66 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/66 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/48 [00:00<?, ?it/s]

In [3]:
def pad3d(x, p, extra, small_dim_mode):
    b, c, f, h, w = x.shape
    extra_f, extra_h, extra_w = extra  # Unpack the extra tuple for each dimension
    
    for dim, size, extra_pad in zip(['f', 'h', 'w'], [f, h, w], [extra_f, extra_h, extra_w]):
        if small_dim_mode and size < p:
            pad1 = extra_pad
            pad2 = extra_pad
        else:
            t = math.ceil(size / p) * p
            pad_total = t - size
            pad1 = pad_total // 2
            pad2 = pad_total - pad1
            pad1 += extra_pad
            pad2 += extra_pad
        if dim == 'f':
            fp1, fp2 = pad1, pad2
        elif dim == 'h':
            hp1, hp2 = pad1, pad2
        elif dim == 'w':
            wp1, wp2 = pad1, pad2
            
    return torch.nn.functional.pad(
        x,
        pad=(wp1, wp2, hp1, hp2, fp1, fp2),
        mode="reflect"
    )

def center_crop_3d(x, f, h, w):
    assert x.ndim == 5
    _, _, F, H, W = x.shape
    front = (F - f) // 2
    back  = front + f
    top   = (H - h) // 2
    bottom = top + h
    left  = (W - w) // 2
    right = left + w
    return x[:, :, front:back, top:bottom, left:right]

In [4]:
sample = aviris[0]

img = sample['image']
bands = []
for i in range(img.n_frames):
    img.seek(i)
    bands.append(np.array(img, dtype=np.int16))
x_orig = torch.from_numpy(np.stack(bands)).to(device).to(torch.bfloat16).unsqueeze(0).unsqueeze(0) / 32768.0
x = pad3d(x_orig, config.F, extra=(0,0,0), small_dim_mode=False)

t0 = time.time()
with torch.no_grad():
    z = model.quantize.compand(model.encode(x)).round()
z = einops.rearrange(z, 'b c f h w -> (c f) b h w')
img_list = latent_to_pil(z.cpu(), n_bits=8, C=1)

buff = []
for img in img_list:
    buff.append(io.BytesIO())
    img.save(buff[-1], format= "TIFF", compression='tiff_adobe_deflate')
encode_time = time.time() - t0

size_bytes = sum(len(b.getbuffer()) for b in buff)
cr = 2*x.numel() / size_bytes

t0 = time.time()
z = pil_to_latent([PIL.Image.open(b) for b in buff], N=1, n_bits=8, C=1)
z = einops.rearrange(z, '(c f) b h w -> b c f h w', c = config.latent_dim).to(device).to(torch.bfloat16)
with torch.no_grad():
    xhat = model.decode(z).clamp(-1,1)

decode_time = time.time() - t0

xhat = center_crop_3d(x=xhat, f=x_orig.shape[2], h=x_orig.shape[3], w=x_orig.shape[4])
mse = torch.nn.functional.mse_loss(x_orig.to(torch.float),xhat.to(torch.float))
psnr = -10*mse.log10().item() + 6.02

In [5]:
def evaluate_quality(sample):   
    img = sample['image']
    bands = []
    for i in range(img.n_frames):
        img.seek(i)
        bands.append(np.array(img, dtype=np.int16))
    x_orig = torch.from_numpy(np.stack(bands)).to(device).to(torch.bfloat16).unsqueeze(0).unsqueeze(0) / 32768.0
    x = pad3d(x_orig, config.F, extra=(0,0,0), small_dim_mode=False)
    
    t0 = time.time()
    with torch.no_grad():
        z = model.quantize.compand(model.encode(x)).round()
    z = einops.rearrange(z, 'b c f h w -> (c f) b h w')
    img_list = latent_to_pil(z.cpu(), n_bits=8, C=1)
    
    buff = []
    for img in img_list:
        buff.append(io.BytesIO())
        img.save(buff[-1], format= "TIFF", compression='tiff_adobe_deflate')
    encode_time = time.time() - t0
    
    size_bytes = sum(len(b.getbuffer()) for b in buff)
    cr = 2*x.numel() / size_bytes
    
    t0 = time.time()
    z = pil_to_latent([PIL.Image.open(b) for b in buff], N=1, n_bits=8, C=1)
    z = einops.rearrange(z, '(c f) b h w -> b c f h w', c = config.latent_dim).to(device).to(torch.bfloat16)
    with torch.no_grad():
        xhat = model.decode(z).clamp(-1,1)
    
    decode_time = time.time() - t0
    
    xhat = center_crop_3d(x=xhat, f=x_orig.shape[2], h=x_orig.shape[3], w=x_orig.shape[4])
    mse = torch.nn.functional.mse_loss(x_orig.to(torch.float),xhat.to(torch.float))
    psnr = -10*mse.log10().item() + 6.02

    return {
        "voxels": x_orig.numel(),
        "encode_time": encode_time,
        "decode_time": decode_time,
        "cr": cr,
        "psnr": psnr,
    }

In [6]:
metrics = [
    'voxels',
    'encode_time',
    'decode_time',
    'cr',
    'psnr',
]

In [7]:
results_dataset = aviris.map(evaluate_quality)

In [8]:
print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['voxels'])/1e6/np.array(results_dataset['encode_time']))} MVox/sec")
print(f"{np.mean(np.array(results_dataset['voxels'])/1e6/np.array(results_dataset['decode_time']))} MVox/sec")

mean
---
voxels: 87734420.76595744
encode_time: 0.14616816094581117
decode_time: 0.0333839883195593
cr: 574.5738471332235
psnr: 18.52037821049386
600.1471162177203 MVox/sec
2680.8554334029 MVox/sec


In [37]:
print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['voxels'])/1e6/np.array(results_dataset['encode_time']))} MVox/sec")
print(f"{np.mean(np.array(results_dataset['voxels'])/1e6/np.array(results_dataset['decode_time']))} MVox/sec")

mean
---
voxels: 87734420.76595744
encode_time: 0.14616816094581117
decode_time: 0.0333839883195593
cr: 574.5738471332235
psnr: 18.52037821049386
600.1471162177203 MVox/sec
2680.8554334029 MVox/sec
