In [1]:
import io
import time
import math
import torch
import datasets
import einops
import PIL.Image
import numpy as np
import torch.nn as nn
from types import SimpleNamespace
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor

In [2]:
device = "cpu"
aviris = datasets.load_dataset("danjacobellis/aviris_1k_val", split="validation")

Resolving data files:   0%|          | 0/66 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/66 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/48 [00:00<?, ?it/s]

In [3]:
def pad3d(x, p, extra, small_dim_mode):
    b, c, f, h, w = x.shape
    extra_f, extra_h, extra_w = extra  # Unpack the extra tuple for each dimension
    
    for dim, size, extra_pad in zip(['f', 'h', 'w'], [f, h, w], [extra_f, extra_h, extra_w]):
        if small_dim_mode and size < p:
            pad1 = extra_pad
            pad2 = extra_pad
        else:
            t = math.ceil(size / p) * p
            pad_total = t - size
            pad1 = pad_total // 2
            pad2 = pad_total - pad1
            pad1 += extra_pad
            pad2 += extra_pad
        if dim == 'f':
            fp1, fp2 = pad1, pad2
        elif dim == 'h':
            hp1, hp2 = pad1, pad2
        elif dim == 'w':
            wp1, wp2 = pad1, pad2
            
    return torch.nn.functional.pad(
        x,
        pad=(wp1, wp2, hp1, hp2, fp1, fp2),
        mode="reflect"
    )

def center_crop_3d(x, f, h, w):
    assert x.ndim == 5
    _, _, F, H, W = x.shape
    front = (F - f) // 2
    back  = front + f
    top   = (H - h) // 2
    bottom = top + h
    left  = (W - w) // 2
    right = left + w
    return x[:, :, front:back, top:bottom, left:right]

In [4]:
def evaluate_quality(sample):   
    img = sample['image']
    bands = []
    for i in range(img.n_frames):
        img.seek(i)
        bands.append(np.array(img, dtype=np.int16))
    x_orig = torch.from_numpy(np.stack(bands)).to(device).to(torch.float16).unsqueeze(0).unsqueeze(0) / (2**16) + 0.5
    x = pad3d(x_orig, 8, extra=(0,0,0), small_dim_mode=False)

    t0 = time.time()
    img_list = [to_pil_image(x[0,:,i]) for i in range(224)]
    buff = []
    for img in img_list:
        buff.append(io.BytesIO())
        img.save(buff[-1], format= "JPEG2000", quality_layers=[284])
    encode_time = time.time() - t0
    
    size_bytes = sum(len(b.getbuffer()) for b in buff)
    cr = 2*x.numel() / size_bytes
    
    t0 = time.time()
    xhat = torch.cat([pil_to_tensor(PIL.Image.open(b)) for b in buff])
    decode_time = time.time() - t0
    
    xhat = center_crop_3d(x=xhat.unsqueeze(0).unsqueeze(0), f=x_orig.shape[2], h=x_orig.shape[3], w=x_orig.shape[4])
    mse = torch.nn.functional.mse_loss(x_orig,xhat.to(torch.float)/255)
    psnr = -10*mse.log10().item()

    return {
        "voxels": x_orig.numel(),
        "encode_time": encode_time,
        "decode_time": decode_time,
        "cr": cr,
        "psnr": psnr,
    }

In [5]:
metrics = [
    'voxels',
    'encode_time',
    'decode_time',
    'cr',
    'psnr',
]

In [None]:
results_dataset = aviris.map(evaluate_quality)

Map:   0%|          | 0/94 [00:00<?, ? examples/s]

In [None]:
print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['voxels'])/1e6/np.array(results_dataset['encode_time']))} MVox/sec")
print(f"{np.mean(np.array(results_dataset['voxels'])/1e6/np.array(results_dataset['decode_time']))} MVox/sec")