In [None]:
import torch_tensorrt
import torch
from pathlib import Path
import piq
from collections import defaultdict
from binarization.traintools import CustomLPIPS

from binarization.config import get_default_config
from binarization.dataset import get_test_batches
from binarization.traintools import prepare_generator, prepare_cuda_device
from binarization.datatools import (
    min_max_scaler,
    inv_make_4times_downscalable,
)
from tqdm import tqdm

- [✅] load a model
- [✅] fetch a pair compressed/original image
- [✅] generated = model(compressed)
- [✅] metric(original, generated)
- [✅] scale up

In [None]:
n_evaluations = 50
default_cfg = get_default_config()
device = prepare_cuda_device()
lpips_alex = CustomLPIPS(net="alex")
metrics = defaultdict(list)

In [None]:
unet_cfg = default_cfg.copy()
unet_cfg.model.name = "unet"
unet_cfg.paths.ckpt_path_to_resume = Path(
    default_cfg.paths.artifacts_dir,
    "best_checkpoints",
    "2022_12_19_unet_4_318780.pth",
)
unet = prepare_generator(unet_cfg, device).eval()

test_batches = get_test_batches()
progress_bar = tqdm(test_batches, total=n_evaluations)

for step_id, (original, compressed) in enumerate(progress_bar):
    if n_evaluations and step_id > n_evaluations - 1:
        break
    original = min_max_scaler(original)
    compressed = compressed.to(device)

    unet.eval()
    with torch.no_grad():
        generated = unet(compressed).cpu()
    generated = inv_make_4times_downscalable(original, generated)

    metrics["unet_lpips"].append(lpips_alex(generated, original).item())
    metrics["unet_ssim"].append(piq.ssim(generated, original).item())
    metrics["unet_psnr"].append(piq.psnr(generated, original).item())
    metrics['unet_ms_ssim'].append(piq.multi_scale_ssim(generated, original).item())
    metrics['unet_brisque'].append(piq.brisque(generated).item())

In [None]:
srunet_cfg = default_cfg.copy()
srunet_cfg.model.name = "srunet"
srunet_cfg.paths.ckpt_path_to_resume = Path(
    default_cfg.paths.artifacts_dir,
    "best_checkpoints",
    "2022_12_19_srunet_4_318780.pth",
)
srunet = prepare_generator(srunet_cfg, device).eval()

test_batches = get_test_batches()
progress_bar = tqdm(test_batches, total=n_evaluations)

for step_id, (original, compressed) in enumerate(progress_bar):
    if n_evaluations and step_id > n_evaluations - 1:
        break
    original = min_max_scaler(original)
    compressed = compressed.to(device)

    srunet.eval()
    with torch.no_grad():
        generated = srunet(compressed).cpu()
    generated = inv_make_4times_downscalable(original, generated)

    metrics["srunet_lpips"].append(lpips_alex(generated, original).item())
    metrics["srunet_ssim"].append(piq.ssim(generated, original).item())
    metrics["srunet_psnr"].append(piq.psnr(generated, original).item())
    metrics['srunet_ms_ssim'].append(piq.multi_scale_ssim(generated, original).item())
    metrics['srunet_brisque'].append(piq.brisque(generated).item())

In [None]:
quant_unet = torch.jit.load(
    default_cfg.paths.trt_dir / "unet.ts"
).to(device).eval()

test_batches = get_test_batches()
progress_bar = tqdm(test_batches, total=n_evaluations)

for step_id, (original, compressed) in enumerate(progress_bar):
    if n_evaluations and step_id > n_evaluations - 1:
        break
    original = min_max_scaler(original)
    compressed = compressed.to(device)

    quant_unet.eval()
    with torch.no_grad():
        generated = quant_unet(compressed).cpu()
    generated = inv_make_4times_downscalable(original, generated)

    metrics["quant_unet_lpips"].append(lpips_alex(generated, original).item())
    metrics["quant_unet_ssim"].append(piq.ssim(generated, original).item())
    metrics["quant_unet_psnr"].append(piq.psnr(generated, original).item())
    metrics['quant_unet_ms_ssim'].append(piq.multi_scale_ssim(generated, original).item())
    metrics['quant_unet_brisque'].append(piq.brisque(generated).item())

In [None]:
quant_srunet = torch.jit.load(
    default_cfg.paths.trt_dir / "srunet.ts"
).to(device).eval()

test_batches = get_test_batches()
progress_bar = tqdm(test_batches, total=n_evaluations)

for step_id, (original, compressed) in enumerate(progress_bar):
    if n_evaluations and step_id > n_evaluations - 1:
        break
    original = min_max_scaler(original)
    compressed = compressed.to(device)

    quant_srunet.eval()
    with torch.no_grad():
        generated = quant_srunet(compressed).cpu()
    generated = inv_make_4times_downscalable(original, generated)

    metrics["quant_srunet_lpips"].append(lpips_alex(generated, original).item())
    metrics["quant_srunet_ssim"].append(piq.ssim(generated, original).item())
    metrics["quant_srunet_psnr"].append(piq.psnr(generated, original).item())
    metrics['quant_srunet_ms_ssim'].append(piq.multi_scale_ssim(generated, original).item())
    metrics['quant_srunet_brisque'].append(piq.brisque(generated).item())

In [None]:
# import pandas as pd
# from datetime import datetime

# df = pd.DataFrame(metrics)
# today_str = datetime.now().strftime(r"%Y_%m_%d")
# save_path = default_cfg.paths.outputs_dir / f"{today_str}_results.csv"
# df.to_csv(save_path, index=False)

# Visualize results

In [None]:
from binarization.config import get_default_config

default_cfg = get_default_config()
results_path = default_cfg.paths.outputs_dir / "2023_02_28_results.csv"

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df = pd.read_csv(results_path)

models = ["unet", "srunet", "quant_unet", "quant_srunet"]
metrics = ["lpips", "ssim", "psnr", "ms_ssim", "brisque"]

dfs = {}
for metric in metrics:
    dfs[metric] = df[[model + "_" + metric for model in models]]
    dfs[metric].columns = models

n_metrics = len(metrics)
fig, ax = plt.subplots(1, n_metrics, figsize=(n_metrics * n_metrics, n_metrics), sharey=True)
for i, k in enumerate(metrics):
    sns.boxplot(dfs[k], ax=ax[i], orient="h")
    ax[i].set_title(k, fontsize=16)
    ax[i].grid()