In [8]:
import json
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict, OrderedDict
from PIL import Image
import torch
import torchvision.transforms.functional as tvf
torch.set_grad_enabled(False)


# test set root directory
images_root = Path('d:/datasets/improcessing/clic/test-2022')
# resolutions to test
resolutions = [192, 256, 384, 512, 768, 1024, 1536, 2048]

device = torch.device('cuda:0')

In [9]:
# image paths
img_paths = list(images_root.rglob('*.*'))
# test set statistics
images = [Image.open(impath) for impath in img_paths]
longsides = [max(img.height, img.width) for img in images]
print(f'image longest side: min={min(longsides)}, max={max(longsides)}')

image longest side: min=2048, max=2048


In [10]:
def scale_image(img: Image.Image, longside: int):
    old_hw = (img.height, img.width)
    if max(old_hw) != longside:
        ratio = longside / max(old_hw)
        new_h = round(old_hw[0] * ratio)
        new_w = round(old_hw[1] * ratio)
        img = img.resize(size=(new_w, new_h), resample=Image.BICUBIC)
    return img

def crop_divisible(img: Image.Image, div=64):
    old_hw = (img.height, img.width)
    if old_hw[0] % div == 0 and old_hw[1] % div == 0:
        return img
    h_new = div * (old_hw[0] // div)
    w_new = div * (old_hw[1] // div)
    top = (old_hw[0] - h_new) // 2
    left = (old_hw[1] - w_new) // 2
    img = img.crop(box=[left, top, left+w_new, top+h_new])
    return img

@torch.no_grad()
def evaluate_model(model, im_longside):
    all_image_stats = defaultdict(float)
    for img in images:
        img = scale_image(img, im_longside)
        img = crop_divisible(img)
        im = tvf.to_tensor(img).unsqueeze_(0).to(device=device)
        stats = model.forward_eval(im)
        for k, v in stats.items():
            all_image_stats[k] += float(v)
        all_image_stats['count'] += 1
    # average over all images
    count = all_image_stats.pop('count')
    results = {k: v/count for k,v in all_image_stats.items()}
    return results

In [4]:
from mycv.models.vae.qres import qres34m
# model checkpoints root directory
weights_root = Path.cwd() / 'd:/projects/mycv-pytorch/mycv/weights/my-vaes/dh-64s4x'
# model checkpoints root directory
save_json_path = Path().cwd() / '../results/clic-test-qres34m.json'

def evaluate_all_models_at_res(im_longside):
    all_model_stats = defaultdict(list)
    for lmb in [16, 32, 64, 128, 256, 512, 1024, 2048]:
        # initialize model
        model = qres34m(lmb=lmb)

        wpath = weights_root / f'dh_64s4x-lmb{lmb}/last_ema.pt'
        model.load_state_dict(torch.load(wpath)['model'])
        # uncomment to get real entropy coding bpp (rather than theoretical bpp)
        # model.compress_mode()

        model = model.to(device=device).eval()
        results = evaluate_model(model, im_longside=im_longside)

        all_model_stats['lambda'].append(lmb)
        for k,v in results.items():
            all_model_stats[k].append(v)
    return all_model_stats


all_resoutions_stats = OrderedDict()
for res in tqdm(resolutions):
    stats = evaluate_all_models_at_res(im_longside=res)

    all_resoutions_stats[res] = stats
    with open(save_json_path, 'w') as f:
        json.dump(all_resoutions_stats, fp=f, indent=4)

100%|██████████| 8/8 [04:36<00:00, 34.60s/it]


In [11]:
import compressai.zoo
import torch.nn.functional as tnf

class CompressAIWrapper():
    def __init__(self, name, q):
        super().__init__()
        self.nic_model = compressai.zoo.models[name](quality=q, pretrained=True)
        self.nic_model = self.nic_model.to(device=device)
        self.nic_model.eval()

    def forward_eval(self, im):
        output = self.nic_model(im)
        mse = tnf.mse_loss(output['x_hat'], im, reduction='mean')
        likelihoods = output['likelihoods']
        imH, imW = im.shape[2:4]
        bpp_z   = - torch.log2(likelihoods['y']).sum(dim=(1,2,3)).mean(0) / float(imH * imW)
        bpp_hyp = - torch.log2(likelihoods['z']).sum(dim=(1,2,3)).mean(0) / float(imH * imW)
        bpp = bpp_z + bpp_hyp

        # logging
        stats = {
            'bppix': bpp.cpu().item(),
            'psnr': -10.0 * torch.log10(mse.cpu()).item(),
        }
        return stats

def evaluate_compressai_at_res(name, im_longside):
    all_model_stats = defaultdict(list)
    max_q = 6 if (name == 'cheng2020-anchor') else 8
    for q in range(1, max_q+1):
        # initialize model
        model = CompressAIWrapper(name, q)
        results = evaluate_model(model, im_longside=im_longside)

        all_model_stats['quality'].append(q)
        for k,v in results.items():
            all_model_stats[k].append(v)
    return all_model_stats

def evaluate_compressai_at_all_res(name):
    save_json_path = Path().cwd() / f'../results/clic-test-{name}.json'
    all_resoutions_stats = OrderedDict()
    pbar = tqdm(resolutions)
    for res in pbar:
        pbar.set_description(f'model: {name}, resolution: {res}')
        stats = evaluate_compressai_at_res(name, im_longside=res)

        all_resoutions_stats[res] = stats
        with open(save_json_path, 'w') as f:
            json.dump(all_resoutions_stats, fp=f, indent=4)

In [12]:
# model checkpoints root directory
evaluate_compressai_at_all_res('cheng2020-anchor')

model: cheng2020-anchor, resolution: 2048: 100%|██████████| 8/8 [02:06<00:00, 15.79s/it]


In [14]:
# model checkpoints root directory
evaluate_compressai_at_all_res('mbt2018')

model: mbt2018, resolution: 2048: 100%|██████████| 8/8 [01:41<00:00, 12.73s/it]
