In [1]:
import io
import time
import torch
import datasets
import PIL.Image
import numpy as np
import torch.nn as nn
from types import SimpleNamespace
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from torchvision.transforms.v2 import PILToTensor

In [None]:
device = "cuda"
aviris = datasets.load_dataset("danjacobellis/aviris_1k_val", split="validation")

checkpoint = torch.load('../../hf/autocodec/hyper_f8c8.pth', map_location="cpu", weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']

model = AutoCodecND(
    dim=3,
    input_channels=config.input_channels,
    J=int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth=config.encoder_depth,
    encoder_kernel_size=config.encoder_kernel_size,
    decoder_depth=config.decoder_depth,
    lightweight_encode=config.lightweight_encode,
    lightweight_decode=config.lightweight_decode,
).to(device).to(torch.bfloat16)
model.load_state_dict(state_dict)
model.eval()

In [None]:
def tiff_to_tensor(img: PIL.Image.Image) -> torch.Tensor:
    """Convert multi‑frame TIFF to (C,H,W) float tensor in [-1,1]."""
    bands = []
    for i_band in range(img.n_frames):
        img.seek(i_band)
        bands.append(np.array(img, dtype='int16'))
    return torch.tensor(np.stack(bands), dtype=torch.float32) / 32768.0


def evaluate_quality(sample):
    img = sample['image']
    x_orig = tiff_to_tensor(img).unsqueeze(0).to(device).to(torch.bfloat16).clamp(-1, 1)
    voxels = x_orig.numel()

    # --- Encode ---
    t0 = time.time()
    with torch.no_grad():
        z = model.encode(x_orig)
        latent = model.quantize.compand(z).round()
    latent_imgs = latent_to_pil(latent.cpu(), n_bits=8, C=config.latent_dim)
    size_bytes = 0
    for im in latent_imgs:
        buff = io.BytesIO()
        im.save(buff, format='PNG')  # lossless
        size_bytes += len(buff.getbuffer())
    encode_time = time.time() - t0

    # --- Decode ---
    t0 = time.time()
    latent_decoded = pil_to_latent(latent_imgs, N=config.latent_dim, n_bits=8, C=config.latent_dim).to(device).to(torch.bfloat16)
    with torch.no_grad():
        x_hat = model.decode(latent_decoded).clamp(-1, 1)
    decode_time = time.time() - t0

    # --- Metrics ---
    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5

    mse = torch.nn.functional.mse_loss(x_orig_01[0], x_hat_01[0])
    PSNR = (-10 * mse.log10()).item()

    bpv = 8 * size_bytes / voxels  # bits per voxel

    return {
        'voxels': voxels,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'bpv': bpv,
        'PSNR': PSNR,
    }

In [None]:
metrics = [
    'voxels',
    'encode_time',
    'decode_time',
    'bpv',
    'PSNR',
]

In [None]:
results_dataset = aviris.map(evaluate_quality)

In [None]:
print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['voxels'])/1e6/np.array(results_dataset['encode_time']))} MVox/sec")
print(f"{np.mean(np.array(results_dataset['voxels'])/1e6/np.array(results_dataset['decode_time']))} MVox/sec")