In [1]:
from pathlib import Path
from common import get_data_module
import torchmetrics
import numpy as np
import torch
import tqdm
import pandas as pd
from typing import Dict
import pandas as pd



## Utility Functions

In [2]:
def compute_metrics_per_sample(pred, label, num_classes = 6, class_names = ("0", "1", "2", "3", "4", "5", "7")):
    # Initialize metrics without averaging (per class)
    iou = torchmetrics.JaccardIndex(task="multiclass", num_classes=num_classes, average=None)
    precision = torchmetrics.Precision(task="multiclass", num_classes=num_classes, average=None)
    recall = torchmetrics.Recall(task="multiclass", num_classes=num_classes, average=None)
    f1_score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes, average=None)
    accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes, average=None)
    specificity = torchmetrics.Specificity(task="multiclass", num_classes=num_classes, average=None)
    cohen_kappa = torchmetrics.CohenKappa(task="multiclass", num_classes=num_classes)
    mcc = torchmetrics.MatthewsCorrCoef(task="multiclass", num_classes=num_classes)
    hamming_loss = torchmetrics.HammingDistance(task="multiclass", num_classes=num_classes)

    # Initialize weighted metrics
    weighted_precision = torchmetrics.Precision(task="multiclass", num_classes=num_classes, average="weighted")
    weighted_recall = torchmetrics.Recall(task="multiclass", num_classes=num_classes, average="weighted")
    weighted_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes, average="weighted")

    # print("shape pred antes: ", pred.shape)
    # print("valores unicos: ", np.unique(pred))
    # print("shape label antes: ", label.shape)
    # print("valores unicos: ", np.unique(label))
    # Convert numpy arrays to tensors and add batch dimension
    pred_tensor = torch.from_numpy(pred).unsqueeze(0)
    label_tensor = torch.from_numpy(label).unsqueeze(0)
    # print("shape pred_tensor depois: ", pred_tensor.shape)
    # print("valores unicos: ", torch.unique(pred_tensor))
    # print("shape label_tensor depois: ", label_tensor.shape)
    # print("valores unicos: ", torch.unique(label_tensor))

    # Calculate per-class metrics
    iou_per_class = iou(pred_tensor, label_tensor)
    iou_dict = {class_names[i]: iou_per_class[i].item() for i in range(num_classes)}

    precision_per_class = precision(pred_tensor, label_tensor)
    precision_dict = {class_names[i]: precision_per_class[i].item() for i in range(num_classes)}

    recall_per_class = recall(pred_tensor, label_tensor)
    recall_dict = {class_names[i]: recall_per_class[i].item() for i in range(num_classes)}

    f1_per_class = f1_score(pred_tensor, label_tensor)
    f1_dict = {class_names[i]: f1_per_class[i].item() for i in range(num_classes)}

    accuracy_per_class = accuracy(pred_tensor, label_tensor)
    accuracy_dict = {class_names[i]: accuracy_per_class[i].item() for i in range(num_classes)}

    specificity_per_class = specificity(pred_tensor, label_tensor)
    specificity_dict = {class_names[i]: specificity_per_class[i].item() for i in range(num_classes)}

    # Compute weighted metrics
    weighted_precision_value = weighted_precision(pred_tensor, label_tensor).item()
    weighted_recall_value = weighted_recall(pred_tensor, label_tensor).item()
    weighted_f1_value = weighted_f1(pred_tensor, label_tensor).item()

    # Compute Cohen's Kappa, MCC, and Hamming Loss
    cohen_kappa_value = cohen_kappa(pred_tensor, label_tensor).item()
    mcc_value = mcc(pred_tensor, label_tensor).item()
    hamming_loss_value = hamming_loss(pred_tensor, label_tensor).item()

    # Calculate the percentage of each class in the label
    label_flat = label_tensor.flatten()
    class_counts = torch.bincount(label_flat, minlength=num_classes).float()  # Count occurrences of each class
    total_samples = class_counts.sum().item()
    class_percentages = {class_names[i]: (class_counts[i].item() / total_samples * 100) if total_samples > 0 else 0.0 for i in range(num_classes)}

    # Calculate weighted accuracy manually
    weighted_accuracy_value = sum(
        (accuracy_per_class[i] * class_counts[i]).item() for i in range(num_classes)
    ) / total_samples if total_samples > 0 else 0.0

    # Compute mean values while handling potential NaNs
    if "6" in iou_dict:
        iou_dict.pop("6")
        print(f"6 not in iou_dict anymore: {iou_dict}")
    if "6" in precision_dict:
        precision_dict.pop("6")
    if "6" in recall_dict:
        recall_dict.pop("6")
    if "6" in f1_dict:
        f1_dict.pop("6")
    if "6" in accuracy_dict:
        accuracy_dict.pop("6")
    if "6" in specificity_dict:
        specificity_dict.pop("6")
    
    mean_iou = np.nanmean(list(iou_dict.values()))
    mean_precision = np.nanmean(list(precision_dict.values()))
    mean_recall = np.nanmean(list(recall_dict.values()))
    mean_f1 = np.nanmean(list(f1_dict.values()))
    mean_accuracy = np.nanmean(list(accuracy_dict.values()))
    mean_specificity = np.nanmean(list(specificity_dict.values()))
    
    
    summary = {
        "Mean IoU": mean_iou,
        "Mean F1 Score": mean_f1,
        "Mean Recall": mean_recall,
        "Mean Precision": mean_precision,
        "Mean Accuracy": mean_accuracy,
        "Mean Specificity": mean_specificity,
        "Weighted Precision": weighted_precision_value,
        "Weighted Recall": weighted_recall_value,
        "Weighted F1 Score": weighted_f1_value,
        "Weighted Accuracy": weighted_accuracy_value,
        "Cohen's Kappa": cohen_kappa_value,
        "Matthews Correlation Coefficient": mcc_value,
        "Hamming Loss": hamming_loss_value,
    }


    for m_name, m_dict in [
        ("IoU", iou_dict),
        ("Precision", precision_dict),
        ("Recall", recall_dict),
        ("F1 Score", f1_dict),
        ("Accuracy", accuracy_dict),
        ("Specificity", specificity_dict),
        ("Class percentages", class_percentages)
    ]:
        for class_name in sorted(class_names):
            if class_name in m_dict:
                summary[f"{m_name} {class_name}"] = m_dict[class_name]
            else:
                summary[f"{m_name} {class_name}"] = 0

    return summary

## Data loading

In [3]:
# root_predictions_path = Path("/workspaces/HIAAC-KR-Dev-Container/Minerva-Dev/docs/notebooks/examples/seismic/facies_classification/parihaka/predictions")
root_predictions_path = Path("/workspaces/Minerva-Discovery/my_experiments/sam_original/evaluate_experiments/parihaka/tmp/predictions/sam_vit_b_experiment_3")
root_data_dir = Path(
    "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/images"
)
root_annotation_dir = Path(
    "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/annotations"
)

data_module = get_data_module(
    root_data_dir=root_data_dir,
    root_annotation_dir=root_annotation_dir,
    img_size=None,
    batch_size=1,
    single_channel=True
)

## Looping over predictions

The cell below will loop over the predictions (npy files), calculcates several metrics for each prediction (using Parihaka's test set), and stores the results in a csv file, with same name as the prediction file, but with the `.csv` extension. The csv is stored in the same directory as the predictions.

**NOTE**: No metrics are calculated if the csv file already exists.

In [4]:
def summary_dataset(dataset, predictions, dname):
    results = []
    
    n_classes = 7
    
    for i in tqdm.tqdm(range(len(dataset)), total=len(dataset), desc=f"Processing dataset {dname}"):
        img, label = dataset[i]
        img = img.squeeze(0)
        pred = predictions[i]
        for j in range(pred.shape[0]):
            single_pred = pred[j]
            # pred = pred.argmax(axis=0)
            result = compute_metrics_per_sample(single_pred, label, num_classes=n_classes)
            results.append(result)
        
    return pd.DataFrame(results)

for pred_path in sorted(root_predictions_path.rglob("*.npy")):
    csv_path = pred_path.with_suffix(".csv")
    if csv_path.exists():
        print(f"Skipping {pred_path} as {csv_path.name} already exists")
        continue
    
    data_module.setup("predict")
    predict_dataset = data_module.datasets["predict"]
    print(f"Loading predictions from {pred_path}")
    predictions = np.load(pred_path) # (200, 10, 1006, 590)

    df = summary_dataset(predict_dataset, predictions, pred_path.stem)
    df.to_csv(csv_path, index=False)
    
    print(f"Saved results to {csv_path}")
    print()

Loading predictions from /workspaces/Minerva-Discovery/my_experiments/sam_original/evaluate_experiments/parihaka/tmp/predictions/sam_vit_b_experiment_3/sam_vit_b_experiment_3.npy


Processing dataset sam_vit_b_experiment_3: 100%|██████████| 200/200 [01:07<00:00,  2.95it/s]

Saved results to /workspaces/Minerva-Discovery/my_experiments/sam_original/evaluate_experiments/parihaka/tmp/predictions/sam_vit_b_experiment_3/sam_vit_b_experiment_3.csv






## Make dim view

In [5]:
data_module.setup("predict")
predict_dataset = data_module.datasets["predict"]
labels = np.array([predict_dataset[i][1] for i in range(len(predict_dataset))])
labels.shape

(200, 1006, 590)

In [10]:
pred_path = Path("/workspaces/Minerva-Discovery/my_experiments/sam_original/evaluate_experiments/parihaka/tmp/predictions/sam_vit_b_experiment_1/sam_vit_b_experiment_1.npy")
predictions = np.load(pred_path)
print(predictions.shape)
# predictions = predictions[:, 0:6, :, :]
# predictions = predictions.argmax(axis=1)

pred_path, predictions.shape

(200, 10, 1006, 590)


(PosixPath('/workspaces/Minerva-Discovery/my_experiments/sam_original/evaluate_experiments/parihaka/tmp/predictions/sam_vit_b_experiment_1/sam_vit_b_experiment_1.npy'),
 (200, 10, 1006, 590))

In [9]:
import torch
import numpy as np
import torchmetrics

def compute_mean_iou_plane(predictions, labels, axis):
    """
    Compute the mean IoU along a given axis.
    
    Args:
        predictions (numpy.ndarray): The predicted labels, shape (I, J, K)
        labels (numpy.ndarray): The ground truth labels, shape (I, J, K)
        axis (int): Axis to iterate over (0 for JxK, 1 for IxK, 2 for IxJ)

    Returns:
        numpy.ndarray: A 2D plane with mean IoU values
    """

    # Define IoU metric
    iou = torchmetrics.JaccardIndex(task="multiclass", num_classes=6, average=None)

    # Get shape
    I, J, K = predictions.shape

    if axis == 0:  # Compute JxK (Iterate over i)
        result_shape = (J, K)
        iou_map = np.zeros(result_shape)

        for j in range(J):
            for k in range(K):
                pred_trace = predictions[:, j, k]
                label_trace = labels[:, j, k]

                pred_tensor = torch.from_numpy(pred_trace).unsqueeze(0)
                label_tensor = torch.from_numpy(label_trace).unsqueeze(0)

                iou_per_class = iou(pred_tensor, label_tensor)
                iou_map[j, k] = torch.mean(iou_per_class).item()

    elif axis == 1:  # Compute IxK (Iterate over j)
        result_shape = (I, K)
        iou_map = np.zeros(result_shape)

        for i in range(I):
            for k in range(K):
                pred_trace = predictions[i, :, k]
                label_trace = labels[i, :, k]

                pred_tensor = torch.from_numpy(pred_trace).unsqueeze(0)
                label_tensor = torch.from_numpy(label_trace).unsqueeze(0)

                iou_per_class = iou(pred_tensor, label_tensor)
                iou_map[i, k] = torch.mean(iou_per_class).item()

    elif axis == 2:  # Compute IxJ (Iterate over k)
        result_shape = (I, J)
        iou_map = np.zeros(result_shape)

        for i in range(I):
            for j in range(J):
                pred_trace = predictions[i, j, :]
                label_trace = labels[i, j, :]

                pred_tensor = torch.from_numpy(pred_trace).unsqueeze(0)
                label_tensor = torch.from_numpy(label_trace).unsqueeze(0)

                iou_per_class = iou(pred_tensor, label_tensor)
                iou_map[i, j] = torch.mean(iou_per_class).item()

    else:
        raise ValueError("Invalid axis. Choose 0 (JxK), 1 (IxK), or 2 (IxJ).")

    return iou_map


In [11]:
import matplotlib.pyplot as plt


def plot_iou_map(
    iou_map,
    title="IoU Map",
    cmap="viridis",
    xlabel="Dimension 2",
    ylabel="Dimension 1",
    figsize=(8, 6),
    save_path=None,
    show=True,
):
    """
    Plot the IoU map using Matplotlib with full customization.

    Args:
        iou_map (numpy.ndarray): 2D array of IoU values.
        title (str): Title of the plot.
        cmap (str): Colormap to use (e.g., 'viridis', 'plasma', 'jet', etc.).
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        figsize (tuple): Size of the figure (width, height).
        save_path (str or None): If provided, saves the plot to the given file path.
        show (bool): Whether to display the figure.

    Returns:
        matplotlib.figure.Figure: The generated figure.
    """
    fig, ax = plt.subplots(figsize=figsize)
    cax = ax.imshow(iou_map, cmap=cmap, interpolation="nearest")
    fig.colorbar(cax, label="Mean IoU")

    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close(fig)  # Prevents displaying if show=False

    return fig

In [None]:
for i range(len(predictions))
plane_jk = compute_mean_iou_plane(predictions, labels, axis=0)
print(plane_jk.shape)

In [None]:
fig = plot_iou_map(plane_jk, title="Mean IoU (JxK)", xlabel="K", ylabel="J", cmap="viridis")

In [None]:
plane_ik = compute_mean_iou_plane(predictions, labels, axis=1)
print(plane_ik.shape)

In [None]:
fig = plot_iou_map(plane_ik, title="Mean IoU (IxK)", xlabel="K", ylabel="I", cmap="viridis")

In [None]:
plane_ij = compute_mean_iou_plane(predictions, labels, axis=2)
print(plane_ij.shape)

In [None]:
fig = plot_iou_map(plane_ij, title="Mean IoU (IxJ)", xlabel="J", ylabel="I", cmap="viridis")

In [None]:
import torch
import numpy as np
import torchmetrics

# Define IoU metric
iou = torchmetrics.JaccardIndex(task="multiclass", num_classes=6, average=None)

# Assuming predictions and labels are numpy arrays of shape (I, J, K)
I, J, K = predictions.shape

# Storage for mean IoU values for each (i, j)
iou_map = np.zeros((I, J))

# Iterate over each (i, j) trace
for i in range(I):
    for j in range(J):
        # Get the trace along k-axis
        pred_trace = predictions[i, j, :]
        label_trace = labels[i, j, :]

        # Convert to PyTorch tensors and add batch dimension
        pred_tensor = torch.from_numpy(pred_trace).unsqueeze(0)
        label_tensor = torch.from_numpy(label_trace).unsqueeze(0)

        # Compute per-class IoU and take mean
        iou_per_class = iou(pred_tensor, label_tensor)
        mean_iou = torch.mean(iou_per_class).item()

        # Store result
        iou_map[i, j] = mean_iou

In [None]:
from matplotlib import pyplot as plt
plt.imshow(iou_map, cmap="viridis")

## Loading CSVs

The cell below will load all the csv files and create a dictionary, where the key is the name of the model, and the value is another dictionary, where the key is the run id and the value is the dataframe with the metrics.

The metrics for multiple runs are grouped into a single dataframe, by taking the mean and standard deviation of the metrics. Thus, the final result is a dictionary where key is the model name and value is a dataframe with the mean and standard deviation of the metrics across the multiple runs.

In [None]:
def merge_dfs(
    dfs_dict: Dict[str, Dict[str, pd.DataFrame]]
) -> Dict[str, pd.DataFrame]:
    """Utilitary function to merge multiple DataFrames into a single one."""
    merged_results = {}

    for model_name, runs in dfs_dict.items():
        # List to store DataFrames for this model
        df_list = [
            df.assign(
                sample=list(df.index.tolist())
            )  # Add a column to track runs (optional)
            for i, df in enumerate(runs.values())
        ]

        # Combine all runs into a single DataFrame
        combined_df = pd.concat(df_list)

        # Compute mean and std for each column
        # (index column is the respective inline/crossline. They are ordered)
        mean_df = combined_df.groupby("sample").mean().fillna(0)
        std_df = (
            combined_df.groupby("sample").std().fillna(0)
        )  # Fill NaNs with 0

        # Rename standard deviation columns
        std_df = std_df.rename(
            columns={col: f"{col} (std)" for col in std_df.columns}
        )

        # Merge mean and std dataframes
        final_df = pd.concat([mean_df, std_df], axis=1)

        # Store in result dictionary
        merged_results[model_name] = final_df.reset_index()

    return merged_results

In [None]:
dfs = {}

for p in sorted(root_predictions_path.rglob("*.csv")):
    model_name = p.stem
    if model_name not in dfs:
        dfs[model_name] = {}
    dfs[model_name][p.parent.stem] = pd.read_csv(p)


dfs = merge_dfs(dfs)
dfs.keys()

In [None]:
percentages = {}
for i in range(6):
    class_percentage = dfs["simclr"][f"Class percentages {i}"].mean()
    percentages[f"Class {i}"] = class_percentage
    
percentages

## Make Heatmaps

In [None]:
def get_metrics_df(dfs: Dict[str, pd.DataFrame], metric_name) -> pd.DataFrame:
    metrics = {}
    for model_name, df in dfs.items():
        results = {
            f"Mean {metric_name}": df[f"Mean {metric_name}"].mean(),
            f"{metric_name} 0": df[f"{metric_name} 0"].mean(),
            f"{metric_name} 1": df[f"{metric_name} 1"].mean(),
            f"{metric_name} 2": df[f"{metric_name} 2"].mean(),
            f"{metric_name} 3": df[f"{metric_name} 3"].mean(),
            f"{metric_name} 4": df[f"{metric_name} 4"].mean(),
            f"{metric_name} 5": df[f"{metric_name} 5"].mean(),
            
            f"Mean {metric_name} (std)": df[f"Mean {metric_name} (std)"].mean(),
            f"{metric_name} 0 (std)": df[f"{metric_name} 0 (std)"].mean(),
            f"{metric_name} 1 (std)": df[f"{metric_name} 1 (std)"].mean(),
            f"{metric_name} 2 (std)": df[f"{metric_name} 2 (std)"].mean(),
            f"{metric_name} 3 (std)": df[f"{metric_name} 3 (std)"].mean(),
            f"{metric_name} 4 (std)": df[f"{metric_name} 4 (std)"].mean(),
            f"{metric_name} 5 (std)": df[f"{metric_name} 5 (std)"].mean(),
        }
        
        metrics[model_name] = results
        
    metrics_df = pd.DataFrame(metrics).T.sort_index()
    metrics_df = metrics_df.sort_values(f"Mean {metric_name}", ascending=False) * 100
    return metrics_df

In [None]:
def get_text_df(metrics_df):
    non_std_cols = [col for col in metrics_df.columns if "std" not in col]
    new_metrics_df = metrics_df[non_std_cols].copy()
    for col in non_std_cols:
        new_metrics_df[col] = metrics_df.apply(
            lambda x: f"{x[col]:.2f}<br>(±{x[col + ' (std)']:.2f})", axis=1
        )
        
    return new_metrics_df


def get_df_without_std(metrics_df):
    non_std_cols = [col for col in metrics_df.columns if "std" not in col]
    return metrics_df[non_std_cols]
    

In [None]:
index_map = {
    "byol": "BYOL<br>(ResNet50)",
    "deeplabv3": "DeepLabV3<br>(ResNet50)",
    "tribyol": "TriBYOL<br>(ResNet50)",
    "sam": "SAM<br>(SAM-ViT-B)",
    "kenshodense": "KenShoDense<br>(ResNet50)",
    "fastsiam": "FastSiam<br>(ResNet50)",
    "simclr": "SimCLR<br>(ResNet50)",
    "lfr": "LFR<br>(ResNet50)",
    "dinov2_mla": "DINOv2-MLA<br>(DinoViT)",
    "dinov2_mla_reflect_6_classes": "DINOv2-MLA - reflect<br>(DinoViT)",
    "dinov2_mla_interpolate_6_classes": "DINOv2-MLA - interpolate<br>(DinoViT)",
    "dinov2_mla_constant_7_classes": "DINOv2-MLA - constant<br>(DinoViT)",
    "dinov2_dpt": "DINOv2-DPT<br>(DinoViT)",
    "dinov2_pup": "DINOx2-PUP<br>(DinoViT)",
    "sfm_base_patch16": "SFM<br>(SFM-ViT-B)",
    "setr_pup": "SETR-PUP<br>(ViT-L)"
}

In [None]:
import plotly.graph_objects as go

def plot_heat_map(metrics_df, text_df, colorbar_title: str = None, height: int = 1200, width: int = 800, index_map: dict = None):
    df = metrics_df.copy()
    if index_map is not None:
        df = df.rename(index=index_map)
    
    fig = go.Figure(
        go.Heatmap(
            x=df.columns,
            y=df.index[::-1],  # Reverse order of y-axis
            z=df.values[::-1],  # Reverse order of values
            text=text_df.values[::-1],  # Reverse order of text
            texttemplate="%{text}",
            textfont={"family": "Times New Roman", "size": 15},
            colorscale="Blues",
            colorbar_title=colorbar_title,  # Changed legend title to "IoU"
        )
    )

    fig.update_layout(
        height=height,
        width=width,
        xaxis=dict(
            showticklabels=True,  # Keep tick labels but remove title
            title=None,  # Remove x-axis title
            tickfont=dict(family="Times New Roman", size=16),
        ),
        margin=dict(l=0, r=0, t=10, b=10),
    )
    
    return fig

### IoU

In [None]:
metrics_df_iou = get_metrics_df(dfs, "IoU")
metrics_df = get_df_without_std(metrics_df_iou)
text_df = get_text_df(metrics_df_iou)

metrics_df = metrics_df.drop(index="setr_pup")
metrics_df = metrics_df.drop(index="dinov2_mla")
text_df = text_df.drop(index="setr_pup")
text_df = text_df.drop(index="dinov2_mla")


fig = plot_heat_map(metrics_df, text_df, colorbar_title="IoU (%)", height=1000, width=700, index_map=index_map)
fig.show()

fig.write_image(root_predictions_path / "iou_heatmap.png")
print(f"Saved heatmap to {root_predictions_path / 'iou_heatmap.png'}")

### F1-Score

In [None]:
metrics_df_iou = get_metrics_df(dfs, "F1 Score")
metrics_df = get_df_without_std(metrics_df_iou)
text_df = get_text_df(metrics_df_iou)

metrics_df = metrics_df.drop(index="setr_pup")
metrics_df = metrics_df.drop(index="dinov2_mla")
text_df = text_df.drop(index="setr_pup")
text_df = text_df.drop(index="dinov2_mla")



fig = plot_heat_map(metrics_df, text_df, colorbar_title="F1-Score (%)", height=1000, width=700, index_map=index_map)
fig.show()

fig.write_image(root_predictions_path / "f1_heatmap.png")
print(f"Saved heatmap to {root_predictions_path / 'f1_heatmap.png'}")

## Make Prediction Maps

In [None]:
import plotly.graph_objects as go
import numpy as np


def plot_value_heatmap(
    values,
    colorscale="Blues",
    title="Heatmap of Values",
    show_colorbar=True,
    width=1000,
    height=200,
    metric_name="IoU",
    filename=None,
):
    """
    Plots a heatmap of the given values using Plotly.

    Parameters:
    - values (list or np.array): List of values to visualize.
    - colorscale (str): Color scheme for the heatmap (e.g., 'Viridis', 'Plasma', 'Jet').
    - title (str): Title of the heatmap.
    - show_colorbar (bool): Whether to display the color bar.
    - width (int): Width of the plot.
    - height (int): Height of the plot.
    """
    # Reshape values for a single-row heatmap
    heatmap_values = np.array([values])  # Convert to 2D array

    # Create heatmap figure
    fig = go.Figure(
        data=go.Heatmap(
            z=heatmap_values,
            colorscale=colorscale,
            showscale=show_colorbar,
            zmin=0,
            zmax=1,
            colorbar=dict(
                title=metric_name,
                tickvals=np.arange(0, 1.2, 0.2),
                ticktext=[f"{i:.1f}" for i in np.arange(0, 1.2, 0.2)]
            ) if show_colorbar else None,
        )
    )

    # Update layout settings
    fig.update_layout(
        xaxis=dict(
            tickvals=list(range(len(values))),
            showticklabels=False,  # Hide tick labels for cleaner look
            title="Crossline index",
        ),
        yaxis=dict(
            showticklabels=False,  # Hide y-axis labels as it's a single row
            title="",
        ),
        margin=dict(l=10, r=10, t=40, b=10),
        title=title,
        height=height,
        width=width,
    )

    fig.show()
    
    if filename:
        fig.write_image(filename)
        print(f"Saved heatmap to {filename}")

In [None]:
for key, df in sorted(dfs.items(), key=lambda x: x[0]):
    values = df["Mean IoU"].values
    plot_value_heatmap(
        values,
        colorscale="Plasma",
        title=f"IoU for model: {index_map[key].replace('<br>', ' ')} (min: {values.min() * 100:.2f}%, average: {values.mean() * 100:.2f}%, max: {values.max() * 100:.2f}%)",
        width=1000,
        height=250,
        metric_name="IoU",
        filename=root_predictions_path / f"{key}_iou_heatmap.png",
    )

In [None]:
for key, df in sorted(dfs.items(), key=lambda x: x[0]):
    values = df["Mean F1 Score"].values
    plot_value_heatmap(
        values,
        colorscale="Plasma",
        title=f"F1-Score for model: {index_map[key].replace('<br>', ' ')} (min: {values.min() * 100:.2f}%, average: {values.mean() * 100:.2f}%, max: {values.max() * 100:.2f}%)",
        width=1000,
        height=250,
        metric_name="F1-Score",
        filename=root_predictions_path / f"{key}_f1_heatmap.png",
    )
