In [None]:
!wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_Li_27c_J3_nf4_v1.0.2.pth

In [1]:
import io
import time
import torch
import numpy as np
import PIL
from torchvision.transforms import ToPILImage, PILToTensor
from datasets import load_dataset, Image
from walloc import walloc
from piq import LPIPS, DISTS, psnr, multi_scale_ssim

class Config: pass

In [2]:
checkpoint = torch.load("RGB_Li_27c_J3_nf4_v1.0.2.pth",map_location="cpu",weights_only=False)
codec_config = checkpoint['config']
codec = walloc.Codec2D(
    channels = codec_config.channels,
    J = codec_config.J,
    Ne = codec_config.Ne,
    Nd = codec_config.Nd,
    latent_dim = codec_config.latent_dim,
    latent_bits = codec_config.latent_bits,
    lightweight_encode = codec_config.lightweight_encode
)
codec.load_state_dict(checkpoint['model_state_dict'])
codec.eval();

In [3]:
lpips_loss = LPIPS().to("cuda")
dists_loss = DISTS().to("cuda")



In [4]:
dataset = load_dataset("danjacobellis/LSDIR", split='validation')

Resolving data files:   0%|          | 0/195 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/195 [00:00<?, ?it/s]

In [5]:
def walloc_compress(sample):
    with torch.no_grad():
        img = sample['image']
        x = PILToTensor()(img).to(torch.float)
        x = (x/255 - 0.5).unsqueeze(0).to(device)
        H, W = x.size(2), x.size(3)
        x_padded = walloc.pad(x,p=8)

        t0 = time.time()
        X = codec.wavelet_analysis(x_padded,codec.J)
        Y = codec.encoder(X)
        webp = walloc.latent_to_pil(Y.to("cpu"),codec.latent_bits, 3)[0]
        buff = io.BytesIO()
        webp.save(buff, format='WEBP', lossless=True)
        webp_bytes = buff.getbuffer()
        encode_time = time.time() - t0

        t0 = time.time()
        Y = walloc.pil_to_latent([PIL.Image.open(buff)], codec.latent_dim, codec.latent_bits, 3).to(device)
        X_hat = codec.decoder(Y)
        x_hat = codec.wavelet_synthesis(X_hat,codec.J)
        x_hat = codec.clamp(x_hat)
        decode_time = time.time() - t0
        
        x_hat = walloc.crop(x_hat, (H,W))
        rec = ToPILImage()(x_hat[0] + 0.5)
        buff2 = io.BytesIO()
        rec.save(buff2, format='WEBP', lossless=True)
        rec_webp_bytes = buff2.getbuffer()

        bpp = 8*len(webp_bytes)/(H*W)
        PSNR = psnr(x+0.5,x_hat+0.5)
        MSSIM = multi_scale_ssim(x+0.5,x_hat+0.5)
        LPIPS_dB = -10*np.log10(lpips_loss(x.to("cuda")+0.5, x_hat.to("cuda")+0.5).item())
        DISTS_dB = -10*np.log10(dists_loss(x.to("cuda")+0.5, x_hat.to("cuda")+0.5).item())        
        
    return {
        'recovered': rec_webp_bytes,
        'compressed': webp_bytes,
        'encode_time': encode_time,
        'decode_time': decode_time,
        'bpp': bpp,
        'PSNR': PSNR,
        'MSSIM': MSSIM,
        'LPIPS_dB': LPIPS_dB,
        'DISTS_dB': DISTS_dB,
    }

In [11]:
device = "cuda"
codec = codec.to(device)
walloc_ds = dataset.select(range(2)).map(walloc_compress)
walloc_ds = walloc_ds.cast_column('recovered',Image())
walloc_ds = walloc_ds.cast_column('compressed',Image())

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

In [12]:
def walloc_compress_cpu(sample):
    with torch.no_grad():
        img = sample['image']
        x = PILToTensor()(img).to(torch.float)
        x = (x/255 - 0.5).unsqueeze(0).to(device)
        H, W = x.size(2), x.size(3)
        x_padded = walloc.pad(x,p=8)

        t0 = time.time()
        X = codec.wavelet_analysis(x_padded,codec.J)
        Y = codec.encoder(X)
        webp = walloc.latent_to_pil(Y.to("cpu"),codec.latent_bits, 3)[0]
        buff = io.BytesIO()
        webp.save(buff, format='WEBP', lossless=True)
        webp_bytes = buff.getbuffer()
        encode_time = time.time() - t0

        t0 = time.time()
        Y = walloc.pil_to_latent([PIL.Image.open(buff)], codec.latent_dim, codec.latent_bits, 3).to(device)
        X_hat = codec.decoder(Y)
        x_hat = codec.wavelet_synthesis(X_hat,codec.J)
        x_hat = codec.clamp(x_hat)
        decode_time = time.time() - t0
                
    return {
        'cpu_encode_time': encode_time,
        'cpu_decode_time': decode_time,
    }

In [13]:
device = "cpu"
codec = codec.to(device)
walloc_cpu = dataset.select(range(2)).map(walloc_compress_cpu)

In [14]:
walloc_ds = walloc_ds.add_column('cpu_encode_time',walloc_cpu['cpu_encode_time'])
walloc_ds = walloc_ds.add_column('cpu_decode_time',walloc_cpu['cpu_decode_time'])

In [15]:
walloc_ds[1]

{'path': 'LSDIR_val/val1/HR/val/0000130.png',
 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1152x768>,
 'w': 1152,
 'h': 768,
 'mode': 'RGB',
 'aspect': 1.5,
 'n_pixels': 884736,
 'recovered': <PIL.WebPImagePlugin.WebPImageFile image mode=RGB size=1152x768>,
 'compressed': <PIL.WebPImagePlugin.WebPImageFile image mode=RGB size=432x288>,
 'encode_time': 0.08014345169067383,
 'decode_time': 0.006960868835449219,
 'bpp': 1.1057038483796295,
 'PSNR': 31.71041488647461,
 'MSSIM': 0.9848196506500244,
 'LPIPS_dB': 8.433589351193366,
 'DISTS_dB': 15.761822997039431,
 'cpu_encode_time': 0.1143491268157959,
 'cpu_decode_time': 0.3262336254119873}

In [20]:
8*8*3*8/(4*16)

24.0