In [1]:
import pickle
from pathlib import Path
from PIL import Image
import torch
import torchvision as tv
import torchvision.transforms.functional as tvf

torch.set_grad_enabled(False)

import sys
sys.path.append(str(Path().cwd().parent))
from models.library import qres34m_lossless

Initialize model and load pre-trained weights

In [2]:
model = qres34m_lossless()

msd = torch.load('../checkpoints/qres34m-lossless/last_ema.pt')['model']
model.load_state_dict(msd)

model.eval()
model.compress_mode()

Compress an RGB image

In [3]:
img_path = '../images/collie128.png'

im = tvf.to_tensor(Image.open(img_path)).unsqueeze_(0)
compressed_obj = model.compress(im)

Save to file, compute bit rate

In [4]:
save_path = '../results/image.bits'
with open(save_path, 'wb') as f:
    pickle.dump(compressed_obj, file=f)

total_bits = Path(save_path).stat().st_size * 8
bpp = total_bits / (im.shape[2] * im.shape[3])
print(f'Compressed file size: {total_bits} bits = {bpp:.6f} bpp')

Compressed file size: 286112 bits = 17.462891 bpp


Decompress and reconstruct the image

In [5]:
with open(save_path,'rb') as f:
    compressed_obj = pickle.load(file=f)

im_hat = model.decompress(compressed_obj).squeeze(0).cpu()

Check if the compression is lossless

In [6]:
real = tv.io.read_image(str(img_path))
fake = torch.round_(im_hat * 255.0).to(dtype=torch.uint8)

torch.equal(real, fake)

True