In [None]:
import torch
from torchvision import transforms
import numpy as np
import rasterio

import matplotlib.pyplot as plt
from ipywidgets import interact, widgets

from dotmap import DotMap
import toml,os 

import math
from pytorch_msssim import ms_ssim

from licos.model_utils import get_model
from licos.l0_utils import DN_MAX
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
cfg = DotMap(toml.load("cfg/l0.toml"), _dynamic = False)
checkpoint_path = "results/bmshj2018-factorizedqual=1_l0=raw_seed=42.pth.tar"

In [None]:
net = get_model(cfg.model,False,1,cfg.model_quality)
print(f'Parameters: {sum(p.numel() for p in net.parameters())}')

In [None]:
checkpoint = torch.load(checkpoint_path, map_location=device)
net.load_state_dict(checkpoint["state_dict"])

In [None]:
test_image_path = os.path.join("test.tif")

with rasterio.open(test_image_path) as src:
    x = (
        torch.from_numpy(src.read(1).astype(np.float32)).unsqueeze(0) / DN_MAX
    )

In [None]:
%matplotlib inline
plt.figure(figsize=(12, 9))
plt.axis('off')
plt.imshow(x[0])
plt.show()

In [None]:
with torch.no_grad():
    out_net = net.forward(x.unsqueeze(1))
out_net['x_hat'].clamp_(0, 1)
print(out_net.keys())

In [None]:
out_net['x_hat'] = out_net['x_hat'][...,:x.shape[1],:x.shape[2]]
rec_net = out_net['x_hat'].squeeze().cpu()
diff = torch.mean((out_net['x_hat'] - x).abs(), axis=1).squeeze().cpu()

In [None]:
%matplotlib inline
fix, axes = plt.subplots(1, 3, dpi = 300)
for ax in axes:
    ax.axis('off')
    
axes[0].imshow(x[0])
axes[0].title.set_text('Original')

axes[1].imshow(rec_net)
axes[1].title.set_text('Reconstructed')

axes[2].imshow(diff, cmap='viridis')
axes[2].title.set_text('Difference')

plt.show()

In [None]:
def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()

def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
              for likelihoods in out_net['likelihoods'].values()).item()

In [None]:
out_net["x_hat"].shape

In [None]:
x.shape

In [None]:
print(f'PSNR: {compute_psnr(x.unsqueeze(0), out_net["x_hat"]):.2f}dB')
print(f'MS-SSIM: {compute_msssim(x.unsqueeze(0), out_net["x_hat"]):.4f}')
print(f'Bit-rate: {compute_bpp(out_net):.3f} bpp')