In [1]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
import random
import time
from timm.optim import Mars
from types import SimpleNamespace
from IPython.display import HTML
from types import SimpleNamespace
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import CenterCrop, RandomCrop
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image

In [2]:
device = "cuda"
medmnist_types = ['organ', 'adrenal', 'fracture', 'nodule', 'synapse', 'vessel']
dataset_train = datasets.concatenate_datasets([datasets.load_dataset(f"danjacobellis/{type}mnist3d_64", split='train') for type in medmnist_types])
dataset_valid = datasets.concatenate_datasets([datasets.load_dataset(f"danjacobellis/{type}mnist3d_64", split='validation') for type in medmnist_types])
dataset = datasets.DatasetDict({
    'train': dataset_train,
    'validation': dataset_valid
})

In [3]:
def pil_to_grid3d(img):
    x = torch.tensor(np.array(img))
    x = einops.rearrange(x, '(a y) (b z) c -> (a b c) y z', a=4, b=4, c=4)
    return x

In [78]:
def evaluate_quality(sample):

    x = pil_to_grid3d(sample['image'])
    orig_dim = x.numel()
    
    t0 = time.time()
    buff_list = []
    for s in x:
        buff_list.append(io.BytesIO())
        img = to_pil_image(s)
        img.save(buff_list[-1], format='JPEG2000', quality_layers=[210])
    encode_time = time.time() - t0
    
    buff = io.BytesIO()
    img = to_pil_image(torch.zeros((64,64)))
    img.save(buff, format='JPEG2000', quality_layers=[1000])
    header_size = len(buff.getbuffer())
    cr = x.numel() / (header_size + sum(len(b.getbuffer())-header_size for b in buff_list))
    
    t0 = time.time()
    xhat = []
    for b in buff_list:
        xhati = PIL.Image.open(b)
        xhat.append(pil_to_tensor(xhati))
    xhat = torch.cat(xhat)
    decode_time = time.time() - t0
    
    x = x.to(torch.float)/255
    xhat = xhat.to(torch.float)/255
    
    mse = torch.nn.functional.mse_loss(x, xhat)
    psnr = -10 * mse.log10().item()

    return {
        'voxels': x.numel(),
        'cr': cr,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'psnr': psnr
    }

In [79]:
gpu_results = dataset['validation'].select(range(100)).map(evaluate_quality)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [80]:
metrics = [
    'voxels',
    'encode_time',
    'decode_time',
    'cr',
    'psnr',
]

In [81]:
print("mean\n---")
for metric in metrics:
    μ = np.mean(gpu_results[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(gpu_results['voxels'])/1e6/np.array(gpu_results['encode_time']))} MVox/sec")
print(f"{np.mean(np.array(gpu_results['voxels'])/1e6/np.array(gpu_results['decode_time']))} MVox/sec")

mean
---
voxels: 262144.0
encode_time: 0.038227436542510984
decode_time: 0.0052431726455688474
cr: 188.2600935526794
psnr: 16.01819063425064
6.971167760978013 MVox/sec
50.0155813867643 MVox/sec
