In [None]:
"""Script to evaluate an image with a super-resolution model"""

from __future__ import annotations

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_tensorrt  # mandatory for inference even without calling it
import torchvision.transforms.functional as TF
from tqdm import tqdm
from PIL import Image

from binarization.config import Gifnoc, get_default_config
from binarization.dataset import get_test_batches
from binarization.traintools import prepare_cuda_device, prepare_generator

from binarization.datatools import (
    compose,
    list_directories,
    list_files,
    make_4times_downscalable,
    min_max_scaler,
    random_crop_images,
    draw_validation_fig,
    postprocess,
)

In [None]:
cfg = get_default_config()
cuda_or_cpu = prepare_cuda_device(0)

generated_dict = {}

In [None]:
original_path = cfg.paths.original_frames_dir / "old_town_cross_1080p50/old_town_cross_1080p50_0001.png"
compressed_path = cfg.paths.compressed_frames_dir / "old_town_cross_1080p50/old_town_cross_1080p50_0001.jpg"
original, compressed = Image.open(original_path), Image.open(compressed_path)
pipe = compose(TF.pil_to_tensor, min_max_scaler, make_4times_downscalable)
original, compressed = TF.pil_to_tensor(original), TF.pil_to_tensor(compressed)
compressed = min_max_scaler(compressed)
compressed = make_4times_downscalable(compressed)
compressed = compressed.to(cuda_or_cpu)
half_compressed = compressed.half()

unet_fp32_path = cfg.paths.trt_dir / "unet_fp32.ts"
unet_fp32 = torch.jit.load(unet_fp32_path).to(cuda_or_cpu).eval()
with torch.no_grad():
    generated_dict["UNet-FP32"] = TF.crop(
        postprocess(
            unet_fp32(compressed.unsqueeze(0)).squeeze(0).cpu(),
            width_original=original.shape[-1],
            height_original=original.shape[-2],
        ),
        150 * 4, 250 * 4, 96 * 4, 96 * 4
    )
unet_fp16_path = cfg.paths.trt_dir / "unet_fp16.ts"
unet_fp16 = torch.jit.load(unet_fp16_path).to(cuda_or_cpu).eval()
with torch.no_grad():
    generated_dict["UNet-FP16"] = TF.crop(
        postprocess(
            unet_fp16(half_compressed.unsqueeze(0)).squeeze(0).cpu(),
            width_original=original.shape[-1],
            height_original=original.shape[-2],
        ),
        150 * 4, 250 * 4, 96 * 4, 96 * 4
    )
unet_int8_path = cfg.paths.trt_dir / "unet_int8.ts"
unet_int8 = torch.jit.load(unet_int8_path).to(cuda_or_cpu).eval()
with torch.no_grad():
    generated_dict["UNet-INT8"] = TF.crop(
        postprocess(
            unet_int8(compressed.unsqueeze(0)).squeeze(0).cpu(),
            width_original=original.shape[-1],
            height_original=original.shape[-2],
        ),
        150 * 4, 250 * 4, 96 * 4, 96 * 4
    )

srunet_fp32_path = cfg.paths.trt_dir / "srunet_fp32.ts"
srunet_fp32 = torch.jit.load(srunet_fp32_path).to(cuda_or_cpu).eval()
with torch.no_grad():
    generated_dict["SRUNet-FP32"] = TF.crop(
        postprocess(
            srunet_fp32(compressed.unsqueeze(0)).squeeze(0).cpu(),
            width_original=original.shape[-1],
            height_original=original.shape[-2],
        ),
        150 * 4, 250 * 4, 96 * 4, 96 * 4
    )
srunet_fp16_path = cfg.paths.trt_dir / "srunet_fp16.ts"
srunet_fp16 = torch.jit.load(srunet_fp16_path).to(cuda_or_cpu).eval()
with torch.no_grad():
    generated_dict["SRUNet-FP16"] = TF.crop(
        postprocess(
            srunet_fp16(half_compressed.unsqueeze(0)).squeeze(0).cpu(),
            width_original=original.shape[-1],
            height_original=original.shape[-2],
        ),
        150 * 4, 250 * 4, 96 * 4, 96 * 4
    )
srunet_int8_path = cfg.paths.trt_dir / "srunet_int8.ts"
srunet_int8 = torch.jit.load(srunet_int8_path).to(cuda_or_cpu).eval()
with torch.no_grad():
    generated_dict["SRUNet-INT8"] = TF.crop(
        postprocess(
            srunet_int8(compressed.unsqueeze(0)).squeeze(0).cpu(),
            width_original=original.shape[-1],
            height_original=original.shape[-2],
        ),
        150 * 4, 250 * 4, 96 * 4, 96 * 4
    )

In [None]:
original_path = cfg.paths.original_frames_dir / "old_town_cross_1080p50/old_town_cross_1080p50_0001.png"
compressed_path = cfg.paths.compressed_frames_dir / "old_town_cross_1080p50/old_town_cross_1080p50_0001.jpg"
original, compressed = Image.open(original_path), Image.open(compressed_path)
pipe = compose(TF.pil_to_tensor, min_max_scaler, make_4times_downscalable)
original, compressed = TF.pil_to_tensor(original), TF.pil_to_tensor(compressed)

compressed = TF.crop(compressed, 150, 250, 96, 96)
original = TF.crop(original, 150 * 4, 250 * 4, 96 * 4, 96 * 4)
compressed = min_max_scaler(compressed)

compressed = compressed.to(cuda_or_cpu)

unet_ckpt_path = Path(
    cfg.paths.artifacts_dir,
    "best_checkpoints",
    f"2023_03_24_unet_2_191268.pth",
)
unet_cfg = cfg.copy()
unet_cfg.model.ckpt_path_to_resume = unet_ckpt_path
unet_cfg.model.name = "unet"

unet = prepare_generator(unet_cfg, device=cuda_or_cpu).eval()
with torch.no_grad():
    generated_dict["UNet"] = postprocess(
        unet(compressed.unsqueeze(0)).squeeze(0).cpu(),
        width_original=original.shape[-1],
        height_original=original.shape[-2],
    )

srunet_ckpt_path = Path(
    cfg.paths.artifacts_dir,
    "best_checkpoints",
    f"2023_03_24_srunet_2_191268.pth",
)
srunet_cfg = cfg.copy()
srunet_cfg.model.ckpt_path_to_resume = srunet_ckpt_path
srunet_cfg.model.name = "srunet"

srunet = prepare_generator(srunet_cfg, device=cuda_or_cpu).eval()
with torch.no_grad():
    generated_dict["SRUNet"] = postprocess(
        srunet(compressed.unsqueeze(0)).squeeze(0).cpu(),
        width_original=original.shape[-1],
        height_original=original.shape[-2],
    )

In [None]:
compressed = compressed.cpu()
figsize=(10, 3 * 4)
original_image_pil = TF.to_pil_image(original)
compressed_image_pil = TF.to_pil_image(compressed)
fig, ax = plt.subplots(4, 3, figsize=figsize)

for i, model_name in enumerate(["UNet", "UNet-FP32", "UNet-FP16", "UNet-INT8"]):
    ax[i][0].imshow(original_image_pil)
    ax[i][0].set_title('high quality')
    ax[i][0].axis('off')
    ax[i][1].imshow(TF.to_pil_image(generated_dict[model_name]))
    ax[i][1].set_title(f'generated by {model_name}')
    ax[i][1].axis('off')
    ax[i][2].imshow(compressed_image_pil)
    ax[i][2].set_title('low quality')
    ax[i][2].axis('off')

fig.subplots_adjust(
    top=1.0, bottom=0.0, right=1.0, left=0.0, hspace=0.15, wspace=0.0
)
fig.tight_layout()
fig.savefig(cfg.paths.outputs_dir / "unet_qualitative_results.png")
# fig.savefig(save_dir / f'{step_id:05d}_validation_fig.png')
# plt.close(fig)  # close the current fig to prevent OOM issues

In [None]:
compressed = compressed.cpu()
figsize=(10, 3 * 4)
original_image_pil = TF.to_pil_image(original)
compressed_image_pil = TF.to_pil_image(compressed)
fig, ax = plt.subplots(4, 3, figsize=figsize)

for i, model_name in enumerate(["SRUNet", "SRUNet-FP32", "SRUNet-FP16", "SRUNet-INT8"]):
    ax[i][0].imshow(original_image_pil)
    ax[i][0].set_title('high quality')
    ax[i][0].axis('off')
    ax[i][1].imshow(TF.to_pil_image(generated_dict[model_name]))
    ax[i][1].set_title(f'generated by {model_name}')
    ax[i][1].axis('off')
    ax[i][2].imshow(compressed_image_pil)
    ax[i][2].set_title('low quality')
    ax[i][2].axis('off')

fig.subplots_adjust(
    top=1.0, bottom=0.0, right=1.0, left=0.0, hspace=0.15, wspace=0.0
)
fig.savefig(cfg.paths.outputs_dir / "srunet_qualitative_results.png")
# fig.savefig(save_dir / f'{step_id:05d}_validation_fig.png')
# plt.close(fig)  # close the current fig to prevent OOM issues

In [None]:
n_evaluations = 1
model_name = "unet"
cuda_or_cpu = "cuda"
dtype = "fp32"
if cuda_or_cpu.startswith("cuda"):
    cuda_or_cpu = prepare_cuda_device(0)
cfg = get_default_config()

ckpt_path = Path(
    cfg.paths.artifacts_dir,
    "best_checkpoints",
    f"2023_03_24_{model_name}_2_191268.pth",
)
cfg.model.ckpt_path_to_resume = ckpt_path
cfg.model.name = model_name

save_dir = cfg.paths.outputs_dir / ckpt_path.stem
save_dir.mkdir(exist_ok=True, parents=True)

gen = prepare_generator(cfg, device=cuda_or_cpu).eval()

test_batches = get_test_batches(cfg)
progress_bar = tqdm(test_batches, total=n_evaluations)
for step_id, (original, compressed) in enumerate(progress_bar):
    if n_evaluations and step_id > n_evaluations - 1:
        break

    compressed = compressed.to(cuda_or_cpu)
    if dtype == 'fp16':
        compressed = compressed.half()
    elif dtype not in {'fp32', 'int8'}:
        raise ValueError(
            f"Unknown dtype: {dtype}. Choose in {'fp32', 'fp16', 'int8'}."
        )

    gen.eval()
    with torch.no_grad():
        generated = gen(compressed)

    compressed = compressed.cpu()
    generated = generated.cpu()
    generated = postprocess(
        generated=generated,
        width_original=original.shape[-1],
        height_original=original.shape[-2],
    )

    for i in range(original.shape[0]):
        original_image=original[i]
        compressed_image=compressed[i]
        generated_image=generated[i]

        # compressed_h, compressed_w = 
        offset = compressed_image.shape[-2] - original_image.shape[-2] // 4
        compressed_image = TF.crop(compressed_image, offset // 2, 0, compressed_image.shape[-2] - offset, compressed_image.shape[-1])

        figsize=(15, 5 * 6)
        original_image_pil = TF.to_pil_image(original_image)
        compressed_image_pil = TF.to_pil_image(compressed_image)
        generated_image_pil = TF.to_pil_image(generated_image)
        fig, ax = plt.subplots(6, 3, figsize=figsize)
        ax[0][0].imshow(original_image_pil)
        ax[0][0].set_title('high quality')
        ax[0][0].axis('off')
        ax[0][1].imshow(generated_image_pil)
        ax[0][1].set_title(f'generated by {model_name}-{dtype}')
        ax[0][1].axis('off')
        ax[0][2].imshow(compressed_image_pil)
        ax[0][2].set_title('low quality')
        ax[0][2].axis('off')
        fig.subplots_adjust(
            top=1.0, bottom=0.0, right=1.0, left=0.0, hspace=0.0, wspace=0.0
        )
        # fig.savefig(save_dir / f'{step_id:05d}_validation_fig.jpg')
        # plt.close(fig)  # close the current fig to prevent OOM issues

In [None]:
def eval_images(
    gen: torch.nn.Module,
    save_dir: Path,
    cfg: Gifnoc = None,
    n_evaluations: int | None = None,
    dtype: str = "fp32",
    cuda_or_cpu: str = "cuda",
):
    """Upscales a bunch of images given a super-resolution model.

    Args:
        gen (torch.nn.Module): a PyTorch generator model.
        save_dir (Path): path to directory where to save evaluation figures.
        cfg (Gifnoc, optional): configuration settings. The only useful
            option to be modified here is `cfg.params.buffer_size`.
            Defaults to None.
        n_evaluations (Union[int, None], optional): num of images to evaluate.
            Defaults to None (that means all the available frames).
        dtype (str): Choose in {"fp32", "fp16", "int8"}. Defaults to "fp32".
        cuda_or_cpu (str, optional): {"cuda", "cpu"}. Defaults to "cuda".
    """
    if cfg is None:
        cfg = get_default_config()
    if cuda_or_cpu.startswith("cuda"):
        cuda_or_cpu = prepare_cuda_device(0)

    test_batches = get_test_batches(cfg)
    progress_bar = tqdm(test_batches, total=n_evaluations)

    for step_id, (original, compressed) in enumerate(progress_bar):
        if n_evaluations and step_id > n_evaluations - 1:
            break

        compressed = compressed.to(cuda_or_cpu)
        if dtype == 'fp16':
            compressed = compressed.half()
        elif dtype not in {'fp32', 'int8'}:
            raise ValueError(
                f"Unknown dtype: {dtype}. Choose in {'fp32', 'fp16', 'int8'}."
            )

        gen.eval()
        with torch.no_grad():
            generated = gen(compressed)

        compressed = compressed.cpu()
        generated = generated.cpu()
        generated = postprocess(
            generated=generated,
            width_original=original.shape[-1],
            height_original=original.shape[-2],
        )

        for i in range(original.shape[0]):
            fig = draw_validation_fig(
                original_image=original[i],
                compressed_image=compressed[i],
                generated_image=generated[i],
                figsize=(36, 15),
            )
            fig.savefig(save_dir / f'{step_id:05d}_validation_fig.jpg')
            plt.close(fig)  # close the current fig to prevent OOM issues


In [None]:
n_evaluations = 1
model_name = "unet"
cuda_or_cpu = "cuda:0"
cfg = get_default_config()

available_dtypes = ("fp32", "fp16", "int8")
for dtype in available_dtypes:
    quant_save_dir = cfg.paths.outputs_dir / f"{model_name}_{dtype}"
    quant_save_dir.mkdir(exist_ok=True, parents=True)

    quant_path = cfg.paths.trt_dir / f"{model_name}_{dtype}.ts"
    quant_gen = torch.jit.load(quant_path).to(cuda_or_cpu).eval()

    eval_images(
        gen=quant_gen,
        save_dir=quant_save_dir,
        n_evaluations=n_evaluations,
        dtype=dtype,
    )

In [None]:
def draw_validation_fig(
    original_image: torch.Tensor,
    compressed_image: torch.Tensor,
    generated_image: torch.Tensor,
    figsize: tuple[int, int] = (36, 15),
    save_path: Path | None = None,
) -> plt.Figure:
    """Draws three images in a row with matplotlib."""
    original_image_pil = TF.to_pil_image(original_image)
    compressed_image_pil = TF.to_pil_image(compressed_image)
    generated_image_pil = TF.to_pil_image(generated_image)
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
    ax1.imshow(original_image_pil)
    ax1.set_title('high quality')
    ax1.axis('off')
    ax2.imshow(generated_image_pil)
    ax2.set_title('reconstructed')
    ax2.axis('off')
    ax3.imshow(compressed_image_pil)
    ax3.set_title('low quality')
    ax3.axis('off')
    fig.subplots_adjust(
        top=1.0, bottom=0.0, right=1.0, left=0.0, hspace=0.0, wspace=0.0
    )
    if save_path is not None:
        fig.savefig(save_path)
    return fig