In [None]:
from __future__ import annotations

from collections import defaultdict
from pathlib import Path

import torch
import torch_tensorrt
import piq
from tqdm import tqdm

from binarization.config import get_default_config, Gifnoc
from binarization.dataset import get_test_batches
from binarization.traintools import (
    prepare_generator, prepare_cuda_device, CustomLPIPS
)
from binarization.datatools import (
    min_max_scaler,
    inv_make_4times_downscalable,
)

- [✅] load a model
- [✅] fetch a pair compressed/original image
- [✅] generated = model(compressed)
- [✅] metric(original, generated)
- [✅] scale up

In [None]:
def eval_models(
    n_evaluations: int = 50,
) -> dict[str, float]:
    default_cfg = get_default_config()
    device = prepare_cuda_device()

    lpips_alex = CustomLPIPS(net="alex")
    metrics = defaultdict(list)
    model_names = ("unet", "srunet")

    for model_name in model_names:
        cfg = default_cfg.copy()
        cfg.model.name = model_name
        cfg.paths.ckpt_path_to_resume = Path(
            cfg.paths.artifacts_dir,
            "best_checkpoints",
            f"2022_12_19_{model_name}_4_318780.pth",
        )
        gen = prepare_generator(cfg, device).eval()

        test_batches = get_test_batches()
        progress_bar = tqdm(test_batches, total=n_evaluations)

        for step_id, (original, compressed) in enumerate(progress_bar):
            if n_evaluations and step_id > n_evaluations - 1:
                break
            original = min_max_scaler(original)
            compressed = compressed.to(device)

            gen.eval()
            with torch.no_grad():
                generated = gen(compressed).cpu()
            generated = inv_make_4times_downscalable(original, generated)

            metrics[f"{model_name}_lpips"].append(lpips_alex(generated, original).item())
            metrics[f"{model_name}_ssim"].append(piq.ssim(generated, original).item())
            metrics[f"{model_name}_psnr"].append(piq.psnr(generated, original).item())
            metrics[f"{model_name}_ms_ssim"].append(piq.multi_scale_ssim(generated, original).item())
            metrics[f"{model_name}_brisque"].append(piq.brisque(generated).item())
    return metrics

In [None]:
def eval_trt_models(
    n_evaluations: int = 50,
) -> dict[str, float]:
    cfg = get_default_config()
    device = prepare_cuda_device()
    lpips_alex = CustomLPIPS(net="alex")
    metrics = defaultdict(list)

    model_names = ("unet", "srunet")
    available_dtypes = ("fp32", "fp16", "int8")

    for model_name in model_names:
        for dtype in available_dtypes:
            quant_gen = torch.jit.load(
                cfg.paths.trt_dir / f"{model_name}_{dtype}.ts"
            ).to(device).eval()

            test_batches = get_test_batches()
            progress_bar = tqdm(test_batches, total=n_evaluations)

            for step_id, (original, compressed) in enumerate(progress_bar):
                if n_evaluations and step_id > n_evaluations - 1:
                    break
                original = min_max_scaler(original)
                compressed = compressed.to(device)

                if dtype == "fp16":
                    compressed = compressed.half()
                elif dtype not in {"fp32", "int8"}:
                    raise ValueError(
                        f"Unknown dtype: {dtype}. Choose in {'fp32', 'fp16', 'int8'}."
                    )

                quant_gen.eval()
                with torch.no_grad():
                    generated = quant_gen(compressed).cpu()
                generated = inv_make_4times_downscalable(original, generated)

                metrics[f"{model_name}_{dtype}_lpips"].append(lpips_alex(generated, original).item())
                metrics[f"{model_name}_{dtype}_ssim"].append(piq.ssim(generated, original).item())
                metrics[f"{model_name}_{dtype}_psnr"].append(piq.psnr(generated, original).item())
                metrics[f"{model_name}_{dtype}_ms_ssim"].append(piq.multi_scale_ssim(generated, original).item())
                metrics[f"{model_name}_{dtype}_brisque"].append(piq.brisque(generated).item())
    return metrics

In [None]:
import json
from datetime import datetime
def save_json(json_obj: dict, save_path: Path):
    with open(save_path, "w") as out_file:
        json.dump(json_obj, out_file)

n_evaluations = 30
metrics = eval_models(n_evaluations=n_evaluations)
trt_metrics = eval_trt_models(n_evaluations=n_evaluations)

metrics.update(trt_metrics)
today_str = datetime.now().strftime(r"%Y_%m_%d")
cfg = get_default_config()
save_path = cfg.paths.outputs_dir / f"{today_str}_metrics.json"
save_json(metrics, save_path)

In [None]:
# import pandas as pd
# from datetime import datetime

# df = pd.DataFrame(metrics)
# today_str = datetime.now().strftime(r"%Y_%m_%d")
# save_path = default_cfg.paths.outputs_dir / f"{today_str}_results.csv"
# df.to_csv(save_path, index=False)

# Visualize results

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import json
from datetime import datetime
today_str = datetime.now().strftime(r"%Y_%m_%d")
cfg = get_default_config()
metrics_json_path = cfg.paths.outputs_dir / f"{today_str}_metrics.json"
with open(metrics_json_path, "r") as in_file:
    metrics_json = json.load(in_file)

df = pd.DataFrame(metrics_json)

In [None]:
models_list = [
    [
        "unet",
        "unet_fp32",
        "unet_fp16",
        "unet_int8",
        "srunet",
        "srunet_fp32",
        "srunet_fp16",
        "srunet_int8"
    ],
    [
        "unet_fp32",
        "unet_fp16",
        "unet_int8",
    ],
    [
        "srunet_fp32",
        "srunet_fp16",
        "srunet_int8"
    ],
]
metrics = ["lpips", "ssim", "psnr", "ms_ssim", "brisque"]

for models in models_list:
    dfs = {}
    for metric in metrics:
        dfs[metric] = df[[model + "_" + metric for model in models]]
        dfs[metric].columns = models
        n_metrics = len(metrics)
    # fig, ax = plt.subplots(1, n_metrics, figsize=(n_metrics * n_metrics, n_metrics), sharey=True)
    fig, ax = plt.subplots(1, n_metrics, figsize=(20, 3), sharey=True)
    for i, k in enumerate(metrics):
        # sns.boxplot(dfs[k].apply(lambda x: x.round(2), axis=0), ax=ax[i], orient="h")
        sns.boxplot(dfs[k], ax=ax[i], orient="h")
        ax[i].set_title(k, fontsize=16)
        ax[i].grid() 

In [None]:
from pathlib import Path
import json
from datetime import datetime
import pandas as pd

cfg = get_default_config()
today_str = datetime.now().strftime(r"%Y_%m_%d")
unet_timings_json_path = cfg.paths.outputs_dir / f"{today_str}_timings_unet.json"
srunet_timings_json_path = cfg.paths.outputs_dir / f"{today_str}_timings_srunet.json"

with open(unet_timings_json_path, "r") as in_file:
    unet_timings_dict = json.load(in_file)
with open(srunet_timings_json_path, "r") as in_file:
    srunet_timings_dict = json.load(in_file)
unet_timings_dict.update(srunet_timings_dict)
timings = pd.DataFrame(unet_timings_dict)
save_path = cfg.paths.outputs_dir / f"{today_str}_timings.csv"
timings.to_csv(save_path, index=False)

In [None]:
timings /= 1e+9

In [None]:
models_list = [
    [
        "unet",
        "unet_fp32",
        "unet_fp16",
        "unet_int8",
        "srunet",
        "srunet_fp32",
        "srunet_fp16",
        "srunet_int8"
    ],
    [
        "unet_fp32",
        "unet_fp16",
        "unet_int8",
    ],
    [
        "srunet_fp32",
        "srunet_fp16",
        "srunet_int8"
    ],
]
fig, ax = plt.subplots(len(models_list), 1, figsize=(12, 8), sharex=True)
for i, k in enumerate(models_list):
    # sns.boxplot(dfs[k].apply(lambda x: x.round(2), axis=0), ax=ax[i], orient="h")
    sns.boxplot(timings[k], ax=ax[i], orient="h")
    # ax[i].set_title(k, fontsize=16)
    ax[i].grid()
ax[len(models_list) - 1].set_xlabel("seconds")
fig.suptitle("Time elapsed evaluating one image in seconds")
fig.tight_layout()