In [1]:
import sys
import pickle
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict
from PIL import Image
import torch
import torchvision.transforms.functional as tvf
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1fd0873b130>

In [2]:
def get_object_bits(obj):
    return sys.getsizeof(pickle.dumps(obj)) * 8

def evaluate_model(model, dataset_root):
    device = next(model.parameters()).device
    img_paths = Path(dataset_root).rglob('*.*')

    all_image_stats = defaultdict(float)
    pbar = tqdm(img_paths)
    for impath in pbar:
        # read image
        im = tvf.to_tensor(Image.open(impath)).unsqueeze_(0).to(device=device)
        # compression
        compressed_obj = model.compress(im)
        num_bits = get_object_bits(compressed_obj)
        # decompression
        im_hat = model.decompress(compressed_obj)
        mse = torch.nn.functional.mse_loss(im, im_hat, reduction='mean')

        # metrics
        bpp  = float(num_bits / (im.shape[2] * im.shape[3]))
        psnr = float(-10 * torch.log10(mse))
        # logging
        pbar.set_description(f'image {impath.stem}: bpp={bpp:.5f}, psnr={psnr:.3f}')
        all_image_stats['bpp'] += bpp
        all_image_stats['psnr']  += psnr
        all_image_stats['count'] += 1

    # average over all images
    count = all_image_stats.pop('count')
    results = {k: v/count for k,v in all_image_stats.items()}
    return results

In [3]:
from mycv.models.vae.qres import qres34m

dataset_root = 'd:/datasets/improcessing/kodak'
from mycv.paths import MYCV_DIR
weights_root = MYCV_DIR / 'weights/my-vaes/dh-64s4x'

all_lmb_stats = defaultdict(list)
for lmb in [16, 32, 64, 128, 256, 512, 1024, 2048]:
    # initialize model
    model = qres34m(lmb=lmb)

    wpath = weights_root / f'dh_64s4x-lmb{lmb}/last_ema.pt'
    msd = torch.load(wpath)['model']
    model.load_state_dict(msd)

    print(f'Evaluating lmb={lmb}. Model weights={wpath}')
    model.compress_mode()
    model = model.cuda()
    model.eval()

    results = evaluate_model(model, dataset_root)
    # print(results)
    for k,v in results.items():
        all_lmb_stats[k].append(v)

for k, vlist in all_lmb_stats.items():
    vlist_str = ', '.join([f'{v:.12f}'[:8] for v in vlist])
    print(f'{k:<6s} = [{vlist_str}]')

Evaluating lmb=16. Model weights=d:\projects\mycv-pytorch\mycv\weights\my-vaes\dh-64s4x\dh_64s4x-lmb16\last_ema.pt


image kodim24: bpp=0.28709, psnr=28.849: : 24it [00:07,  3.39it/s]


Evaluating lmb=32. Model weights=d:\projects\mycv-pytorch\mycv\weights\my-vaes\dh-64s4x\dh_64s4x-lmb32\last_ema.pt


image kodim24: bpp=0.44863, psnr=30.903: : 24it [00:04,  5.09it/s]


Evaluating lmb=64. Model weights=d:\projects\mycv-pytorch\mycv\weights\my-vaes\dh-64s4x\dh_64s4x-lmb64\last_ema.pt


image kodim24: bpp=0.64368, psnr=33.194: : 24it [00:04,  5.11it/s]


Evaluating lmb=128. Model weights=d:\projects\mycv-pytorch\mycv\weights\my-vaes\dh-64s4x\dh_64s4x-lmb128\last_ema.pt


image kodim24: bpp=0.91211, psnr=35.330: : 24it [00:04,  5.05it/s]


Evaluating lmb=256. Model weights=d:\projects\mycv-pytorch\mycv\weights\my-vaes\dh-64s4x\dh_64s4x-lmb256\last_ema.pt


image kodim24: bpp=1.25291, psnr=37.459: : 24it [00:04,  5.01it/s]


Evaluating lmb=512. Model weights=d:\projects\mycv-pytorch\mycv\weights\my-vaes\dh-64s4x\dh_64s4x-lmb512\last_ema.pt


image kodim24: bpp=1.62880, psnr=39.515: : 24it [00:04,  4.94it/s]


Evaluating lmb=1024. Model weights=d:\projects\mycv-pytorch\mycv\weights\my-vaes\dh-64s4x\dh_64s4x-lmb1024\last_ema.pt


image kodim24: bpp=2.14693, psnr=41.367: : 24it [00:04,  4.95it/s]


Evaluating lmb=2048. Model weights=d:\projects\mycv-pytorch\mycv\weights\my-vaes\dh-64s4x\dh_64s4x-lmb2048\last_ema.pt


image kodim24: bpp=2.78479, psnr=43.288: : 24it [00:05,  4.41it/s]

bpp    = [0.183965, 0.301801, 0.452579, 0.675982, 0.954627, 1.287195, 1.748591, 2.357312]
psnr   = [30.01786, 31.97380, 33.88750, 35.84958, 38.14332, 40.22931, 42.21028, 44.31357]





In [4]:
# save to json
import json

save_json_path = Path().cwd() / '../results/kodak-qres34m.json'
with open(save_json_path, 'w') as f:
    json.dump(all_lmb_stats, fp=f, indent=4)
