In [1]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
import random
import time
from timm.optim import Mars
from types import SimpleNamespace
from IPython.display import HTML
from types import SimpleNamespace
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import CenterCrop, RandomCrop
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from decord import VideoReader
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent

In [2]:
device = "cuda"
medmnist_types = ['organ', 'adrenal', 'fracture', 'nodule', 'synapse', 'vessel']
dataset_train = datasets.concatenate_datasets([datasets.load_dataset(f"danjacobellis/{type}mnist3d_64", split='train') for type in medmnist_types])
dataset_valid = datasets.concatenate_datasets([datasets.load_dataset(f"danjacobellis/{type}mnist3d_64", split='validation') for type in medmnist_types])
dataset = datasets.DatasetDict({
    'train': dataset_train,
    'validation': dataset_valid
})

In [3]:
checkpoint = torch.load('../../hf/autocodec/med3d_f8c8.pth', map_location="cpu",weights_only=False)

config = checkpoint['config']
state_dict = checkpoint['state_dict']
model = AutoCodecND(
    dim=3,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device)
model.load_state_dict(state_dict)
model.to(torch.bfloat16)
model.eval();

In [4]:
def pil_to_grid3d(img):
    x = torch.tensor(np.array(img))
    x = einops.rearrange(x, '(a y) (b z) c -> (a b c) y z', a=4, b=4, c=4)
    return x

In [5]:
def evaluate_quality(sample):
    
    x = pil_to_grid3d(sample['image']).unsqueeze(0).unsqueeze(0).to(device).to(torch.bfloat16) / 127.5 - 1.0
    orig_dim = x.numel()
    
    t0 = time.time()
    with torch.no_grad():
        z = model.encode(x)
        latent = model.quantize.compand(z).round()
    latent = einops.rearrange(latent, 'b c d h w -> b (c d) h w').cpu()
    img = latent_to_pil(latent, n_bits=8, C=4)
    buff = io.BytesIO()
    img[0].save(buff, format='TIFF', compression = 'tiff_adobe_deflate')
    encode_time = time.time() - t0
    
    size_bytes = len(buff.getbuffer())
    cr = x.numel()/size_bytes
    
    t0 = time.time()
    latent_decoded = pil_to_latent([PIL.Image.open(buff)], N=64, n_bits=8, C=4)
    latent_decoded = einops.rearrange(latent_decoded, 'b (c d) h w -> b c d h w', d=8).to(device).to(torch.bfloat16)
    with torch.no_grad():
        x_hat = model.decode(latent_decoded).clamp(-1, 1)
    decode_time = time.time() - t0
    
    
    mse = torch.nn.functional.mse_loss(x, x_hat)
    psnr = -10 * mse.log10().item() + 6.02

    return {
        'voxels': x.numel(),
        'cr': cr,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'psnr': psnr
    }

In [6]:
gpu_results = dataset['validation'].map(evaluate_quality)

Map:   0%|          | 0/895 [00:00<?, ? examples/s]

In [7]:
metrics = [
    'voxels',
    'encode_time',
    'decode_time',
    'cr',
    'psnr',
]

In [8]:
print("mean\n---")
for metric in metrics:
    μ = np.mean(gpu_results[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(gpu_results['voxels'])/1e6/np.array(gpu_results['encode_time']))} MVox/sec")
print(f"{np.mean(np.array(gpu_results['voxels'])/1e6/np.array(gpu_results['decode_time']))} MVox/sec")

mean
---
voxels: 262144.0
encode_time: 0.005050100560960823
decode_time: 0.004511122197412246
cr: 209.2325504413078
psnr: 24.740452164804466
54.07623692142692 MVox/sec
59.76341054905106 MVox/sec
