In [1]:
import sys
import json
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict, OrderedDict
from PIL import Image
import numpy as np
import torch
import torchvision.transforms.functional as tvf
torch.set_grad_enabled(False)

sys.path.append('../')
from models.library import qres34m
from models.qresvae import pad_divisible_by
device = torch.device('cuda:0')


  from .autonotebook import tqdm as notebook_tqdm


In [12]:

@torch.no_grad()
def evaluate_model(model, dataset_root):
    device = next(model.parameters()).device
    img_paths = list(Path(dataset_root).rglob('*.*'))
    img_paths.sort()

    all_image_stats = defaultdict(float)
    pbar = tqdm(img_paths)
    for impath in pbar:
        # read image
        img = Image.open(impath)
        imgh, imgw = img.height, img.width
        img_padded = pad_divisible_by(img, div=64)
        im = tvf.to_tensor(img_padded).unsqueeze_(0).to(device=device)

        stats = model.forward_eval(im, return_rec=True)
        im_hat = stats['im_hat']
        bpp_estimateed = float(stats['bppix']) * (im.shape[2]*im.shape[3]) / (imgh * imgw)

        model.compress_file(impath, 'tmp.bits')
        num_bits = Path('tmp.bits').stat().st_size * 8

        # compute psnr
        real = np.array(img).astype(np.float32) / 255.0
        fake = im_hat.cpu().squeeze(0).permute(1,2,0)[:imgh, :imgw, :].numpy()
        mse = np.square(real - fake).mean()
        stats = {
            'psnr': float(-10 * np.log10(mse)),
            'bpp' : num_bits / (imgh * imgw),
            'bpp-estimated': bpp_estimateed
        }
        # accumulate stats
        all_image_stats['count'] += 1
        for k,v in stats.items():
            all_image_stats[k] += float(v)
        # logging
        msg = ', '.join([f'{k}={float(v):.3f}' for k,v in stats.items()])
        pbar.set_description(f'image {impath.stem}: {msg}')

    # average over all images
    count = all_image_stats.pop('count')
    results = {k: v/count for k,v in all_image_stats.items()}
    return results


def evaluate(root):
    weights_root = Path('../checkpoints/qres34m')
    # save_json_path = 'results/tmp-qres34m.json'

    all_lmb_stats = defaultdict(list)
    for lmb in [16, 32, 64, 128, 256, 512, 1024, 2048]:
        # initialize model
        model = qres34m(lmb=lmb)

        wpath = weights_root / f'lmb{lmb}/last_ema.pt'
        msd = torch.load(wpath)['model']
        model.load_state_dict(msd)

        print(f'Evaluating lmb={lmb}. Model weights={wpath}')
        model.compress_mode()
        model = model.cuda()
        model.eval()

        results = evaluate_model(model, root)
        # print(results)
        for k,v in results.items():
            all_lmb_stats[k].append(v)

        # save to json
        # with open(save_json_path, 'w') as f:
        #     json.dump(all_lmb_stats, fp=f, indent=4)

    for k, vlist in all_lmb_stats.items():
        vlist_str = ' & '.join([f'{v:.12f}'[:7] for v in vlist])
        print(f'{k:<6s} = [{vlist_str}]')


In [10]:
evaluate('d:/datasets/improcessing/kodak')

Evaluating lmb=16. Model weights=..\checkpoints\qres34m\lmb16\last_ema.pt


image kodim24: psnr=28.858, bpp=0.287, bpp-estimated=0.282: 100%|██████████| 24/24 [00:04<00:00,  5.01it/s]


Evaluating lmb=32. Model weights=..\checkpoints\qres34m\lmb32\last_ema.pt


image kodim24: psnr=30.925, bpp=0.449, bpp-estimated=0.444: 100%|██████████| 24/24 [00:04<00:00,  5.41it/s]


Evaluating lmb=64. Model weights=..\checkpoints\qres34m\lmb64\last_ema.pt


image kodim24: psnr=33.226, bpp=0.643, bpp-estimated=0.639: 100%|██████████| 24/24 [00:04<00:00,  5.41it/s]


Evaluating lmb=128. Model weights=..\checkpoints\qres34m\lmb128\last_ema.pt


image kodim24: psnr=35.362, bpp=0.913, bpp-estimated=0.909: 100%|██████████| 24/24 [00:04<00:00,  5.37it/s]


Evaluating lmb=256. Model weights=..\checkpoints\qres34m\lmb256\last_ema.pt


image kodim24: psnr=37.556, bpp=1.253, bpp-estimated=1.248: 100%|██████████| 24/24 [00:04<00:00,  5.25it/s]


Evaluating lmb=512. Model weights=..\checkpoints\qres34m\lmb512\last_ema.pt


image kodim24: psnr=39.634, bpp=1.629, bpp-estimated=1.624: 100%|██████████| 24/24 [00:04<00:00,  5.33it/s]


Evaluating lmb=1024. Model weights=..\checkpoints\qres34m\lmb1024\last_ema.pt


image kodim24: psnr=41.516, bpp=2.147, bpp-estimated=2.143: 100%|██████████| 24/24 [00:04<00:00,  5.21it/s]


Evaluating lmb=2048. Model weights=..\checkpoints\qres34m\lmb2048\last_ema.pt


image kodim24: psnr=43.426, bpp=2.785, bpp-estimated=2.780: 100%|██████████| 24/24 [00:04<00:00,  5.31it/s]

psnr   = [30.0210 & 31.9801 & 33.8986 & 36.1126 & 38.1649 & 40.2613 & 42.2478 & 44.3549]
bpp    = [0.18352 & 0.30125 & 0.45200 & 0.67388 & 0.95406 & 1.28697 & 1.74814 & 2.35659]
bpp-estimated = [0.17960 & 0.29680 & 0.44780 & 0.66960 & 0.94993 & 1.28291 & 1.74430 & 2.35219]





In [13]:
evaluate('d:/datasets/improcessing/clic/test-2022')

Evaluating lmb=16. Model weights=..\checkpoints\qres34m\lmb16\last_ema.pt


image 8019e85654193a14b938689e3cf06b790d39197eb5a57f9de83f7b58d2e3302c: psnr=32.175, bpp=0.067, bpp-estimated=0.067: 100%|██████████| 30/30 [00:42<00:00,  1.42s/it]


Evaluating lmb=32. Model weights=..\checkpoints\qres34m\lmb32\last_ema.pt


image 8019e85654193a14b938689e3cf06b790d39197eb5a57f9de83f7b58d2e3302c: psnr=34.253, bpp=0.109, bpp-estimated=0.109: 100%|██████████| 30/30 [00:43<00:00,  1.46s/it]


Evaluating lmb=64. Model weights=..\checkpoints\qres34m\lmb64\last_ema.pt


image 8019e85654193a14b938689e3cf06b790d39197eb5a57f9de83f7b58d2e3302c: psnr=35.739, bpp=0.161, bpp-estimated=0.160: 100%|██████████| 30/30 [00:44<00:00,  1.47s/it]


Evaluating lmb=128. Model weights=..\checkpoints\qres34m\lmb128\last_ema.pt


image 8019e85654193a14b938689e3cf06b790d39197eb5a57f9de83f7b58d2e3302c: psnr=37.658, bpp=0.259, bpp-estimated=0.259: 100%|██████████| 30/30 [00:44<00:00,  1.47s/it]


Evaluating lmb=256. Model weights=..\checkpoints\qres34m\lmb256\last_ema.pt


image 8019e85654193a14b938689e3cf06b790d39197eb5a57f9de83f7b58d2e3302c: psnr=39.489, bpp=0.409, bpp-estimated=0.410: 100%|██████████| 30/30 [00:44<00:00,  1.47s/it]


Evaluating lmb=512. Model weights=..\checkpoints\qres34m\lmb512\last_ema.pt


image 8019e85654193a14b938689e3cf06b790d39197eb5a57f9de83f7b58d2e3302c: psnr=41.394, bpp=0.611, bpp-estimated=0.612: 100%|██████████| 30/30 [00:43<00:00,  1.47s/it]


Evaluating lmb=1024. Model weights=..\checkpoints\qres34m\lmb1024\last_ema.pt


image 8019e85654193a14b938689e3cf06b790d39197eb5a57f9de83f7b58d2e3302c: psnr=43.096, bpp=0.941, bpp-estimated=0.942: 100%|██████████| 30/30 [00:46<00:00,  1.57s/it]


Evaluating lmb=2048. Model weights=..\checkpoints\qres34m\lmb2048\last_ema.pt


image 8019e85654193a14b938689e3cf06b790d39197eb5a57f9de83f7b58d2e3302c: psnr=45.010, bpp=1.430, bpp-estimated=1.434: 100%|██████████| 30/30 [00:43<00:00,  1.46s/it]

psnr   = [30.6719 & 32.7126 & 34.1318 & 36.2879 & 38.2443 & 40.2436 & 42.1072 & 44.0814]
bpp    = [0.15405 & 0.24315 & 0.35457 & 0.54065 & 0.79773 & 1.10183 & 1.55027 & 2.14379]
bpp-estimated = [0.15370 & 0.24236 & 0.35435 & 0.54013 & 0.79785 & 1.10251 & 1.55179 & 2.14681]



