In [None]:
!wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_Li_48c_J4_nf8_v1.0.2.pth

In [1]:
!lscpu

Architecture:                    aarch64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
CPU(s):                          4
On-line CPU(s) list:             0-3
Thread(s) per core:              1
Core(s) per socket:              4
Socket(s):                       1
Vendor ID:                       ARM
Model:                           3
Model name:                      Cortex-A72
Stepping:                        r0p3
CPU max MHz:                     1500.0000
CPU min MHz:                     600.0000
BogoMIPS:                        108.00
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; __user pointer sanitization
Vulnerability Spectre v2:        Vulnerable
Vulnerability Srbds:             Not affected
Vulnerability Tsx

In [2]:
import io
import time
import torch
import numpy as np
import PIL
from torchvision.transforms import ToPILImage, PILToTensor
from datasets import load_dataset, Image
from walloc import walloc

class Config: pass

In [3]:
checkpoint = torch.load("RGB_Li_48c_J4_nf8_v1.0.2.pth",map_location="cpu",weights_only=False)
codec_config = checkpoint['config']
codec = walloc.Codec2D(
    channels = codec_config.channels,
    J = codec_config.J,
    Ne = codec_config.Ne,
    Nd = codec_config.Nd,
    latent_dim = codec_config.latent_dim,
    latent_bits = codec_config.latent_bits,
    lightweight_encode = codec_config.lightweight_encode
)
codec.load_state_dict(checkpoint['model_state_dict'])
# del codec.decoder
codec.eval();

In [4]:
dataset = load_dataset("danjacobellis/LSDIR_val", split='validation')

In [5]:
def walloc_compress(sample):
    with torch.no_grad():
        img = sample['image']
        x = PILToTensor()(img).to(torch.float)
        x = (x/255 - 0.5).unsqueeze(0).to(device)
        H, W = x.size(2), x.size(3)
        x_padded = walloc.pad(x,p=8)

        t0 = time.time()
        X = codec.wavelet_analysis(x_padded,codec.J)
        Y = codec.encoder(X)
        webp = walloc.latent_to_pil(Y.to("cpu"),codec.latent_bits, 3)[0]
        buff = io.BytesIO()
        webp.save(buff, format='WEBP', lossless=True)
        webp_bytes = buff.getbuffer()
        encode_time = time.time() - t0

        t0 = time.time()
        Y = walloc.pil_to_latent([PIL.Image.open(buff)], codec.latent_dim, codec.latent_bits, 3).to(device)
        X_hat = codec.decoder(Y)
        x_hat = codec.wavelet_synthesis(X_hat,codec.J)
        x_hat = codec.clamp(x_hat)
        decode_time = time.time() - t0
        
    return {
        'compressed': webp_bytes,
        'encode_time': encode_time,
        'decode_time': decode_time,
    }

In [6]:
device = "cpu"
codec = codec.to(device)
walloc_ds = dataset.map(walloc_compress, writer_batch_size=10)
walloc_ds = walloc_ds.cast_column('compressed',Image())

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [9]:
np.median(walloc_ds['encode_time'])

1.3567014932632446

In [22]:
enc_px_per_sec = np.array(walloc_ds['n_pixels'])/np.array(walloc_ds['encode_time'])
print(f"Encode: {np.mean(enc_px_per_sec)/1e6 : .5g} megapixels per second")

Encode:  1.0085 megapixels per second


In [23]:
dec_sec_per_px = np.array(walloc_ds['decode_time'])/np.array(walloc_ds['n_pixels'])
print(f"Decode: {1e6*np.mean(dec_sec_per_px) : .5g} seconds per megapixel")

Decode:  31.951 seconds per megapixel


In [25]:
upload_ds = walloc_ds.remove_columns(['image','compressed'])

In [28]:
upload_ds.push_to_hub("danjacobellis/LSDIR_walloc_4x_raspi",split='validation')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/LSDIR_walloc_4x_raspi/commit/fcaaa886d63680d0e1a63a1a82d8f7c5538dfb0f', commit_message='Upload dataset', commit_description='', oid='fcaaa886d63680d0e1a63a1a82d8f7c5538dfb0f', pr_url=None, pr_revision=None, pr_num=None)