In [1]:
%env CUDA_VISIBLE_DEVICES=MIG-cbafb023-40ef-594e-9092-fb0e3c44baa2

env: CUDA_VISIBLE_DEVICES=MIG-cbafb023-40ef-594e-9092-fb0e3c44baa2


In [2]:
import io
import time
import torch
import torch.nn as nn
import PIL.Image
import io
import numpy as np
import datasets
from torchvision.transforms.v2 import Pad, CenterCrop
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from types import SimpleNamespace
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from piq import LPIPS, DISTS, SSIMLoss

In [3]:
device = "cuda"
lpips_loss = LPIPS().to(device)
dists_loss = DISTS().to(device)
ssim_loss = SSIMLoss().to(device)



In [4]:
inet = datasets.load_dataset("timm/imagenet-1k-wds",split='validation')

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Using the latest cached version of the dataset since timm/imagenet-1k-wds couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/dgj335/.cache/huggingface/datasets/timm___imagenet-1k-wds/default/0.0.0/cdf403ce12f01022a0c36e584e588c0b9cebc4af (last modified on Fri Apr  4 13:57:51 2025).


In [5]:
checkpoint = torch.load('../../hf/dance/inet_ft_rgb_f8c12.pth', map_location=device,weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']

model = AutoCodecND(
    dim=2,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device)
model.load_state_dict(state_dict)
model.eval();

In [6]:
def compress_and_evaluate(sample):
    img = sample['jpg'].convert("RGB")
    aspect = img.width/img.height
    img = img.resize((int(16*(1024*aspect//16)),1024),resample=PIL.Image.Resampling.LANCZOS)
    x_orig = pil_to_tensor(img).to(device).unsqueeze(0).to(torch.float) / 127.5 - 1.0
    orig_dim = x_orig.numel() 

    t0 = time.time()
    with torch.no_grad():
        z = model.encode(x_orig)
        latent = model.quantize.compand(z).round()
    webp = latent_to_pil(latent.cpu(), n_bits=8, C=3)
    buff = io.BytesIO()
    webp[0].save(buff, format='WEBP', lossless=True)
    encode_time = time.time() - t0
    size_bytes = len(buff.getbuffer())
    t0 = time.time()
    latent_decoded = pil_to_latent(webp, N=config.latent_dim, n_bits=8, C=3).to(device)
    with torch.no_grad():
        x_hat = model.decode(latent_decoded).clamp(-1,1)
    decode_time = time.time() - t0

    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5

    pixels = img.width * img.height
    bpp = 8 * size_bytes / pixels
    mse = torch.nn.functional.mse_loss(x_orig_01[0], x_hat_01[0])
    PSNR = -10 * mse.log10().item()
    LPIPS_dB = -10 * np.log10(lpips_loss(x_orig_01.to(device), x_hat_01.to(device)).item())
    DISTS_dB = -10 * np.log10(dists_loss(x_orig_01.to(device), x_hat_01.to(device)).item())
    SSIM = 1 - ssim_loss(x_orig_01.to(device), x_hat_01.to(device)).item()

    return {
        'pixels': pixels,
        'bpp': bpp,
        'PSNR': PSNR,
        'LPIPS_dB': LPIPS_dB,
        'DISTS_dB': DISTS_dB,
        'SSIM': SSIM,
        'encode_time': encode_time,
        'decode_time': decode_time,
    }

metrics = [
    'pixels',
    'bpp',
    'PSNR',
    'LPIPS_dB',
    'DISTS_dB',
    'SSIM',
    'encode_time',
    'decode_time',
]

f16c12

In [None]:
results_dataset = inet.select(range(100)).map(compress_and_evaluate)

print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['encode_time']))} MP/sec")

f8c12 lsdir

In [None]:
results_dataset = inet.select(range(100)).map(compress_and_evaluate)

print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['encode_time']))} MP/sec")

f8c12 lsdir+inet

In [7]:
results_dataset = inet.map(compress_and_evaluate)

print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['encode_time']))} MP/sec")

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

mean
---
pixels: 1233140.44928
bpp: 0.6455606527768093
PSNR: 38.895700521564486
LPIPS_dB: 10.482115628613725
DISTS_dB: 19.412052084736796
SSIM: 0.9871415659821033
encode_time: 0.058272379055023193
decode_time: 0.007950794057846069
23.599175936410298 MP/sec


f8c12 lsdir+inet+lsdir

In [None]:
results_dataset = inet.select(range(100)).map(compress_and_evaluate)

print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['encode_time']))} MP/sec")