In [1]:
import io
import time
import torch
import datasets
import PIL.Image
import numpy as np
import torch.nn as nn
from types import SimpleNamespace
from piq import LPIPS, DISTS, SSIMLoss
from huggingface_hub import snapshot_download
from cosmos_tokenizer.image_lib import ImageTokenizer
from torchvision.transforms.v2 import Pad, CenterCrop
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor

In [2]:
device = "cuda"
lpips_loss = LPIPS().to(device)
dists_loss = DISTS().to(device)
ssim_loss = SSIMLoss().to(device)
kodak = datasets.load_dataset("danjacobellis/kodak", split='validation')
lsdir = datasets.load_dataset("danjacobellis/LSDIR_val", split='validation')
inet = datasets.load_dataset("timm/imagenet-1k-wds",split='validation')
model_path = snapshot_download(repo_id='nvidia/Cosmos-Tokenizer-DI8x8')
encoder = ImageTokenizer(checkpoint_enc=f'{model_path}/encoder.jit').to(device)
decoder = ImageTokenizer(checkpoint_dec=f'{model_path}/decoder.jit').to(device)



Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

In [3]:
def evaluate_quality_h1024(sample):
    img = sample['jpg'].convert("RGB")
    aspect = img.width/img.height
    img = img.resize((int(16*(1024*aspect//16)),1024),resample=PIL.Image.Resampling.LANCZOS)
    x_orig = pil_to_tensor(img).to(device).unsqueeze(0).to(torch.float) / 127.5 - 1.0
    orig_dim = x_orig.numel() 

    t0 = time.time()
    with torch.no_grad():
        z, _ = encoder.encode(x_orig)
    encode_time = time.time() - t0
    size_bytes = 2*z.numel()
    t0 = time.time()
    with torch.no_grad():
        z, _ = encoder.encode(x_orig)
        x_hat = decoder.decode(z).to(torch.float).clamp(-1,1)
    decode_time = time.time() - t0

    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5

    pixels = img.width * img.height
    bpp = 8 * size_bytes / pixels
    mse = torch.nn.functional.mse_loss(x_orig_01[0], x_hat_01[0])
    PSNR = -10 * mse.log10().item()
    LPIPS_dB = -10 * np.log10(lpips_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item())
    DISTS_dB = -10 * np.log10(dists_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item())
    SSIM = 1 - ssim_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item()

    return {
        'encode_time': encode_time,
        'decode_time': decode_time,
        'bpp': bpp,
        'PSNR': PSNR,
        'LPIPS_dB': LPIPS_dB,
        'DISTS_dB': DISTS_dB,
        'SSIM': SSIM,
    }

In [None]:
results_dataset = inet.map(evaluate_quality_h1024)



Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [20]:
print("mean\n---")
for metric in [
    'bpp',
    'PSNR',
    'LPIPS_dB',
    'DISTS_dB',
    'SSIM',
]:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")

mean
---
encode_time: 0.05055922746658325
decode_time: 0.10153760433197022
bpp: 0.0625
PSNR: 24.753208494186403
LPIPS_dB: 5.816998623644801
DISTS_dB: 11.682074767345494
SSIM: 0.8084122937917709
