In [None]:
import sys
sys.path.append("./licos/")

import torch

import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from dotmap import DotMap
import toml
from PIL import Image

import math
from pytorch_msssim import ms_ssim

from licos.model_utils import get_model
from licos.l0_image_folder import L0ImageFolder
from licos.l0_utils import DN_MAX
from licos.utils import get_savepath_str
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = "data/test/"

cfg = DotMap(toml.load("cfg/l0.toml"), _dynamic = False)
checkpoint_path = "results/bmshj2018-factorizedqual=1_l0=raw_seed=42.pth.tar"

In [None]:
import io

def pillow_encode(img, fmt='jpeg', quality=10):
    tmp = io.BytesIO()
    img.save(tmp, format=fmt, quality=quality)
    tmp.seek(0)
    filesize = tmp.getbuffer().nbytes
    bpp = filesize * float(8) / (img.size[0] * img.size[1])
    rec = Image.open(tmp)
    return rec, bpp

def find_closest_bpp(target, img, fmt='jpeg'):
    lower = 0
    upper = 100
    prev_mid = upper
    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        if bpp > target:
            upper = mid - 1
        else:
            lower = mid
    return rec, bpp

def find_closest_psnr(target, img, fmt='jpeg'):
    lower = 0
    upper = 100
    prev_mid = upper
    
    def _psnr(a, b):
        a = np.asarray(a).astype(np.float32)
        b = np.asarray(b).astype(np.float32)
        mse = np.mean(np.square(a - b))
        return 20*math.log10(255.) -10. * math.log10(mse)
    
    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        prev_mid = mid
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        psnr_val = _psnr(rec, img)
        if psnr_val > target:
            upper = mid - 1
        else:
            lower = mid
    return rec, bpp, psnr_val

def find_closest_msssim(target, img, fmt='jpeg'):
    lower = 0
    upper = 100
    prev_mid = upper
    
    def _mssim(a, b):
        a = torch.from_numpy(np.asarray(a).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
        b = torch.from_numpy(np.asarray(b).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
        return ms_ssim(a, b, data_range=255.).item()

    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        prev_mid = mid
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        msssim_val = _mssim(rec, img)
        if msssim_val > target:
            upper = mid - 1
        else:
            lower = mid
    return rec, bpp, msssim_val

In [None]:
def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()

def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
              for likelihoods in out_net['likelihoods'].values()).item()

def process_img(img):
    with torch.no_grad():
        out_net = net.forward(img.unsqueeze(0))
    out_net['x_hat'].clamp_(0, 1)
    # print(out_net.keys())
    out_net['x_hat'] = out_net['x_hat'][...,:img.shape[1],:img.shape[2]]
    reconstructed = out_net['x_hat'].squeeze().cpu()
    diff = torch.mean((out_net['x_hat'] - img).abs(), axis=1).squeeze().cpu()
    return out_net, reconstructed, diff

In [None]:
net = get_model(cfg.model,False,1,cfg.model_quality)
print(f'Parameters: {sum(p.numel() for p in net.parameters())}')
checkpoint = torch.load(checkpoint_path, map_location=device)
net.load_state_dict(checkpoint["state_dict"])
net.update()

In [None]:
cfg.l0_train_test_split = 0.5 # for testing
test_data = L0ImageFolder(dataset,cfg.seed,cfg.l0_train_test_split,cfg.l0_format,cfg.l0_target_resolution_merged_m,split="test")

In [None]:
idx = 1
fix, axes = plt.subplots(1, 3, dpi = 400)
for ax in axes:
    ax.axis('off')

out,rec,diff = process_img(test_data[idx])

axes[0].imshow(test_data[idx].squeeze())
axes[0].title.set_text('Original')

axes[1].imshow(rec)
axes[1].title.set_text('Reconstructed')

axes[2].imshow(diff, cmap='viridis')
axes[2].title.set_text('Difference')

plt.show()

In [None]:
print(f'PSNR: {compute_psnr(x.unsqueeze(0), out_net["x_hat"]):.2f}dB')
print(f'MS-SSIM: {compute_msssim(x.unsqueeze(0), out_net["x_hat"]):.4f}')
print(f'Bit-rate: {compute_bpp(out_net):.3f} bpp')

In [None]:
psnr, ssim, bpp = [],[],[]

psnr_bbp_matched, ssim_bbp_matched, bpp_bbp_matched = [],[],[]
psnr_psnr_matched, ssim_psnr_matched, bpp_psnr_matched = [],[],[]
psnr_ssim_matched, ssim_ssim_matched, bpp_ssim_matched = [],[],[]

psnr_bbp_matched, ssim_bbp_matched, bpp_bbp_matched = [],[],[]
psnr_psnr_matched, ssim_psnr_matched, bpp_psnr_matched = [],[],[]
psnr_ssim_matched, ssim_ssim_matched, bpp_ssim_matched = [],[],[]

for img in tqdm(test_data):
    out, reconstructed, diff = process_img(img)
    out_bpp = compute_bpp(out)
    out_psnr = compute_psnr(img.unsqueeze(0), out["x_hat"])
    out_ssim = compute_msssim(img.unsqueeze(0), out["x_hat"])
    psnr.append(out_psnr)
    ssim.append(out_ssim)
    bpp.append(out_bpp)

    PIL_img = Image.fromarray(np.uint8(img.squeeze() * 255) , 'L')
    
    rec_jpeg, bpp_jpeg = find_closest_bpp(out_bpp, PIL_img)
    rec_webp, bpp_webp = find_closest_bpp(out_bpp, PIL_img, fmt='webp')
    
    rec_jpeg, bpp_jpeg, psnr_jpeg = find_closest_psnr(out_psnr, img)
    rec_webp, bpp_webp, psnr_webp = find_closest_psnr(out_psnr, img, fmt='webp')

    rec_jpeg, bpp_jpeg, msssim_jpeg = find_closest_msssim(out_ssim, img)
    rec_webp, bpp_webp, msssim_webp = find_closest_msssim(out_ssim, img, fmt='webp')

In [None]:
print("PSNR=",np.mean(psnr),"\t SSIM=",np.mean(ssim),"\t BPP=",np.mean(bpp))