# Visualization: CRC

Losses:
- binary w/ threshold
- Miscoverage
 

In [None]:
import matplotlib.pyplot as plt
import torch
import json

from cose.utils import load_project_paths
from app.tools import setup_mmseg_inference
from cose.conformal import split_dataset_idxs
from mmengine.registry import init_default_scope  # type: ignore

init_default_scope("mmseg")


# device_str = "cuda:0"
device_str = "cuda:1"
# device_str = "cpu"
device = torch.device(device_str)

prj_path = load_project_paths().COSE_PATH

# dataset: Literal["Cityscapes", "ADE20K", "LoveDA"]
dataset_name = "ADE20K"

# loss: Literal["binary", "miscoverage"]
loss = "miscoverage"
# loss = "binary"

# my_alpha = 0.1
my_alpha = 0.01

if dataset_name == "ADE20K" and loss == "binary":
    if my_alpha == 0.1:
        config_json = f"{prj_path}/experiments/outputs/ADE20K/binary_loss/20240315_23h17m12s_ADE20K__id_102__alpha_0.1__binary_loss.json"
    elif my_alpha == 0.01:
        # # mincov: 0.75, alpha=0.01
        config_json = f"{prj_path}/experiments/outputs/ADE20K/binary_loss/20240316_02h39m23s_ADE20K__id_102__alpha_0.01__binary_loss.json"
    else:
        raise ValueError(f"{my_alpha = }")
elif dataset_name == "ADE20K" and loss == "miscoverage":
    if my_alpha == 0.01:
        config_json = f"{prj_path}/experiments/outputs/ADE20K/miscoverage_loss/20240303_02h19m59s_ADE20K__id_102__alpha_0.01__miscoverage_loss.json"

my_config = json.load(open(config_json))
n_calib = my_config["n_calib"]
print(f"{my_config['mincov'] = }")
print(my_config.keys())
print(f"{my_config['experiment_id'] = }")

dataset, model, input_paths = setup_mmseg_inference(
    device, dataset_name, my_config["experiment_id"], n_calib=n_calib
)
#
_cal_ids, test_ids = split_dataset_idxs(
    len_dataset=len(dataset),
    n_calib=my_config["n_calib"],
    random_seed=my_config["experiment_id"],
)

In [None]:
import mmcv

image_id_ = None

show_some_images = False
if show_some_images:

    for i in range(100, 200, 4):
        id1, id2, id3, id4 = (
            test_ids[i],
            test_ids[i + 1],
            test_ids[i + 2],
            test_ids[i + 3],
        )

        # Plot for the first subplot
        fig, axs = plt.subplots(1, 4, figsize=(10, 4))
        im_path = dataset[id1]["data_samples"].img_path
        im = mmcv.imread(im_path)
        axs[0].imshow(mmcv.bgr2rgb(im))
        axs[0].set_title(f"id: {id1}")

        # Plot for the second subplot
        im_path = dataset[id2]["data_samples"].img_path
        im = mmcv.imread(im_path)
        axs[1].imshow(mmcv.bgr2rgb(im))
        axs[1].set_title(f"id: {id2}")

        im_path = dataset[id3]["data_samples"].img_path
        im = mmcv.imread(im_path)
        # Plot for the third subplot
        axs[2].imshow(mmcv.bgr2rgb(im))
        axs[2].set_title(f"id: {id3}")

        im_path = dataset[id4]["data_samples"].img_path
        im = mmcv.imread(im_path)
        # Plot for the third subplot
        axs[3].imshow(mmcv.bgr2rgb(im))
        axs[3].set_title(f"id: {id4}")

        plt.show()

In [None]:
if dataset_name == "ADE20K":
    # image_id_ = 3
    # image_id_ = 37
    # image_id_ = 44
    # image_id_ = 184
    # image_id_ = 203
    # image_id_ = 308
    image_id_ = 312
    # image_id_ = 314

    im_path = dataset[image_id_]["data_samples"].img_path
    print(f"[LUCA dbg] {im_path = }")
    im = mmcv.imread(im_path)
    plt.imshow(mmcv.bgr2rgb(im))
    plt.show()

In [None]:
from app.tools import parse_arguments, setup_gpu, setup_mmseg_inference
from app.tools import heatmap_from_multimaks, segmask_from_softmax
from cose.conformal import lac_multimask, PredictionHandler
from PIL import Image
import matplotlib as mpl


def plot_heatmap_from_input_img_path(
    input_img_path: str,
    expe_config,
    normalize_by_total_number_of_classes,
):
    torch.cuda.empty_cache()

    threshold = expe_config["optimal_lambda"]

    dataset, model, _ = setup_mmseg_inference(
        torch_device=device,
        use_case=expe_config["dataset"],
        random_seed=expe_config["experiment_id"],
        n_calib=expe_config["n_calib"],
    )

    with torch.no_grad():
        predictor = PredictionHandler(model["model"].mmseg_model)

        softmax_prediction = predictor.predict(input_img_path)
        segmask = segmask_from_softmax(softmax_prediction).cpu().numpy()

        cmap = mpl.colormaps["tab20"]
        segmask = Image.fromarray(cmap(segmask / segmask.max(), bytes=True))  # * 255)

        figure_size = (21, 5)
        fig, axs = plt.subplots(1, 3, figsize=figure_size, dpi=200)

        _input = mmcv.imread(input_img_path)
        ax1 = plt.subplot(1, 3, 1)
        ax1.set_title("Input image")
        ax1.imshow(mmcv.bgr2rgb(_input))

        ax2 = plt.subplot(1, 3, 2)
        ax2.set_title("Segmentation mask")
        ax2.imshow(segmask, alpha=0.5, cmap=cmap)

        multimask = lac_multimask(
            threshold=threshold,
            predicted_softmax=softmax_prediction,
            n_labels=dataset.n_classes,
        )

        cmap = mpl.colormaps["turbo"]
        _map = multimask.sum(dim=0).cpu().numpy()

        if normalize_by_total_number_of_classes:
            vmax = dataset.n_classes
        else:
            vmax = multimask.sum(dim=0).max().cpu().numpy()

        ax3 = plt.subplot(1, 3, 3)
        ax3.set_title("Uncertainty heatmap")
        im = ax3.imshow(_map, cmap=cmap, aspect="equal", vmax=vmax)

        fig.colorbar(im, ax=ax3, label="Number of classes per pixel", pad=0.05)
        plt.rcParams.update({"font.size": 10})
        plt.show()


plot_heatmap_from_input_img_path(
    input_img_path=im_path,
    expe_config=my_config,
    normalize_by_total_number_of_classes=True,
    # normalize_by_total_number_of_classes=False,
)

## Thresholding: visualize how $\lambda$ determines the heatmap

In [None]:
from typing import Sequence


def plot_threshold_heatmap_from_input_img_path(
    input_img_path: str,
    expe_config,
    normalize_by_total_number_of_classes,
    n_classes: int,
    lbd: Sequence[float],
):
    dataset, model, _ = setup_mmseg_inference(
        torch_device=device,
        use_case=expe_config["dataset"],
        random_seed=expe_config["experiment_id"],
        n_calib=expe_config["n_calib"],
    )

    with torch.no_grad():

        predictor = PredictionHandler(model["model"].mmseg_model)
        softmax_prediction = predictor.predict(input_img_path)

        segmask = segmask_from_softmax(softmax_prediction).cpu().numpy()
        _input = mmcv.imread(input_img_path)
        plt.imshow(mmcv.bgr2rgb(_input))
        plt.imshow(segmask, cmap="tab20", alpha=0.5)
        plt.show()

        figure_size = (22, 10)
        fig, axs = plt.subplots(1, len(lbd), figsize=figure_size, dpi=300)
        for i, lb in enumerate(lbd):

            multimask = lac_multimask(
                threshold=lb,
                predicted_softmax=softmax_prediction,
                n_labels=dataset.n_classes,
            )
            heatmap = multimask.sum(dim=0).cpu().numpy()
            if normalize_by_total_number_of_classes:
                vmax = n_classes
            else:
                vmax = heatmap.max().item()

            cmap = mpl.colormaps["turbo"]

            ax = axs[i]
            hm = ax.imshow(
                heatmap,
                cmap=cmap,
                vmax=vmax,
            )

            if i > 0:
                # plt.xticks([])
                ax.set_yticks([])

            ax.set_title(f"$\lambda = $ {lb}")

            # plt.colorbar(shrink=0.5, aspect="equal")
            # plt.subplots_adjust(hspace=0)
            # plt.subplots_adjust(wspace=0, hspace=0)
            plt.colorbar(
                hm, ax=ax, label="Number of classes per pixel", pad=0.05, shrink=0.25
            )

        plt.rcParams.update({"font.size": 10})

        plt.show()

    torch.cuda.empty_cache()


plot_threshold_heatmap_from_input_img_path(
    input_img_path=im_path,
    expe_config=my_config,
    normalize_by_total_number_of_classes=False,  # True,
    # normalize_by_total_number_of_classes=True,
    lbd=[0.99, 0.999, 0.999999, 0.99999999, 0.9999999999],
    n_classes=dataset.n_classes,
)