In [None]:
!wget https://huggingface.co/danjacobellis/dance/resolve/main/LF_rgb_f16c12_v1.6.pth

In [1]:
import torch
import torch.nn as nn
import PIL.Image
import io
import numpy as np
import datasets
import time
from types import SimpleNamespace
from datasets import Dataset
from torchvision.transforms.v2 import CenterCrop
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor
from autocodec.codec import AutoCodecND, latent_to_pil, pil_to_latent

In [39]:
def compress_and_evaluate(sample, quality=0.1):
    img = sample['image']
    x_orig = pil_to_tensor(img).to(device).unsqueeze(0).to(torch.float) / 127.5 - 1.0
    orig_dim = x_orig.numel() 

    t0 = time.time()
    with torch.no_grad():
        z = model.encode(x_orig)
        latent = model.quantize.compand(z).round()
    webp = latent_to_pil(latent.cpu(), n_bits=8, C=3)
    buff = io.BytesIO()
    webp[0].save(buff, format='WEBP', lossless=True)
    neural_encode_time = time.time() - t0
    
    return {
        'neural_encode_time': neural_encode_time,
        'mp':orig_dim/3e6
    }

In [32]:
device = "cpu"
valid_dataset = datasets.load_dataset("danjacobellis/LSDIR_val", split='validation').select(range(10))

In [40]:
checkpoint = torch.load('LF_rgb_f16c12_v1.6.pth', map_location=device,weights_only=False)
config = checkpoint['config']
# state_dict = checkpoint['state_dict']

model = AutoCodecND(
    dim=2,
    input_channels=config.input_channels,
    J=4,
    latent_dim=12,
    encoder_depth = 6,
    encoder_kernel_size=3,
    decoder_depth = 1,
    lightweight_encode=config.lightweight_encode,
    lightweight_decode=config.lightweight_decode,
).to(device)
# model.load_state_dict(state_dict)
model.eval();
print(sum(p.numel() for p in model.encoder_blocks.parameters())/1e6)
model = torch.compile(model)

4.868352


In [38]:
print('ks=1, d=1')
r = valid_dataset.map(lambda batch: compress_and_evaluate(batch,0.45)).with_format('torch')
r['mp'] / r['neural_encode_time']

ks=1, d=1


tensor([0.4837, 0.6055, 0.5899, 0.5406, 0.5649, 0.5846, 0.5747, 0.6035, 0.6152,
        0.6324])

In [41]:
print('ks=3, d=6')
r = valid_dataset.map(lambda batch: compress_and_evaluate(batch,0.45)).with_format('torch')
r['mp'] / r['neural_encode_time']

ks=3, d=6


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

tensor([0.1688, 0.1711, 0.1713, 0.1725, 0.1811, 0.1735, 0.1797, 0.1820, 0.1745,
        0.1773])