In [None]:
from __future__ import annotations

from collections import defaultdict
from pathlib import Path

import torch
import torch_tensorrt  # keep this
import piq
from tqdm import tqdm
from DISTS_pytorch import DISTS

from binarization.config import get_default_config
from binarization.dataset import get_test_batches
from binarization.traintools import (
    prepare_generator, prepare_cuda_device, CustomLPIPS
)
from binarization.datatools import (
    min_max_scaler,
    inv_make_4times_downscalable,
)

- [✅] load a model
- [✅] fetch a pair compressed/original image
- [✅] generated = model(compressed)
- [✅] metric(original, generated)
- [✅] scale up
- [] adjust names (e.g., ms_ssim -> MS-SSIM)

In [None]:
def eval_models(
    n_evaluations: int = 50,
) -> dict[str, float]:
    default_cfg = get_default_config()
    device = prepare_cuda_device()

    lpips_alex = CustomLPIPS(net="alex")
    dists = DISTS()
    metrics = defaultdict(list)
    model_names = ("unet", "srunet")

    for model_name in model_names:
        cfg = default_cfg.copy()
        cfg.model.name = model_name
        cfg.model.ckpt_path_to_resume = Path(
            cfg.paths.artifacts_dir,
            "best_checkpoints",
            f"2023_03_24_{model_name}_2_191268.pth",
        )
        gen = prepare_generator(cfg, device).eval()

        test_batches = get_test_batches()
        progress_bar = tqdm(test_batches, total=n_evaluations)

        for step_id, (original, compressed) in enumerate(progress_bar):
            if n_evaluations and step_id > n_evaluations - 1:
                break
            original = min_max_scaler(original)
            compressed = compressed.to(device)

            gen.eval()
            with torch.no_grad():
                generated = gen(compressed).cpu()

            width_original = original.shape[-1]
            height_original = original.shape[-2]
            generated = inv_make_4times_downscalable(
                generated,
                width_original=width_original,
                height_original=height_original,
            )

            metrics[f"{model_name}_lpips"].append(lpips_alex(generated, original).item())
            metrics[f"{model_name}_dists"].append(dists(generated, original).item())
            metrics[f"{model_name}_brisque"].append(piq.brisque(generated).item())
            metrics[f"{model_name}_ssim"].append(piq.ssim(generated, original).item())
            metrics[f"{model_name}_ms_ssim"].append(piq.multi_scale_ssim(generated, original).item())
            metrics[f"{model_name}_psnr"].append(piq.psnr(generated, original).item())
    return metrics

In [None]:
def eval_trt_models(
    n_evaluations: int = 50,
) -> dict[str, float]:
    cfg = get_default_config()
    device = prepare_cuda_device()
    lpips_alex = CustomLPIPS(net="alex")
    dists = DISTS()
    metrics = defaultdict(list)

    model_names = ("unet", "srunet")
    available_dtypes = ("fp32", "fp16", "int8")

    for model_name in model_names:
        for dtype in available_dtypes:
            quant_gen = torch.jit.load(
                cfg.paths.trt_dir / f"{model_name}_{dtype}.ts"
            ).to(device).eval()

            test_batches = get_test_batches()
            progress_bar = tqdm(test_batches, total=n_evaluations)

            for step_id, (original, compressed) in enumerate(progress_bar):
                if n_evaluations and step_id > n_evaluations - 1:
                    break
                original = min_max_scaler(original)
                compressed = compressed.to(device)

                if dtype == "fp16":
                    compressed = compressed.half()
                elif dtype not in {"fp32", "int8"}:
                    raise ValueError(
                        f"Unknown dtype: {dtype}. Choose in {'fp32', 'fp16', 'int8'}."
                    )

                quant_gen.eval()
                with torch.no_grad():
                    generated = quant_gen(compressed).cpu()
                width_original = original.shape[-1]
                height_original = original.shape[-2]
                generated = inv_make_4times_downscalable(
                    generated,
                    width_original=width_original,
                    height_original=height_original,
                )

                metrics[f"{model_name}_{dtype}_lpips"].append(lpips_alex(generated, original).item())
                metrics[f"{model_name}_{dtype}_dists"].append(dists(generated, original).item())
                metrics[f"{model_name}_{dtype}_brisque"].append(piq.brisque(generated).item())
                metrics[f"{model_name}_{dtype}_ssim"].append(piq.ssim(generated, original).item())
                metrics[f"{model_name}_{dtype}_ms_ssim"].append(piq.multi_scale_ssim(generated, original).item())
                metrics[f"{model_name}_{dtype}_psnr"].append(piq.psnr(generated, original).item())
    return metrics

In [None]:
import json
from datetime import datetime
def save_json(json_obj: dict, save_path: Path):
    with open(save_path, "w") as out_file:
        json.dump(json_obj, out_file)

n_evaluations = 60
metrics = eval_models(n_evaluations=n_evaluations)
trt_metrics = eval_trt_models(n_evaluations=n_evaluations)

metrics.update(trt_metrics)
today_str = datetime.now().strftime(r"%Y_%m_%d")
cfg = get_default_config()
save_path = cfg.paths.outputs_dir / f"{today_str}_metrics.json"
save_json(metrics, save_path)

In [None]:
# import pandas as pd
# from datetime import datetime

# df = pd.DataFrame(metrics)
# today_str = datetime.now().strftime(r"%Y_%m_%d")
# save_path = default_cfg.paths.outputs_dir / f"{today_str}_results.csv"
# df.to_csv(save_path, index=False)

# Visualize results

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import json
from datetime import datetime
today_str = datetime.now().strftime(r"%Y_%m_%d")
cfg = get_default_config()
# metrics_json_path = cfg.paths.outputs_dir / f"{today_str}_metrics.json"
metrics_json_path = cfg.paths.outputs_dir / f"2023_03_27_metrics.json"
with open(metrics_json_path, "r") as in_file:
    metrics_json = json.load(in_file)

df = pd.DataFrame(metrics_json)

In [None]:
# model_names = [
#     "unet",
#     "unet_fp32",
#     "unet_fp16",
#     "unet_int8",
#     "srunet",
#     "srunet_fp32",
#     "srunet_fp16",
#     "srunet_int8"
# ]

# info_metrics = {}
# metrics_table = {}
# for metric in metrics:
#     info_metrics[metric] = df[[model + '_' + metric for model in model_names]]
#     info_metrics[metric].columns = model_names
#     metrics_table[metric] = info_metrics[metric].apply(lambda x: f"{x.mean().round(3)} ± {x.std().round(3)}", axis=0)
# print(pd.DataFrame(metrics_table).to_latex())

In [None]:
models = [
    "unet",
    "unet_fp32",
    "unet_fp16",
    "unet_int8",
    "srunet",
    "srunet_fp32",
    "srunet_fp16",
    "srunet_int8"
]
corrected_model_names = [
    "UNet", "UNet-FP32", "UNet-FP16", "UNet-INT8",
    "SRUNet", "SRUNet-FP32", "SRUNet-FP16", "SRUNet-INT8"
]
model_names_map = {
    'unet': 'UNet',
    'unet_fp32': 'UNet-FP32',
    'unet_fp16': 'UNet-FP16',
    'unet_int8': 'UNet-INT8',
    'srunet': 'SRUNet',
    'srunet_fp32': 'SRUNet-FP32',
    'srunet_fp16': 'SRUNet-FP16',
    'srunet_int8': 'SRUNet-INT8'
}

perceptual_metrics = ["lpips", "dists", "brisque"]
traditional_metrics = ["ssim", "ms_ssim", "psnr"]

In [None]:
static_dir = Path.home() / "Projects/master-thesis/thesis/static"
perceptual_metrics_savepath = static_dir / "boxplots_perceptual_metrics.jpg"
traditional_metrics_savepath = static_dir / "boxplots_traditional_metrics.jpg"

dfs = {}
for metric in perceptual_metrics:
    dfs[metric] = df[[model + "_" + metric for model in models]]
    dfs[metric].columns = corrected_model_names
fig, ax = plt.subplots(1, len(perceptual_metrics), figsize=(16, 4), sharey=True)
for i, k in enumerate(perceptual_metrics):
    sns.boxplot(dfs[k], ax=ax[i], orient="h")
    ax[i].set_title(k.upper(), fontsize=16)
    ax[i].grid()
fig.savefig(perceptual_metrics_savepath)

In [None]:
dfs = {}
for metric in traditional_metrics:
    dfs[metric] = df[[model + "_" + metric for model in models]]
    dfs[metric].columns = corrected_model_names
fig, ax = plt.subplots(1, len(traditional_metrics), figsize=(16, 4), sharey=True)
for i, k in enumerate(traditional_metrics):
    sns.boxplot(dfs[k], ax=ax[i], orient="h")
    ax[i].set_title(k.replace("_", "-").upper(), fontsize=16)
    ax[i].grid()
fig.savefig(traditional_metrics_savepath)


In [None]:
from pathlib import Path
import json
from datetime import datetime
import pandas as pd

cfg = get_default_config()
today_str = datetime.now().strftime(r"%Y_%m_%d")
# unet_timings_json_path = cfg.paths.outputs_dir / f"{today_str}_timings_unet.json"
# srunet_timings_json_path = cfg.paths.outputs_dir / f"{today_str}_timings_srunet.json"
unet_timings_json_path = cfg.paths.outputs_dir / f"2023_03_30_timings_unet.json"
srunet_timings_json_path = cfg.paths.outputs_dir / f"2023_03_30_timings_srunet.json"

with open(unet_timings_json_path, "r") as in_file:
    unet_timings_dict = json.load(in_file)
with open(srunet_timings_json_path, "r") as in_file:
    srunet_timings_dict = json.load(in_file)
unet_timings_dict.update(srunet_timings_dict)
timings = pd.DataFrame(unet_timings_dict)
timings /= 1e+9
timings.rename(columns=model_names_map, inplace=True)
# save_path = cfg.paths.outputs_dir / f"{today_str}_timings.csv"
# timings.to_csv(save_path, index=False)

In [None]:
timings_table = timings.apply(lambda x: f"{x.mean().round(5):.5f} ± {x.std().round(5):.5f}", axis=0).to_frame(name="times [s]")
print(timings_table.to_latex())

In [None]:
perceptual_metrics_info = {}
perceptual_metrics_table = {}
for metric in perceptual_metrics:
    perceptual_metrics_info[metric] = df[[model + '_' + metric for model in models]]
    perceptual_metrics_info[metric].columns = corrected_model_names
    perceptual_metrics_table[metric.replace("_", "-").upper()] = perceptual_metrics_info[metric].apply(lambda x: f"{x.mean().round(5):.5f} ± {x.std().round(5):.5f}", axis=0)
print(pd.DataFrame(perceptual_metrics_table).to_latex())

In [None]:
traditional_metrics_info = {}
traditional_metrics_table = {}
for metric in traditional_metrics:
    traditional_metrics_info[metric] = df[[model + '_' + metric for model in models]]
    traditional_metrics_info[metric].columns = corrected_model_names
    traditional_metrics_table[metric.replace("_", "-").upper()] = traditional_metrics_info[metric].apply(lambda x: f"{x.mean().round(5):.5f} ± {x.std().round(5):.5f}", axis=0)
print(pd.DataFrame(traditional_metrics_table).to_latex())

In [None]:
fig, ax = plt.subplots(figsize=(16, 4), sharex=True)
sns.boxplot(timings, ax=ax, orient="h")
ax.grid()
ax.set_xlabel("seconds")
fig.suptitle("Time elapsed generating one image")
fig.tight_layout()
fig.savefig(static_dir / "boxplots_timings.jpg")

In [None]:
print(vmaf_df.T.to_latex())

In [None]:
vmaf_json_path = Path(cfg.paths.artifacts_dir, "vmaf", "vmaf_res.json")
with open(vmaf_json_path, "r") as in_file:
    vmaf_json = json.load(in_file)

vmaf_df = pd.DataFrame(vmaf_json)
# vmaf_df.columns = [x + "_vmaf" for x in vmaf_df.columns]
vmaf_df.index = ("mean", "harmonic_mean")

vmaf_means = vmaf_df.loc["mean"]
vmaf_xlim = vmaf_means.min(), vmaf_means.max()
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
vmaf_df.loc["mean"][[x for x in vmaf_df.columns if x.startswith("unet")]].plot(kind="barh", ax=ax1)
vmaf_df.loc["mean"][[x for x in vmaf_df.columns if x.startswith("srunet")]].plot(kind="barh", ax=ax2)
ax2.set_xlabel("VMAF (mean)")
ax1.grid("on"); ax2.grid("on")

vmaf_means = vmaf_df.loc["harmonic_mean"]
vmaf_xlim = vmaf_means.min(), vmaf_means.max()
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
vmaf_df.loc["harmonic_mean"][[x for x in vmaf_df.columns if x.startswith("unet")]].plot(kind="barh", ax=ax1)
vmaf_df.loc["harmonic_mean"][[x for x in vmaf_df.columns if x.startswith("srunet")]].plot(kind="barh", ax=ax2)
ax2.set_xlabel("VMAF (harmonic mean)")
ax1.grid("on"); ax2.grid("on")