# Inference script

Da usare per:
- creare gli output video
- testare nuovi metodi di inferenza
- calcolare metrics di un modello salvato
- testare nuovi metodi per valutare le preds

Nel config file, modificare i parametri nella "testing" section

In [1]:
# autoreload is used to reload modules automatically before entering the
# execution of code typed at the IPython prompt.
%load_ext autoreload
%autoreload 2

In [2]:
# To import modules from parent directory in Jupyter Notebook
import sys

sys.path.append("..")

In [3]:
import logging
import os

import torch
from torch import nn
from torch.utils.data import DataLoader

from config import TrainingConfig, config
from data.data_processing_tools import masks_to_instances_dict, process_raw_predictions
from utils.in_out_tools import write_videos_on_disk

# from torch.cuda.amp import GradScaler
from utils.training_inference_tools import do_inference
from utils.training_script_utils import init_dataset, init_model, get_sample_ids

logger = logging.getLogger(__name__)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
config.verbosity = 3  # To get debug messages

##################### Get training-specific parameters #####################

# run_name = "final_model"
run_name = "TEMP_new_annotated_peaks_physio"  # TEMP local (run on laptop)
config_filename = "config_final_model.ini"
load_epoch = 100000

use_train_data = False
get_final_pred = True  # set to False to only compute raw predictions
testing = False  # set to False to only generate unet predictions
# set to True to also compute processed outputs and metrics
# inference_types = ['overlap', 'average', 'gaussian', 'max']
inference_types = None  # set to None to use the default inference type from
# the config file

# Initialize general parameters
params = TrainingConfig(
    training_config_file=os.path.join("..", "config_files", config_filename)  # notebook
    # training_config_file=os.path.join("config_files", config_filename)
)
if run_name:
    params.run_name = run_name
model_filename = f"network_{load_epoch:06d}.pth"

# Print parameters to console if needed
# params.print_params()

if testing:
    get_final_pred = True

debug = True if config.verbosity == 3 else False

[15:39:44] [  INFO  ] [   config   ] <288 > -- Loading ..\config_files\config_final_model.ini


In [5]:
######################### Configure output folder ##########################

output_folder = os.path.join(
    "evaluation", "inference_script"
)  # Same folder for train and test preds
os.makedirs(output_folder, exist_ok=True)

# Subdirectory of output_folder where predictions are saved.
# Change this to save results for same model with different inference
# approaches.
# output_name = training_name + "_step=2"
output_name = params.run_name

save_folder = os.path.join(output_folder, output_name)
os.makedirs(save_folder, exist_ok=True)
logger.info(f"Annotations and predictions will be saved on '{save_folder}'")

[15:39:46] [  INFO  ] [  __main__  ] < 16 > -- Annotations and predictions will be saved on 'evaluation\inference_script\final_model'


In [6]:
########################### Detect GPU, if available ###########################

params.set_device(device="auto")
params.display_device_info()

[15:39:47] [  INFO  ] [   config   ] <520 > -- Using cuda


In [8]:
###################### Config dataset and UNet model #######################

logger.info(f"Processing training '{params.run_name}'...")

# Define the sample IDs based on dataset size and usage
sample_ids = get_sample_ids(
    train_data=use_train_data, dataset_size=params.dataset_size, custom_ids=[]
)
logger.info(f"Predicting outputs for samples {sample_ids}")

logger.info(f"Using {params.dataset_dir} as dataset root path")

# Create dataset
dataset = init_dataset(
    params=params,
    sample_ids=sample_ids,
    inference_dataset=True,
    print_dataset_info=True,
)

# Create a dataloader
dataset_loader = DataLoader(
    dataset,
    batch_size=params.inference_batch_size,
    shuffle=False,
    num_workers=params.num_workers,
    pin_memory=params.pin_memory,
)

### Configure UNet ###

network = init_model(params=params)
network = nn.DataParallel(network).to(params.device)

### Load UNet model ###

# Path to the saved model checkpoint
models_relative_path = os.path.join(
    "models", "saved_models", params.run_name, model_filename
)
model_dir = os.path.realpath(os.path.join(config.basedir, models_relative_path))

# Load the model state dictionary
logger.info(f"Loading trained model '{run_name}' at epoch {load_epoch}...")
network.load_state_dict(torch.load(model_dir, map_location=params.device))
network.eval()

[15:40:14] [  INFO  ] [  __main__  ] < 3  > -- Processing training 'final_model'...
[15:40:14] [  INFO  ] [  __main__  ] < 9  > -- Predicting outputs for samples ['34']
[15:40:14] [  INFO  ] [  __main__  ] < 11 > -- Using C:\Users\prisc\Code\sparks_project\data\sparks_dataset as dataset root path
[15:40:14] [  INFO  ] [utils.training_script_utils] <136 > -- Samples in training dataset: 22
[15:40:15] [  INFO  ] [  __main__  ] < 44 > -- Loading trained model 'final_model' at epoch 100000...


FileNotFoundError: [Errno 2] No such file or directory: 'models\\saved_models\\final_model\\network_100000.pth'

In [None]:
########################### Run samples in UNet ############################

if inference_types is None:
    inference_types = [params.inference]

# get U-Net's raw predictions
raw_preds = do_inference(
    network=network,
    params=params,
    dataloader=dataset_loader,
    device=params.device,
    compute_loss=False,
    inference_types=inference_types,
)

In [None]:
############# Get movies and labels (and instances if testing) #############

xs = dataset.get_movies()
ys = dataset.get_labels()

if testing:
    ys_instances = dataset.get_instances()

    # convert instance masks to dictionaries
    ys_instances = {
        i: masks_to_instances_dict(
            instances_mask=instances_mask,
            labels_mask=ys[i],
            shift_ids=True,
        )
        for i, instances_mask in ys_instances.items()
    }

    # remove ignored events entry from ys_instances
    for inference in ys_instances:
    ys_instances[inference].pop("ignore", None)

In [None]:
#################### Get processed output (if required) ####################

if get_final_pred:
    logger.debug("Getting processed output (segmentation and instances)")

    final_segmentation_dict = {}
    final_instances_dict = {}
    for i in range(len(sample_ids)):
        movie_segmentation = {}
        movie_instances = {}

        for inference in inference_types:
            # transform raw predictions into a dictionary
            raw_preds_dict = {
                event_type: raw_preds[i][inference][event_label]
                for event_type, event_label in config.classes_dict
            }

            preds_instances, preds_segmentation, _ = process_raw_predictions(
                raw_preds_dict=raw_preds_dict,
                input_movie=xs[i],
                training_mode=False,
                debug=debug,
            )

            movie_segmentation[inference] = preds_segmentation
            movie_instances[inference] = preds_instances

        final_segmentation_dict[sample_ids[i]] = movie_segmentation
        final_instances_dict[sample_ids[i]] = movie_instances

else:
    final_segmentation_dict = {}
    final_instances_dict = {}

In [None]:
############################ Save preds on disk ############################

logger.info(f"\tSaving annotations and predictions...")

for i, sample_id in enumerate(sample_ids):
    for inference in inference_types:
        video_name = f"{str(params.load_epoch)}_{sample_id}_{inference}"

        raw_preds_movie = raw_preds[i][inference]
        if get_final_pred:
            segmented_preds_movie = final_segmentation_dict[sample_id][inference]
            instances_preds_movie = final_instances_dict[sample_id][inference]
        else:
            segmented_preds_movie = None
            instances_preds_movie = None

        write_videos_on_disk(
            training_name=output_name,
            video_name=video_name,
            path=os.path.join(save_folder, "inference_" + inference),
            xs=xs[i],
            ys=ys[i],
            raw_preds=raw_preds_movie,
            segmented_preds=segmented_preds_movie,
            instances_preds=instances_preds_movie,
        )

logger.info(f"DONE")

## Visualize preds with Napari

- load given training at given iteration (testing section of config file) and visualize predictions using Napari
- (idea is to check if it makes sense to stop the training earlier)
- adapt dataset size (& other params) in config file prior to execute code above if necessary

Trainings already checked with this method:
- final_model (minimal dataset, only movie 34)
    - 10K; 20K; 30K; 40K: too many sparks, does not make sense (>2K after correction)
    - 50K: too many sparks, does not make sense (>1K after correction)
    - 60K: 72 sparks before correction/82 sparks after correction
    - 70K: 71/78 -> does not change much from 60K aesthetically
    - 80K: 53/55 -> same; in the end what matters is to chose the best metrics (I believe)
    - 90K: 51/55 -> same
    - 100K: 61/63 -> same

In [None]:
cmap = get_discrete_cmap(name="gray", lut=16)

In [None]:
preds_classes = {}

In [None]:
sample_id = sample_ids[0]

# get contours of annotations, for visualization
ys_contours = get_annotations_contour(annotations=ys[sample_id], contour_val=2)

# get predicted segmentation and event instances
# _, raw_sparks, raw_waves, raw_puffs = preds[sample_id]
# preds_events = preds_instances[sample_id]

preds_classes[params["load_epoch"]] = (
    preds_segmentation[sample_id]["sparks"]
    + 3 * preds_segmentation[sample_id]["puffs"]
    + 2 * preds_segmentation[sample_id]["waves"]
)

In [None]:
viewer = napari.Viewer()
viewer.add_image(input_movies[sample_id], name="raw", colormap=("colors", cmap))
viewer.add_labels(ys_contours, name="gt", opacity=0.5, color=get_labels_cmap())

# viewer.add_image(raw_sparks, name='raw_sparks')
# viewer.add_image(raw_waves, name='raw_waves')
# viewer.add_image(raw_puffs, name='raw_puffs')
for epoch in preds_classes.keys():
    viewer.add_labels(
        preds_classes[epoch],
        name=f"preds_{epoch}",
        opacity=0.3,
        color=get_labels_cmap(),
    )

## Compute metrics

- if considering more than one inference type:
    - `preds_instances` is a nested dict indexed first by movie id, then by inference type, and finally by class. E.g., `preds_instances['05']['overlap']['sparks']` is a numpy array of shape (500, 64, 512) with integer values denoting the events' IDs.
    - `preds_segmentation` is a nested dict indexed first by movie id, then by inference type, and finally by class e.g., `preds_segmentation['05']['overlap']['sparks']` is a numpy array of shape (500, 64, 512) with boolean values denoting the events' presence.
- if considering only one inference type: `preds_instances` and `preds_segmentation`
are the same as above, buth without the inference type index 
- `ys_instances` is a nested dict indexed by movie id and class of events containing arrays with integer values denoting the events' IDs.
- `ys` is a dict indexed by movie id with integers between 0 and 4 to denote the class of the annotated events.

In [None]:
if not testing:
    print("!!!!!!!!! THE FOLLOWING CODE WON'T WORK !!!!!!!!!")

In [None]:
# compute loss on all samples
sum_loss = 0.0  # SERVE???

In [None]:
# to unify the results for any number of inferences, transform 'preds_instances'
# and 'preds_segmentation' adding a new "nested" key 'inference_type'

if len(inference_types) == 1:
    temp_instances = {
        movie_id: {params.inference: preds_instances[movie_id]}
        for movie_id in sample_ids
    }
    temp_segmentation = {
        movie_id: {params.inference: preds_segmentation[movie_id]}
        for movie_id in sample_ids
    }

    preds_instances = temp_instances
    preds_segmentation = temp_segmentation

### Segmentation-based metrics (i.e., IoU)

In [None]:
# define the columns of the dataframe
segmentation_cols = ["inference_type", "event_type", "iou"]

In [None]:
# concatenate annotations and preds to compute segmentation-based metrics
ys_concat = []
preds_concat = {i: [] for i in inference_types}

for sample_id in sample_ids:
    ys_concat.append(ys[sample_id])

    for i in inference_types:
        # get preds segmentation as integer array with values in [0, 1, 2, 3]
        temp_preds = dict_to_int_mask(preds_segmentation[sample_id][i])
        preds_concat[i].append(temp_preds)

ys_concat = np.concatenate(ys_concat)
preds_concat = {i: np.concatenate(preds_concat[i]) for i in inference_types}

# get masks for pixels labelled with 4
ignore_concat = ys_concat == 4

In [None]:
# compute intersection over union for each inference type
# (per class, average, and as binary classification)

iou_dict = {i: {} for i in inference_types}

for i in inference_types:
    for event_type, event_label in config.classes_dict.items():
        if event_type in ["ignore", "background"]:
            continue
        class_ys = ys_concat == event_label
        class_preds = preds_concat[i] == event_label

        iou_dict[i][event_type] = compute_iou(
            ys_roi=class_ys, preds_roi=class_preds, ignore_mask=ignore_concat
        )

    # compute average iou
    iou_dict[i]["average"] = np.mean(list(iou_dict[i].values()))

    # compute binary classification iou
    iou_dict[i]["binary"] = compute_iou(
        ys_roi=ys_concat != 0, preds_roi=preds_concat[i] != 0, ignore_mask=ignore_concat
    )

In [None]:
pd.set_option("display.precision", 3)

# create dataframe where index is event type and columns are inference types
df_barplot = pd.DataFrame(iou_dict).T.T
df_barplot

In [None]:
df_barplot.plot.bar(rot=0, figsize=(10, 5))

### Instance-based metrics

In [None]:
ca_release_events = [
    event_type
    for event_type in config.classes_dict.keys()
    if event_type not in ["ignore", "background"]
]

In [None]:
# sum events over all samples
preds_cat = ["tot", "tp", "ignored", "unlabeled"] + ca_release_events
ys_cat = ["tot", "tp", "undetected"] + ca_release_events

In [None]:
# initialize dicts that will contain the results
matched_preds_ids = {i: {} for i in inference_types}
matched_ys_ids = {i: {} for i in inference_types}

for i in inference_types:
    matched_preds_ids[i]["sum"] = {}
    matched_ys_ids[i]["sum"] = {}
    for ca_event in ca_release_events:
        matched_preds_ids[i]["sum"][ca_event] = {
            cat: 0 for cat in preds_cat if cat != ca_event
        }
        matched_ys_ids[i]["sum"][ca_event] = {
            cat: 0 for cat in ys_cat if cat != ca_event
        }

for sample_id in sample_ids:
    logger.info(f"Processing sample {sample_id}...")

    # get ignore mask for this sample
    ignore_mask = ys[sample_id] == 4

    # compute pairwise scores between annotated and predicted ROIs
    # (for each inference type)
    for i in inference_types:
        logger.info(f"\tInference type {i}...")
        iomin_scores = get_score_matrix(
            ys_instances=ys_instances[sample_id],
            preds_instances=preds_instances[sample_id][i],
            ignore_mask=None,
            score="iomin",
        )

        # get ids of matched ROIs
        (
            matched_ys_ids[i][sample_id],
            matched_preds_ids[i][sample_id],
        ) = get_matches_summary(
            ys_instances=ys_instances[sample_id],
            preds_instances=preds_instances[sample_id][i],
            scores=iomin_scores,
            ignore_mask=ignore_mask,
        )

        # count number of categorized events that are necessary for the metrics
        for ca_event in ca_release_events:
            for cat in matched_ys_ids[i][sample_id][ca_event].keys():
                matched_ys_ids[i]["sum"][ca_event][cat] += len(
                    matched_ys_ids[i][sample_id][ca_event][cat]
                )

            for cat in matched_preds_ids[i][sample_id][ca_event].keys():
                matched_preds_ids[i]["sum"][ca_event][cat] += len(
                    matched_preds_ids[i][sample_id][ca_event][cat]
                )

In [None]:
# get dict that only contains sum over all samples
tot_preds = {
    i: {
        ca_class: matched_preds_ids[i]["sum"][ca_class]["tot"]
        for ca_class in ca_release_events
    }
    for i in inference_types
}
tp_preds = {
    i: {
        ca_class: matched_preds_ids[i]["sum"][ca_class]["tp"]
        for ca_class in ca_release_events
    }
    for i in inference_types
}
ignored_preds = {
    i: {
        ca_class: matched_preds_ids[i]["sum"][ca_class]["ignored"]
        for ca_class in ca_release_events
    }
    for i in inference_types
}
unlabeled_preds = {
    i: {
        ca_class: matched_preds_ids[i]["sum"][ca_class]["unlabeled"]
        for ca_class in ca_release_events
    }
    for i in inference_types
}
tot_ys = {
    i: {
        ca_class: matched_ys_ids[i]["sum"][ca_class]["tot"]
        for ca_class in ca_release_events
    }
    for i in inference_types
}
tp_ys = {
    i: {
        ca_class: matched_ys_ids[i]["sum"][ca_class]["tp"]
        for ca_class in ca_release_events
    }
    for i in inference_types
}
undetected_ys = {
    i: {
        ca_class: matched_ys_ids[i]["sum"][ca_class]["undetected"]
        for ca_class in ca_release_events
    }
    for i in inference_types
}


metrics_all = {i: {} for i in inference_types}

# get other metrics (precision, recall, % correctly classified, % detected)
for i in inference_types:
    metrics_all[i] = get_metrics_from_summary(
        tot_preds=tot_preds[i],
        tp_preds=tp_preds[i],
        ignored_preds=ignored_preds[i],
        unlabeled_preds=unlabeled_preds[i],
        tot_ys=tot_ys[i],
        tp_ys=tp_ys[i],
        undetected_ys=undetected_ys[i],
    )

In [None]:
# Compute same metrics but as % instad of absolute numbers
matched_preds_percent = {
    i: {ca_event: {} for ca_event in ca_release_events} for i in inference_types
}
matched_ys_percent = {
    i: {ca_event: {} for ca_event in ca_release_events} for i in inference_types
}

for i in inference_types:
    for ca_event in ca_release_events:
        for cat in matched_preds_ids[i]["sum"][ca_event].keys():
            if cat != "tot":
                matched_preds_percent[i][ca_event][cat] = (
                    matched_preds_ids[i]["sum"][ca_event][cat]
                    / matched_preds_ids[i]["sum"][ca_event]["tot"]
                    * 100
                )

        for cat in matched_ys_ids[i]["sum"][ca_event].keys():
            if cat != "tot":
                matched_ys_percent[i][ca_event][cat] = (
                    matched_ys_ids[i]["sum"][ca_event][cat]
                    / matched_ys_ids[i]["sum"][ca_event]["tot"]
                    * 100
                )

In [None]:
print("Summary of detected events")
for i in inference_types:
    df = get_df_summary_events(
        inference_type=i,
        matched_ids=matched_preds_ids,
        matched_percent=matched_preds_percent,
        is_detected=True,
    )
    # Format the DataFrame
    styled_df = df.style.format(precision=2, na_rep="N/A")
    print(f"{i} inference")
    display(styled_df)

print("Summary of labeled events")
for i in inference_types:
    df = get_df_summary_events(
        inference_type=i,
        matched_ids=matched_ys_ids,
        matched_percent=matched_ys_percent,
        is_detected=False,
    )
    # Format the DataFrame
    styled_df = df.style.format(precision=2, na_rep="N/A")
    print(f"{i} inference")
    display(styled_df)

In [None]:
for i in inference_types:
    df = get_df_metrics(inference_type=i, metrics_all=metrics_all)
    # Set the display precision to 2 decimal places
    pd.set_option("display.precision", 2)
    # Format the DataFrame
    styled_df = df.style.format(precision=2)
    print(f"Metrics using {i} inference")
    display(styled_df)

## Visualize other properties of data...

### Plot histograms of raw predictions

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(5, 5))

for i, ax in zip(inference_types, axs.flatten()):
    # concatenate raw predictions
    raw_preds_concat = [preds_dict[sample_id][i] for sample_id in sample_ids]
    raw_preds_concat = np.concatenate(raw_preds_concat, axis=1)

    ax.hist(raw_preds_concat.flatten(), bins=10)
    ax.set_title(
        f"Histogram of raw predictions \n(inference type '{i}')", fontsize=10.0
    )

fig.tight_layout()

### code to try and visualize different types of inference

In [None]:
# ## DEBUG ###

# test_dataset = SparkDataset(
#         base_path=dataset_path,
#         sample_ids=['34'],
#         testing=testing,
#         smoothing=c.get("dataset", "data_smoothing"),
#         step=c.getint("testing", "data_stride"),
#         #step=2,
#         duration=c.getint("testing", "data_duration"),
#         remove_background=c.get("dataset", "remove_background"),
#         temporal_reduction=c.getboolean(
#             "network", "temporal_reduction", fallback=False
#         ),
#         num_channels=c.getint("network", "num_channels", fallback=1),,
#         normalize_video=c.get("dataset", "norm_video"),
#         only_sparks=c.getboolean("dataset", "only_sparks", fallback=False),
#         sparks_type=c.get("dataset", "sparks_type"),
#         ignore_frames=c.getint("training", "ignore_frames_loss"),
#         ignore_index=4,
#         gt_available=True,
#         inference=inference,
#     )

# testing_dataloader = torch.utils.data.DataLoader(
#     test_dataset,
#     batch_size=batch_size,
#     shuffle=False,
#     num_workers=0,
#     pin_memory=True,
# )

# from training_inference_tools import do_inference

# test_dataset.inference = 'overlap'
# pred_overlap = do_inference(network=network,
#                             test_dataset=test_dataset,
#                             test_dataloader=testing_dataloader,
#                             device=device,
#                             )


# test_dataset.inference = 'average'
# pred_average = do_inference(network=network,
#                             test_dataset=test_dataset,
#                             test_dataloader=testing_dataloader,
#                             device=device,
#                             )

# test_dataset.inference = 'max'
# pred_max = do_inference(network=network,
#                             test_dataset=test_dataset,
#                             test_dataloader=testing_dataloader,
#                             device=device,
#                             )

# test_dataset.inference = 'gaussian'
# pred_gaussian = do_inference(network=network,
#                             test_dataset=test_dataset,
#                             test_dataloader=testing_dataloader,
#                             device=device,
#                             )

# pred_overlap.shape, pred_average.shape, pred_max.shape, pred_gaussian.shape
# empty_vertical = np.ones((pred_overlap.shape[0],
#                           pred_overlap.shape[1],
#                           pred_overlap.shape[2],
#                           10))*0.5

# empty_horizontal = np.ones((pred_overlap.shape[0],
#                             pred_overlap.shape[1],
#                             10,
#                             2*pred_overlap.shape[3]+10))*0.5
# # stack the four predictions togheter in a squared grid

# stack1 = np.concatenate((pred_overlap, empty_vertical, pred_average), axis=3)
# stack2 = np.concatenate((pred_max, empty_vertical, pred_gaussian), axis=3)

# stack_all = np.concatenate((stack1, empty_horizontal, stack2), axis=2)
# import napari
# viewer = napari.Viewer()
# viewer.theme = 'dark'

# viewer.add_image(stack_all[0],
#                  name='background',
#                     #colormap='white',
#                     blending='additive',
#                     opacity=0.5,
#                     #visible=False,
#                 )

# viewer.add_image(stack_all[1],
#                     name='sparks',
#                     colormap='green',
#                     blending='additive',
#                     opacity=0.5,
#                     #visible=False,
#                 )

# viewer.add_image(stack_all[2],
#                     name='waves',
#                     colormap='yellow',
#                     blending='additive',
#                     opacity=0.5,
#                     #visible=False,
#                 )

# viewer.add_image(stack_all[3],
#                     name='puffs',
#                     colormap='red',
#                     blending='additive',
#                     opacity=0.5,
#                     #visible=False,
#                 )

### code to visualize a confusion matrix (OLD)

In [None]:
# plt.rcParams.update({"font.size": 24})

# n_rows = 1
# n_cols = 1
# num_plots = n_rows * n_cols

# pad = 5  # in points

# fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(20, 100))

# cols = ["IoU"]
# tick_labels = ["Background", "Sparks", "Waves", "Puffs"]

# for ax, col in zip(axs[0], cols):
#     ax.annotate(
#         col,
#         xy=(0.5, 1.2),
#         xytext=(0, pad),
#         xycoords="axes fraction",
#         textcoords="offset points",
#         size="large",
#         ha="center",
#         va="baseline",
#     )


# for ax, row in zip(axs[:, 0], ["TOT"] + movie_ids):
#     ax.annotate(
#         row,
#         xy=(0, 0.5),
#         xytext=(-ax.yaxis.labelpad - pad, 0),
#         xycoords=ax.yaxis.label,
#         textcoords="offset points",
#         size="large",
#         ha="right",
#         va="center",
#     )

# fig.suptitle("Confusion matrices", fontsize=36, y=1)


# # configure heatmap background
# colors = sns.color_palette(
#     ["white", "lightcoral", "paleturquoise", "lemonchiffon", "lightgreen"], as_cmap=True
# )
# colored_bg = [[0, 2, 2, 2], [1, 4, 3, 3], [1, 3, 4, 3], [1, 3, 3, 4]]

# # Get array with confusion matrices to be plotted
# cm_array = np.concatenate(
#     (
#         [[iou_confusion_matrix_tot, iomin_confusion_matrix_tot]],
#         [
#             [iou_confusion_matrix[sample_id], iomin_confusion_matrix[sample_id]]
#             for sample_id in movie_ids
#         ],
#     ),
#     axis=0,
# )

# for row_id in range(n_rows):
#     for col_id in range(n_cols):
#         cm = cm_array[row_id, col_id].astype(int).astype(str)

#         ax = axs[row_id, col_id]
#         sns.heatmap(
#             data=colored_bg,
#             cmap=colors,
#             annot=cm,
#             fmt="",
#             annot_kws={"fontsize": 36},
#             cbar=False,
#             square=True,
#             ax=ax,
#         )

#         ax.tick_params(length=0, labeltop=True, labelbottom=False)
#         ax.tick_params(axis="both", which="major", pad=16)

#         ax.set_xlabel("Predicted", labelpad=32)
#         ax.xaxis.set_label_position("top")
#         ax.set_xticklabels(tick_labels)
#         ax.add_patch(
#             plt.Rectangle(
#                 (-0.01, 1),
#                 1.01,
#                 0.1,
#                 color="yellow",
#                 clip_on=False,
#                 zorder=0,
#                 transform=ax.transAxes,
#             )
#         )

#         ax.set_ylabel("Actual Values", labelpad=32)
#         ax.set_yticklabels(tick_labels, rotation=90, va="center")
#         ax.add_patch(
#             plt.Rectangle(
#                 (0, 0),
#                 -0.1,
#                 1,
#                 color="yellow",
#                 clip_on=False,
#                 zorder=0,
#                 transform=ax.transAxes,
#             )
#         )

# # plt.subplots_adjust(hspace=0.005, wspace=0.)
# # plt.subplots_adjust(hspace=0.1, wspace=0.1, top=0.9, left=0.05, right=0.95)
# # fig.subplots_adjust(left=0.15, top=0.95)
# fig.subplots_adjust(wspace=1.5)
# plt.tight_layout()
# plt.savefig(os.path.join(out_dir, "all_confusion_matrices.png"))
# plt.show()