In [1]:
!wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_Li_12c_J3_nf8_v1.0.2.pth

--2024-09-23 21:03:41--  https://hf.co/danjacobellis/walloc/resolve/main/RGB_Li_12c_J3_nf8_v1.0.2.pth
Resolving hf.co (hf.co)... 44.212.132.255, 34.204.155.59, 34.198.14.237, ...
Connecting to hf.co (hf.co)|44.212.132.255|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://huggingface.co/danjacobellis/walloc/resolve/main/RGB_Li_12c_J3_nf8_v1.0.2.pth [following]
--2024-09-23 21:03:41--  https://huggingface.co/danjacobellis/walloc/resolve/main/RGB_Li_12c_J3_nf8_v1.0.2.pth
Resolving huggingface.co (huggingface.co)... 108.156.211.95, 108.156.211.90, 108.156.211.51, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.95|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/f3/ba/f3ba06623731c38247420d7019770a013d176a409abc994df4fee215d214a026/6460e20646745084ddf80e8913df1d9f4c733f5827e65d07889b377e71adc9b5?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27R

In [2]:
!nvidia-smi --query-gpu=name --format=csv,noheader | head -n 1

NVIDIA GeForce RTX 4090


In [2]:
import io
import time
import torch
import numpy as np
import PIL
from torchvision.transforms import ToPILImage, PILToTensor
from datasets import load_dataset, Image
from walloc import walloc
from piq import LPIPS, DISTS, psnr, multi_scale_ssim

class Config: pass

In [3]:
checkpoint = torch.load("RGB_Li_12c_J3_nf8_v1.0.2.pth",map_location="cpu",weights_only=False)
codec_config = checkpoint['config']
codec = walloc.Codec2D(
    channels = codec_config.channels,
    J = codec_config.J,
    Ne = codec_config.Ne,
    Nd = codec_config.Nd,
    latent_dim = codec_config.latent_dim,
    latent_bits = codec_config.latent_bits,
    lightweight_encode = codec_config.lightweight_encode
)
codec.load_state_dict(checkpoint['model_state_dict'])
codec.eval();

In [4]:
lpips_loss = LPIPS().to("cuda")
dists_loss = DISTS().to("cuda")



In [5]:
LSDIR = load_dataset("danjacobellis/LSDIR_val", split='validation')

In [6]:
def walloc_compress(sample):
    with torch.no_grad():
        img = sample['image']
        x = PILToTensor()(img).to(torch.float)
        x = (x/255 - 0.5).unsqueeze(0).to(device)
        H, W = x.size(2), x.size(3)
        x_padded = walloc.pad(x,p=16)

        t0 = time.time()
        X = codec.wavelet_analysis(x_padded,codec.J)
        Y = codec.encoder(X)
        webp = walloc.latent_to_pil(Y.to("cpu"),codec.latent_bits, 3)[0]
        buff = io.BytesIO()
        webp.save(buff, format='WEBP', lossless=True)
        webp_bytes = buff.getbuffer()
        encode_time = time.time() - t0

        t0 = time.time()
        Y = walloc.pil_to_latent([PIL.Image.open(buff)], codec.latent_dim, codec.latent_bits, 3).to(device)
        X_hat = codec.decoder(Y)
        x_hat = codec.wavelet_synthesis(X_hat,codec.J)
        x_hat = codec.clamp(x_hat)
        decode_time = time.time() - t0
        
        x_hat = walloc.crop(x_hat, (H,W))
        rec = ToPILImage()(x_hat[0] + 0.5)
        buff2 = io.BytesIO()
        rec.save(buff2, format='WEBP', lossless=True)
        rec_webp_bytes = buff2.getbuffer()

        bpp = 8*len(webp_bytes)/(H*W)
        PSNR = psnr(x+0.5,x_hat+0.5)
        MSSIM = multi_scale_ssim(x+0.5,x_hat+0.5)
        LPIPS_dB = -10*np.log10(lpips_loss(x.to("cuda")+0.5, x_hat.to("cuda")+0.5).item())
        DISTS_dB = -10*np.log10(dists_loss(x.to("cuda")+0.5, x_hat.to("cuda")+0.5).item())        
        
    return {
        'recovered': rec_webp_bytes,
        'compressed': webp_bytes,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'bpp': bpp,
        'PSNR': PSNR,
        'MSSIM': MSSIM,
        'LPIPS_dB': LPIPS_dB,
        'DISTS_dB': DISTS_dB,
    }

In [7]:
device = "cuda"
codec = codec.to(device)
gpu = LSDIR.map(walloc_compress)
gpu = gpu.cast_column('recovered',Image())
gpu = gpu.cast_column('compressed',Image())

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [8]:
def walloc_compress_cpu(sample):
    with torch.no_grad():
        img = sample['image']
        x = PILToTensor()(img).to(torch.float)
        x = (x/255 - 0.5).unsqueeze(0).to(device)
        H, W = x.size(2), x.size(3)
        x_padded = walloc.pad(x,p=8)

        t0 = time.time()
        X = codec.wavelet_analysis(x_padded,codec.J)
        Y = codec.encoder(X)
        webp = walloc.latent_to_pil(Y.to("cpu"),codec.latent_bits, 3)[0]
        buff = io.BytesIO()
        webp.save(buff, format='WEBP', lossless=True)
        webp_bytes = buff.getbuffer()
        encode_time = time.time() - t0

        t0 = time.time()
        Y = walloc.pil_to_latent([PIL.Image.open(buff)], codec.latent_dim, codec.latent_bits, 3).to(device)
        X_hat = codec.decoder(Y)
        x_hat = codec.wavelet_synthesis(X_hat,codec.J)
        x_hat = codec.clamp(x_hat)
        decode_time = time.time() - t0
                
    return {
        'cpu_encode_time': encode_time,
        'cpu_decode_time': decode_time,
    }

In [9]:
device = "cpu"
codec = codec.to(device)
cpu = LSDIR.map(walloc_compress_cpu)
combined = gpu.add_column('cpu_encode_time',cpu['cpu_encode_time'])
combined = combined.add_column('cpu_decode_time',cpu['cpu_decode_time'])

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [10]:
metrics = [
     'encode_time',
     'decode_time',
     'bpp',
     'PSNR',
     'MSSIM',
     'LPIPS_dB',
     'DISTS_dB', 
     'cpu_encode_time',
     'cpu_decode_time',
]

In [11]:
for metric in metrics:
    μ = np.mean(combined[metric])
    print(f"{metric}: {μ}")

encode_time: 0.03576316738128662
decode_time: 0.008054869651794434
bpp: 0.6819637504217378
PSNR: 27.452449684143065
MSSIM: 0.9699939274787903
LPIPS_dB: 6.513075175518463
DISTS_dB: 13.851437656052818
cpu_encode_time: 0.05944069290161133
cpu_decode_time: 2.847661691665649


In [12]:
combined.push_to_hub("danjacobellis/RGB_Li_12c_J3_nf8",split='validation')

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/RGB_Li_12c_J3_nf8/commit/b40347b64bb170059fc6fd144be922257a1284e3', commit_message='Upload dataset', commit_description='', oid='b40347b64bb170059fc6fd144be922257a1284e3', pr_url=None, pr_revision=None, pr_num=None)