# Visualization: CRC

Losses:
- binary w/ threshold
- Miscoverage
 

In [None]:
import matplotlib.pyplot as plt
import torch
import json

from cose.utils import load_project_paths
from app.tools import setup_mmseg_inference
from cose.conformal import split_dataset_idxs
from mmengine.registry import init_default_scope  # type: ignore

init_default_scope("mmseg")


device_str = "cuda:0"
device = torch.device(device_str)

prj_path = load_project_paths().COSE_PATH

dataset_name = "LoveDA"

# loss: Literal["binary", "miscoverage"]
# loss = "miscoverage"
loss = "binary"

if dataset_name == "LoveDA" and loss == "binary":
    config_json = f"{prj_path}/experiments/outputs/LoveDA/binary_loss/20240320_20h27m17s_LoveDA__id_101__alpha_0.01__binary_loss.json"

    my_config = json.load(open(config_json))
    n_calib = my_config["n_calib"]
    print(f"{my_config['mincov'] = }")
    print(my_config.keys())
    print(f"{my_config['experiment_id'] = }")

dataset, model, input_paths = setup_mmseg_inference(
    device, dataset_name, my_config["experiment_id"], n_calib=n_calib
)
#
_cal_ids, test_ids = split_dataset_idxs(
    len_dataset=len(dataset),
    n_calib=my_config["n_calib"],
    random_seed=my_config["experiment_id"],
)

In [None]:
if dataset_name == "LoveDA":
    image_id_ = 132

    im_path = dataset[image_id_]["data_samples"].img_path
    im = mmcv.imread(im_path)
    plt.imshow(mmcv.bgr2rgb(im))
    plt.show()

In [None]:
from app.tools setup_mmseg_inference, segmask_from_softmax
from cose.conformal import lac_multimask, PredictionHandler
from PIL import Image
import matplotlib as mpl


def plot_heatmap_from_input_img_path(
    input_img_path: str,
    expe_config,
    normalize_by_total_number_of_classes,
):
    torch.cuda.empty_cache()

    threshold = expe_config["optimal_lambda"]

    dataset, model, _ = setup_mmseg_inference(
        torch_device=device,
        use_case=expe_config["dataset"],
        random_seed=expe_config["experiment_id"],
        n_calib=expe_config["n_calib"],
    )

    with torch.no_grad():
        predictor = PredictionHandler(model["model"].mmseg_model)

        softmax_prediction = predictor.predict(input_img_path)
        segmask = segmask_from_softmax(softmax_prediction).cpu().numpy()

        cmap = mpl.colormaps["tab20"]
        segmask = Image.fromarray(cmap(segmask / segmask.max(), bytes=True))  # * 255)

        figure_size = (21, 5)
        fig, axs = plt.subplots(1, 3, figsize=figure_size, dpi=200)

        _input = mmcv.imread(input_img_path)
        ax1 = plt.subplot(1, 3, 1)
        ax1.set_title("Input image")
        ax1.imshow(mmcv.bgr2rgb(_input))

        ax2 = plt.subplot(1, 3, 2)
        ax2.set_title("Segmentation mask")
        ax2.imshow(segmask, alpha=0.5, cmap=cmap)

        multimask = lac_multimask(
            threshold=threshold,
            predicted_softmax=softmax_prediction,
            n_labels=dataset.n_classes,
        )

        cmap = mpl.colormaps["turbo"]
        _map = multimask.sum(dim=0).cpu().numpy()

        if normalize_by_total_number_of_classes:
            vmax = dataset.n_classes
        else:
            vmax = multimask.sum(dim=0).max().cpu().numpy()

        ax3 = plt.subplot(1, 3, 3)
        ax3.set_title("Uncertainty heatmap")
        im = ax3.imshow(_map, cmap=cmap, aspect="equal", vmax=vmax)

        fig.colorbar(im, ax=ax3, label="Number of classes per pixel", pad=0.05)
        plt.rcParams.update({"font.size": 10})
        plt.show()

plot_heatmap_from_input_img_path(
    input_img_path=im_path,
    expe_config=my_config,
    normalize_by_total_number_of_classes=True,
)