In [1]:
import torch, io, datasets, PIL.Image,  numpy as np
from huggingface_hub import hf_hub_download
from types import SimpleNamespace
from piq import LPIPS, DISTS, SSIMLoss
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor

In [2]:
device = "cuda"
dataset = datasets.load_dataset("danjacobellis/kodak")
checkpoint_file = hf_hub_download(
    repo_id="danjacobellis/autocodec",
    filename="rgb_f16c48_ft.pth"
)
checkpoint = torch.load(checkpoint_file, map_location="cpu",weights_only=False)
config = checkpoint['config']
codec = AutoCodecND(
    dim=2,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device).to(torch.bfloat16)
codec.load_state_dict(checkpoint['state_dict'])
codec.eval();

lpips_loss = LPIPS().to(device)
dists_loss = DISTS().to(device)
ssim_loss = SSIMLoss().to(device)



In [3]:
def evaluate_quality(sample):
    img = sample['image'].convert("RGB")
    x_orig = pil_to_tensor(img).to(device).unsqueeze(0).to(torch.bfloat16) / 127.5 - 1.0
    orig_dim = x_orig.numel() 
    with torch.no_grad():
        z = codec.encode(x_orig)
        latent = codec.quantize.compand(z).round()
    webp = latent_to_pil(latent.cpu(), n_bits=8, C=3)
    buff = io.BytesIO()
    webp[0].save(buff, format='WEBP', lossless=True)
    size_bytes = len(buff.getbuffer())
    latent_decoded = pil_to_latent(webp, N=config.latent_dim, n_bits=8, C=3).to(device).to(torch.bfloat16)
    with torch.no_grad():
        x_hat = codec.decode(latent_decoded).clamp(-1,1)
    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5
    pixels = img.width * img.height
    bpp = 8 * size_bytes / pixels
    mse = torch.nn.functional.mse_loss(x_orig_01[0], x_hat_01[0])
    PSNR = -10 * mse.log10().item()
    LPIPS_dB = -10 * np.log10(lpips_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item())
    DISTS_dB = -10 * np.log10(dists_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item())
    SSIM = 1 - ssim_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item()

    return {
        'pixels': pixels,
        'bpp': bpp,
        'PSNR': PSNR,
        'LPIPS_dB': LPIPS_dB,
        'DISTS_dB': DISTS_dB,
        'SSIM': SSIM,
    }

In [4]:
results_dataset = dataset['validation'].map(evaluate_quality)

In [5]:
print("mean\n---")
for metric in [
    'pixels',
    'bpp',
    'PSNR',
    'LPIPS_dB',
    'DISTS_dB',
    'SSIM',
]:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")

mean
---
pixels: 393216.0
bpp: 0.7802615695529515
PSNR: 31.03515625
LPIPS_dB: 7.04490827424515
DISTS_dB: 13.877594294056875
SSIM: 0.8351236979166666
