In [1]:
import io
import time
import torch
import datasets
import PIL.Image
import numpy as np
import torch.nn as nn
from types import SimpleNamespace
from piq import LPIPS, DISTS, SSIMLoss
from huggingface_hub import snapshot_download
from cosmos_tokenizer.image_lib import ImageTokenizer
from torchvision.transforms.v2 import Pad, CenterCrop
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor

In [2]:
device = "cpu"
inet = datasets.load_dataset("timm/imagenet-1k-wds",split='validation')
model_path = snapshot_download(repo_id='nvidia/Cosmos-Tokenizer-DI8x8')
encoder = torch.jit.load(f'{model_path}/encoder.jit',map_location=device)
decoder = torch.jit.load(f'{model_path}/decoder.jit',map_location=device)

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

In [3]:
def evaluate_perf(sample, resolution_height=128):
    img = sample['jpg'].convert("RGB")
    aspect = img.width/img.height
    img = img.resize((int(16*(resolution_height*aspect//16)),resolution_height),resample=PIL.Image.Resampling.LANCZOS)
    x_orig = pil_to_tensor(img).to(device).unsqueeze(0).to(torch.float) / 127.5 - 1.0
    orig_dim = x_orig.numel() 

    t0 = time.time()
    with torch.no_grad():
        z = encoder.encoder(x_orig)

    torch.save(z,'temp.pth')
    encode_time = time.time() - t0
    size_bytes = 2*z.numel()
    t0 = time.time()
    z = torch.load("temp.pth")
    with torch.no_grad():
        x_hat = decoder.decoder(z).to(torch.float).clamp(-1,1)
    decode_time = time.time() - t0

    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5

    pixels = img.width * img.height
    bpp = 8 * size_bytes / pixels
    mse = torch.nn.functional.mse_loss(x_orig_01[0], x_hat_01[0])
    PSNR = -10 * mse.log10().item()

    return {
        'pixels': pixels,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'bpp': bpp,
        'PSNR': PSNR,
    }

In [6]:
metrics = [
    'pixels',
    'encode_time',
    'decode_time',
    'bpp',
    'PSNR',
]

In [None]:
results_dataset = inet.select(range(0,500,53)).map(lambda s: evaluate_perf(s))

print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['encode_time']))} MP/sec")

In [7]:
print("mean\n---")
for metric in metrics:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
print(f"{np.mean(np.array(results_dataset['pixels'])/1e6/np.array(results_dataset['encode_time']))} MP/sec")

mean
---
pixels: 18432.0
encode_time: 4.611590790748596
decode_time: 29.12102770805359
bpp: 64.0
PSNR: 9.125008463859558
0.004116413068231155 MP/sec
