In [1]:
import io
import time
import torch
import datasets
import PIL.Image
import numpy as np
import torch.nn as nn
from types import SimpleNamespace
from piq import LPIPS, DISTS, SSIMLoss
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from torchvision.transforms.v2 import Pad, CenterCrop, PILToTensor
from torchvision.transforms.v2 import CenterCrop, PILToTensor, ToPILImage, Pad, CenterCrop

In [2]:
device = "cuda"
lpips_loss = LPIPS().to(device)
dists_loss = DISTS().to(device)
ssim_loss = SSIMLoss().to(device)
# kodak = datasets.load_dataset("danjacobellis/kodak", split='validation')
# lsdir = datasets.load_dataset("danjacobellis/LSDIR_val", split='validation')
inet = datasets.load_dataset("timm/imagenet-1k-wds",split='validation')
checkpoint = torch.load('../../hf/autocodec/rgb_f16c12_ft.pth', map_location="cpu",weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']
model = AutoCodecND(
    dim=2,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device).to(torch.bfloat16)
model.load_state_dict(state_dict)
model.eval();




Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

In [3]:
def evaluate_quality_h1024(sample):
    img = sample['jpg'].convert("RGB")
    aspect = img.width/img.height
    img = img.resize((int(16*(1024*aspect//16)),1024),resample=PIL.Image.Resampling.LANCZOS)
    if img.width>2048:
        img = CenterCrop((1024,2048))(img)
    x_orig = PILToTensor()(img).to(device).unsqueeze(0).to(torch.bfloat16) / 127.5 - 1.0
    orig_dim = x_orig.numel() 

    t0 = time.time()
    with torch.no_grad():
        z = model.encode(x_orig)
        latent = model.quantize.compand(z).round()
    webp = latent_to_pil(latent.cpu(), n_bits=8, C=3)
    buff = io.BytesIO()
    webp[0].save(buff, format='WEBP', lossless=True)
    encode_time = time.time() - t0
    size_bytes = len(buff.getbuffer())
    t0 = time.time()
    latent_decoded = pil_to_latent(webp, N=config.latent_dim, n_bits=8, C=3).to(device).to(torch.bfloat16)
    with torch.no_grad():
        x_hat = model.decode(latent_decoded).clamp(-1,1)
    decode_time = time.time() - t0

    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5

    pixels = img.width * img.height
    bpp = 8 * size_bytes / pixels
    mse = torch.nn.functional.mse_loss(x_orig_01[0], x_hat_01[0])
    PSNR = -10 * mse.log10().item()
    LPIPS_dB = -10 * np.log10(lpips_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item())
    DISTS_dB = -10 * np.log10(dists_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item())
    SSIM = 1 - ssim_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item()

    return {
        'pixels': pixels,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'bpp': bpp,
        'PSNR': PSNR,
        'LPIPS_dB': LPIPS_dB,
        'DISTS_dB': DISTS_dB,
        'SSIM': SSIM,
    }

In [4]:
metrics = [
    'pixels',
    'encode_time',
    'decode_time',
    'bpp',
    'PSNR',
    'LPIPS_dB',
    'DISTS_dB',
    'SSIM',
]

In [5]:
results_dataset = inet.map(lambda s: evaluate_quality_h1024(s))

print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['encode_time']))} MP/sec")

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [7]:
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['encode_time']))} MP/sec")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['decode_time']))} MP/sec")

pixels: 1231579.70944
encode_time: 0.009090588297843934
decode_time: 0.003910760917663575
bpp: 0.17156131121335294
PSNR: 31.0866
LPIPS_dB: 5.000525321339839
DISTS_dB: nan
SSIM: 0.871455390625
143.27379297185936 MP/sec
316.62148344109784 MP/sec


In [10]:
config

namespace(F=16,
          latent_dim=12,
          input_channels=3,
          encoder_depth=4,
          encoder_kernel_size=1,
          decoder_depth=8,
          lightweight_encode=True,
          lightweight_decode=False,
          freeze_encoder_after=0.7,
          λ=0.03,
          lr_pow=6,
          epochs=6,
          progressive_sizes=[480, 496, 512, 528, 544, 560],
          batch_size=12,
          max_lr=0.005333333333333333,
          min_lr=5.333333333333333e-06,
          num_workers=32,
          total_steps=560916,
          checkpoint='../../hf/autocodec/rgb_f16c12.pth')