In [None]:
import pandas as pd
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Increase plot resolution
plt.rcParams["figure.dpi"] = 120
sns.set_theme(style="whitegrid")

In [None]:
# Base Directories
models_dir = Path(
    "/home/guillemc/dev/LuPNT-private/output/2025_FeatureMatching/eval_results_models"
)
legacy_dir = Path(
    "/home/guillemc/dev/LuPNT-private/output/2025_FeatureMatching/eval_results"
)

all_data = []
all_per_pair_list = []

# Dataset Name Mapping (Prettify)
dataset_map = {
    "short_base": "Baseline",
    "short_camera_effects": "Camera Effects",
    "short_higher_elevation": "Higher Sun Elevation",
    "short_no_lights": "No Lights",
    "long_base": "Long Base",
    "long_camera_effects": "Long Camera Effects",
    "long_higher_elevation": "Long Higher Elevation",
    "long_no_lights": "Long No Lights",
    "rover_0": "Spirals",
}

# Model Name Mapping (Renaming)
model_name_map = {
    "SuperPoint+LightGlue_Spirals": "Finetuned",
    # We will filter out the others, but mapping them just in case
    "lightglue_spirals_v1": "SuperPoint+LightGlue (Spirals)",
    "lightglue_unreal_base1": "SuperPoint+LightGlue (Traverse)",
    "spirals_seg_rover0": "SuperPoint+LightGlue (Spirals+Segmentation)",
}

# Skip list (Dumb/Test models AND models user wants to hide)
skip_patterns = [
    "semantic_test",
    "local_traverse_fov90",
    "spirals_20251211",
    # User requested ONLY "Spirals Legacy" (now "Finetuned")
    "lightglue_spirals_v1",
    "lightglue_unreal_base1",
    "spirals_seg_rover0",
]

# --- 1. Load Trained Models ---
print(f"Scanning models in {models_dir}...")
if models_dir.exists():
    for model_path in models_dir.iterdir():
        if not model_path.is_dir():
            continue
        raw_model_name = model_path.name

        # Filtering Models
        if any(skip in raw_model_name for skip in skip_patterns):
            continue

        # Determine Display Name
        model_display_name = raw_model_name  # Default

        # Check for user-defined mapping (prefix match)
        for key, val in model_name_map.items():
            if key in raw_model_name:
                model_display_name = val
                break

        # Fallback for others
        if model_display_name == raw_model_name and "spirals" in raw_model_name.lower():
            model_display_name = f"Finetuned ({raw_model_name})"

        for dataset_path in model_path.iterdir():
            if not dataset_path.is_dir():
                continue
            raw_dataset_name = dataset_path.name

            # Filter Datasets: ONLY short_*, exclude rover_0/Spirals
            if "short" not in raw_dataset_name:
                continue

            dataset_name = dataset_map.get(
                raw_dataset_name, raw_dataset_name.replace("_", " ").title()
            )

            for step_file in dataset_path.glob("step_*.pkl"):
                try:
                    step = int(step_file.stem.split("_")[1])
                    with open(step_file, "rb") as f:
                        data = pickle.load(f)

                    metrics = data["results"]
                    for method, res in metrics.items():
                        # Summary
                        if "summary" in res:
                            row = res["summary"].copy()
                            row["Model Type"] = "Finetuned"
                            row["Model"] = model_display_name
                            row["Method"] = method
                            row["Dataset"] = dataset_name
                            row["Step"] = step
                            if "abs_loc_t_error" in row:
                                row["Abs Trans Error (m)"] = row["abs_loc_t_error"]
                            if "rel_pose_r_error" in row:
                                row["Rel Rot Error (deg)"] = row["rel_pose_r_error"]
                            all_data.append(row)

                        # Per Pair (Not loading for summary plot notebook to save memory)
                        # if "per_pair" in res:
                        #     pass
                except:
                    pass

# --- 2. Load Legacy Results (Baselines) ---
print(f"Scanning legacy results in {legacy_dir}...")
if legacy_dir.exists():
    for dataset_path in legacy_dir.iterdir():
        if not dataset_path.is_dir():
            continue
        raw_dataset_name = dataset_path.name
        dataset_name = dataset_map.get(
            raw_dataset_name, raw_dataset_name.replace("_", " ").title()
        )

        for agent_path in dataset_path.iterdir():  # e.g. rover_0
            if not agent_path.is_dir():
                continue

            # Filter Datasets: ONLY short_*, exclude rover_0/Spirals
            if "short" not in raw_dataset_name:
                continue

            for step_file in agent_path.glob("*.pkl"):
                try:
                    parts = step_file.stem.split("_")
                    if len(parts) >= 2 and parts[0] == "step":
                        try:
                            step = int(parts[1])
                        except:
                            continue
                    else:
                        continue

                    with open(step_file, "rb") as f:
                        data = pickle.load(f)

                    metrics = data.get("results", {})
                    for method, res in metrics.items():
                        # Determine Label: Just "Detection+Matching" (e.g. SuperPoint+LightGlue)
                        # Remove "Baseline" prefix to match user request
                        model_label = method

                        # Summary
                        if "summary" in res:
                            row = res["summary"].copy()
                            row["Model Type"] = "Baseline"
                            row["Model"] = (
                                model_label  # Use the method name directly as the Model label
                            )
                            row["Method"] = method
                            row["Dataset"] = dataset_name
                            row["Step"] = step
                            if "rel_pose_r_error" in row:
                                row["Rel Rot Error (deg)"] = row["rel_pose_r_error"]
                            if "rel_pose_t_error" in row:
                                row["Trans Error (m)"] = row["rel_pose_t_error"]
                            all_data.append(row)

                except:
                    pass

df_summary = pd.DataFrame(all_data)

print(f"Loaded {len(df_summary)} total summary records.")
if not df_summary.empty:
    # Also for summary df
    df_summary["Labels"] = df_summary["Model"]

    # 2b. Completeness Table
    # Create a pivot table showing which Steps exist for each Model+Dataset
    print("Generating Completeness Table...")
    completeness = (
        df_summary.groupby(["Dataset", "Labels"])["Step"]
        .apply(lambda x: sorted(list(set(x))))
        .unstack(fill_value="-")
    )

    # Style the table
    # We can just display it as a dataframe
    display(completeness)

    display(df_summary.head())

## 8. Summary Heatmaps (Detailed Absolute Metrics)

In [None]:
summary_metrics = [
    # Matching Statistics
    ("mnum_matches", "Mean Num Matches", "float"),
    ("median_num_matches", "Median Num Matches", "float"),
    ("mnum_keypoints0", "Mean Num Keypoints 0", "float"),
    ("median_num_keypoints0", "Median Num Keypoints 0", "float"),
    ("mnum_keypoints1", "Mean Num Keypoints 1", "float"),
    ("median_num_keypoints1", "Median Num Keypoints 1", "float"),
    ("mnum_keypoints", "Mean Num Keypoints", "float"),
    ("median_num_keypoints", "Median Num Keypoints", "float"),
    # Epipolar Precision
    ("mepi_prec@1e-4", "Mean Epi Prec @ 1e-4", "float"),
    ("median_epi_prec@1e-4", "Median Epi Prec @ 1e-4", "float"),
    ("mepi_prec@5e-4", "Mean Epi Prec @ 5e-4", "float"),
    ("median_epi_prec@5e-4", "Median Epi Prec @ 5e-4", "float"),
    ("mepi_prec@1e-3", "Mean Epi Prec @ 1e-3", "float"),
    ("median_epi_prec@1e-3", "Median Epi Prec @ 1e-3", "float"),
    # Reprojection Precision
    ("mreproj_prec@1px", "Mean Reproj Prec @ 1px", "float"),
    ("median_reproj_prec@1px", "Median Reproj Prec @ 1px", "float"),
    ("mreproj_prec@3px", "Mean Reproj Prec @ 3px", "float"),
    ("median_reproj_prec@3px", "Median Reproj Prec @ 3px", "float"),
    ("mreproj_prec@5px", "Mean Reproj Prec @ 5px", "float"),
    ("median_reproj_prec@5px", "Median Reproj Prec @ 5px", "float"),
    # Covisibility
    ("mcovisible", "Mean Covisible", "float"),
    ("median_covisible", "Median Covisible", "float"),
    ("mcovisible_percent", "Mean Covisible %", "float"),
    ("median_covisible_percent", "Median Covisible %", "float"),
    # Ground Truth Recall/Precision
    ("mgt_match_recall@3px", "Mean GT Match Recall @ 3px", "float"),
    ("median_gt_match_recall@3px", "Median GT Match Recall @ 3px", "float"),
    ("mgt_match_precision@3px", "Mean GT Match Precision @ 3px", "float"),
    ("median_gt_match_precision@3px", "Median GT Match Precision @ 3px", "float"),
    # Relative Pose Error (General)
    ("mrel_pose_error", "Mean Rel Pose Error", "float"),
    ("median_rel_pose_error", "Median Rel Pose Error", "float"),
    ("mransac_inl", "Mean RANSAC Inliers", "float"),
    ("median_ransac_inl", "Median RANSAC Inliers", "float"),
    ("mransac_inl%", "Mean RANSAC Inliers %", "float"),
    ("median_ransac_inl%", "Median RANSAC Inliers %", "float"),
    # Relative Pose Error (Translation/Rotation)
    ("mrel_pose_t_error", "Mean Rel Pose Trans Error", "float"),
    ("median_rel_pose_t_error", "Median Rel Pose Trans Error", "float"),
    ("mrel_pose_r_error", "Mean Rel Pose Rot Error", "float"),
    ("median_rel_pose_r_error", "Median Rel Pose Rot Error", "float"),
    ("mrel_pose_t_error_rel", "Mean Rel Pose Trans Error (Rel)", "float"),
    ("median_rel_pose_t_error_rel", "Median Rel Pose Trans Error (Rel)", "float"),
    ("mrel_pose_r_error_rel", "Mean Rel Pose Rot Error (Rel)", "float"),
    ("median_rel_pose_r_error_rel", "Median Rel Pose Rot Error (Rel)", "float"),
    # Relative Pose Error (Components)
    ("mrel_pose_t_error_x", "Mean Rel Pose Trans Error X", "float"),
    ("median_rel_pose_t_error_x", "Median Rel Pose Trans Error X", "float"),
    ("mrel_pose_t_error_y", "Mean Rel Pose Trans Error Y", "float"),
    ("median_rel_pose_t_error_y", "Median Rel Pose Trans Error Y", "float"),
    ("mrel_pose_t_error_z", "Mean Rel Pose Trans Error Z", "float"),
    ("median_rel_pose_t_error_z", "Median Rel Pose Trans Error Z", "float"),
    ("mrel_pose_r_error_roll", "Mean Rel Pose Rot Error Roll", "float"),
    ("median_rel_pose_r_error_roll", "Median Rel Pose Rot Error Roll", "float"),
    ("mrel_pose_r_error_pitch", "Mean Rel Pose Rot Error Pitch", "float"),
    ("median_rel_pose_r_error_pitch", "Median Rel Pose Rot Error Pitch", "float"),
    ("mrel_pose_r_error_yaw", "Mean Rel Pose Rot Error Yaw", "float"),
    ("median_rel_pose_r_error_yaw", "Median Rel Pose Rot Error Yaw", "float"),
    # Absolute Localization Error
    ("mabs_loc_t_error", "Mean Abs Loc Trans Error", "float"),
    ("median_abs_loc_t_error", "Median Abs Loc Trans Error", "float"),
    ("mabs_loc_r_error", "Mean Abs Loc Rot Error", "float"),
    ("median_abs_loc_r_error", "Median Abs Loc Rot Error", "float"),
    ("mabs_loc_t_error_rel", "Mean Abs Loc Trans Error (Rel)", "float"),
    ("median_abs_loc_t_error_rel", "Median Abs Loc Trans Error (Rel)", "float"),
    ("mabs_loc_r_error_rel", "Mean Abs Loc Rot Error (Rel)", "float"),
    ("median_abs_loc_r_error_rel", "Median Abs Loc Rot Error (Rel)", "float"),
    # Absolute Localization Error (Components)
    ("mabs_loc_t_error_x", "Mean Abs Loc Trans Error X", "float"),
    ("median_abs_loc_t_error_x", "Median Abs Loc Trans Error X", "float"),
    ("mabs_loc_t_error_y", "Mean Abs Loc Trans Error Y", "float"),
    ("median_abs_loc_t_error_y", "Median Abs Loc Trans Error Y", "float"),
    ("mabs_loc_t_error_z", "Mean Abs Loc Trans Error Z", "float"),
    ("median_abs_loc_t_error_z", "Median Abs Loc Trans Error Z", "float"),
    ("mabs_loc_r_error_roll", "Mean Abs Loc Rot Error Roll", "float"),
    ("median_abs_loc_r_error_roll", "Median Abs Loc Rot Error Roll", "float"),
    ("mabs_loc_r_error_pitch", "Mean Abs Loc Rot Error Pitch", "float"),
    ("median_abs_loc_r_error_pitch", "Median Abs Loc Rot Error Pitch", "float"),
    ("mabs_loc_r_error_yaw", "Mean Abs Loc Rot Error Yaw", "float"),
    ("median_abs_loc_r_error_yaw", "Median Abs Loc Rot Error Yaw", "float"),
    # Absolute Localization Accuracy Thresholds
    ("mabs_loc_acc@0.25m_2deg", "Mean Abs Loc Acc @ 0.25m 2deg", "float"),
    ("median_abs_loc_acc@0.25m_2deg", "Median Abs Loc Acc @ 0.25m 2deg", "float"),
    ("mabs_loc_acc@0.5m_5deg", "Mean Abs Loc Acc @ 0.5m 5deg", "float"),
    ("median_abs_loc_acc@0.5m_5deg", "Median Abs Loc Acc @ 0.5m 5deg", "float"),
    ("mabs_loc_acc@1.0m_10deg", "Mean Abs Loc Acc @ 1.0m 10deg", "float"),
    ("median_abs_loc_acc@1.0m_10deg", "Median Abs Loc Acc @ 1.0m 10deg", "float"),
    # Timing
    ("mextraction_time", "Mean Extraction Time", "float"),
    ("median_extraction_time", "Median Extraction Time", "float"),
    ("mmatching_time", "Mean Matching Time", "float"),
    ("median_matching_time", "Median Matching Time", "float"),
    ("mtotal_time", "Mean Total Time", "float"),
    ("median_total_time", "Median Total Time", "float"),
    ("extraction_fps", "Extraction FPS", "float"),
    ("matching_fps", "Matching FPS", "float"),
    ("total_fps", "Total FPS", "float"),
    # Pose AUC
    ("mpose_auc@5", "Mean Pose AUC @ 5 deg", "float"),
    ("mpose_auc@10", "Mean Pose AUC @ 10 deg", "float"),
    ("mpose_auc@20", "Mean Pose AUC @ 20 deg", "float"),
    # Metadata
    ("Model Type", "Model Type", "str"),
    ("Model", "Model", "str"),
    ("Method", "Method", "str"),
    ("Dataset", "Dataset", "str"),
    ("Step", "Step", "int"),
    ("Labels", "Labels", "str"),
]

import os

# 1. Create the directory
output_dir = "./heatmaps"
os.makedirs(output_dir, exist_ok=True)

if not df_summary.empty:
    unique_steps = sorted(df_summary["Step"].unique())

    # Metrics where lower values are better (Errors, Time)
    lower_is_better_keywords = ["error", "time"]

    for step in unique_steps:
        subset_step = df_summary[df_summary["Step"] == step]
        if subset_step.empty:
            continue

        # Filter numeric metrics only
        numeric_metrics = [m for m in summary_metrics if m[2] != "str"]

        higher_better_metrics = []
        lower_better_metrics = []

        # Categorize metrics
        for m in numeric_metrics:
            col_name = m[0]
            if col_name not in subset_step.columns:
                continue

            name_lower = col_name.lower()

            # Logic: If 'error' or 'time' is present (and not 'fps'), use the error colormap
            if (
                any(k in name_lower for k in lower_is_better_keywords)
                and "fps" not in name_lower
            ):
                lower_better_metrics.append(m)
            else:
                higher_better_metrics.append(m)

        # Helper function to generate and save plots
        def save_metric_group(metrics_list, cmap_name):
            for metric, title, _ in metrics_list:
                pivot_df = subset_step.pivot(
                    index="Dataset", columns="Model", values=metric
                )
                if pivot_df.empty:
                    continue

                # Dynamic height based on number of rows
                plt.figure(figsize=(10, len(pivot_df) * 0.8 + 2))

                sns.heatmap(
                    pivot_df, annot=True, fmt=".3f", cmap=cmap_name, linewidths=0.5
                )
                plt.title(f"{title} ($\Delta t={step})%")
                plt.tight_layout()

                # Construct filename
                # Clean filename to remove potentially problematic chars if necessary
                safe_name = metric.replace("/", "_")
                filename = f"{safe_name}_step_{step}.png"
                save_path = os.path.join(output_dir, filename)

                plt.savefig(save_path, dpi=300)  # Save high res
                plt.close()  # Close figure to free memory
                print(f"Saved: {save_path}")

        # 1. Save "Higher is Better" Metrics (Viridis: Yellow=High)
        if higher_better_metrics:
            save_metric_group(higher_better_metrics, "viridis")

        # 2. Save "Lower is Better" Metrics (Magma: Yellow=High Error)
        if lower_better_metrics:
            save_metric_group(lower_better_metrics, "magma")