# Model Comparison for Robot Policy Performance

This notebook compares the performance of different models tested using the `eval.py` script.

We load all model results from the `results/` dir and concatenate them into a dataframe. Repeats of the same title (such as f"{model_name}_{history_length}_{history_choice})" are averaged into one result for the downstream metrics and renamed to `"...-avg-N".`

In [267]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import pandas as pd
import seaborn as sns
from typing import Dict, List, Any

# Set plot style
plt.style.use('ggplot')
sns.set_context("talk")

In [None]:
def calc_expected_frames(history_length: int, history_choice: str) -> int:
    if history_choice == "all":
        return history_length
    elif history_choice == "last":
        return 1
    elif history_choice == "first":
        return 1
    elif history_choice == "alternate":
        return history_length // 2
    elif history_choice == "third":
        return history_length // 3
    elif history_choice.lower() == "none":
        return 0
    elif history_choice.lower() == "one":
        return 1
    else:
        raise ValueError(f"Invalid history_choice: {history_choice}")

RESULTS_DIR = "results"
RESULTS = dict()
model_dirs = [mdir for mdir in os.listdir(RESULTS_DIR) if os.path.isdir(os.path.join(RESULTS_DIR, mdir))]
for model_name in model_dirs:
    print(f"loading results for {model_name}")
    if os.path.isdir(os.path.join(RESULTS_DIR, model_name)):
        if model_name.replace("history_sweep_", "") not in RESULTS:
            RESULTS[model_name.replace("history_sweep_", "")] = list()
        for filename in os.listdir(os.path.join(RESULTS_DIR, model_name)):
            if 'metrics' in filename and filename.endswith('.pkl'):
                print(f"found metrics file {os.path.join(RESULTS_DIR, model_name, filename)}")
                metrics_dict = pickle.load(open(os.path.join(RESULTS_DIR, model_name, filename), "rb"))
                actions_file = filename.replace('metrics', 'actions')
                actions_dict = pickle.load(open(os.path.join(RESULTS_DIR, model_name, actions_file), "rb"))
                info_dict = {k: metrics_dict[k] for k in ['timestamp', 'model_name', 'history_length', 'history_choice']}
                info_dict['model_name'] = model_name.replace("history_sweep_", "")
                RESULTS[model_name.replace("history_sweep_", "")].append({"info": info_dict, "actions": actions_dict, "metrics": metrics_dict})


In [None]:
def info_dict_to_title(info_dict):
    return f"{info_dict['model_name']}" + (f"-histlen-{info_dict['history_length']}-histchoice-{info_dict['history_choice']}" if info_dict['history_length'] is not None else f"")

# Check which models were loaded successfully
print(f"Successfully loaded data for {len(RESULTS)} models and {sum([len(v) for v in RESULTS.values()])} results")
for model_name, results_list in RESULTS.items():
    print(f"model name: {model_name}")
    for result in results_list:
        print(f"- {info_dict_to_title(result['info'])}")

In [270]:
# Function to compute metrics for models without pre-computed metrics
def compute_metrics(pred_actions, gt_actions):
    results = {}
    
    # MSE - Mean Squared Error
    results["mse"] = np.mean(np.square(pred_actions - gt_actions))
    
    # MAE - Mean Absolute Error
    results["mae"] = np.mean(np.abs(pred_actions - gt_actions))
    
    # Normalized MSE - divide by variance of ground truth
    gt_var = np.var(gt_actions)
    if gt_var > 0:
        results["nmse"] = results["mse"] / gt_var
    else:
        results["nmse"] = float('inf')
    
    # Action magnitude comparison
    results["pred_mag"] = np.mean(np.linalg.norm(pred_actions, axis=1))
    results["gt_mag"] = np.mean(np.linalg.norm(gt_actions, axis=1))
    results["mag_ratio"] = results["pred_mag"] / results["gt_mag"] if results["gt_mag"] > 0 else float('inf')
    
    return results

In [271]:
# Extract and organize metrics across models for comparison
def extract_comparison_metrics(model_data: dict[str, list[dict]], force_compute: bool = True, min_results: int = 4):
    comparison = {}
    
    for model_name, model_results in model_data.items():
        for data in model_results:
            if data['info']['history_length'] is None:
                data['info']['history_length'] = 0 
            hist_len = data['info']['history_length']

            if hist_len in (None, 0) or data['info']['history_choice'] is None:
                data['info']['history_choice'] = 'None'
            elif hist_len == 1:
                data['info']['history_choice'] = 'all'
            hist_choice = data['info']['history_choice']

            model_title = info_dict_to_title(data['info'])
            if model_title in comparison:
                copies = [x for x in comparison.keys() if model_title == x or (model_title in x and len(x)- 2 == len(model_title))]
                model_title = f"{model_title}-{len(copies)}"
            # If we have pre-computed metrics, use those
            if 'metrics' in data and 'avg_metrics' in data['metrics'] and not force_compute:
                comparison[model_title] = data['metrics']['avg_metrics']
            else:
                # Otherwise, compute metrics from actions
                all_results = []
                
                for traj in data['actions']:
                    pred_actions = traj['pred_actions']
                    gt_actions = traj['gt_actions']
                    
                    results = compute_metrics(pred_actions, gt_actions)
                    
                    # Add timing information if available
                    if 'inference_times' in traj:
                        results['mean_inference_time'] = np.mean(traj['inference_times'])
                    all_results.append(results)
                
                # Compute average metrics across all trajectories
                if len(all_results) < min_results:
                    print(f"Not enough results for {model_title}, skipping {len(all_results)} results")
                    continue
                metrics = {k: np.mean([r[k] for r in all_results]) for k in all_results[0]}
                # add info to the df
                metrics['model_name'] = model_name
                metrics['model_title'] = model_title
                metrics['history_length'] = hist_len
                metrics['history_choice'] = hist_choice
                metrics['expected_frames'] = calc_expected_frames(hist_len, hist_choice) 
                comparison[model_title] = metrics

    return comparison

def consolidate_reruns(comparison: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
    averaged_comparison = dict()
    multiples = [t for t in comparison.keys() if t.endswith('-1')]
    for t in multiples:
        base_title = t[:-2]
        matched_titles = [x for x in comparison.keys() if x.startswith(base_title) and len(x) <= len(t)]
        avg_metrics = dict()
        for k in comparison[matched_titles[0]].keys():
            if type(comparison[matched_titles[0]][k]) in (str, type(None)):
                # string metrics like choice, name, title etc are the same
                avg_metrics[k] = comparison[matched_titles[0]][k]
            else:
                avg_metrics[k] = np.mean([comparison[t][k] for t in matched_titles])

        averaged_comparison[base_title + f"-avg-{len(matched_titles)}"] = avg_metrics
        print(f"averaged {len(matched_titles)} models for {base_title}")
        for m in matched_titles:
            del comparison[m]
    comparison.update(averaged_comparison)
    return comparison

In [None]:
# Extract performance metrics for comparison
all_comparison_metrics = extract_comparison_metrics(RESULTS)
comparison_metrics = consolidate_reruns(all_comparison_metrics)

metrics_of_interest = list(comparison_metrics.values())[0].keys()
print(metrics_of_interest)

raw_comparison_df = pd.DataFrame(index=metrics_of_interest)
for model_name, metrics in all_comparison_metrics.items():
    model_values = [metrics.get(metric, np.nan) for metric in metrics_of_interest]
    raw_comparison_df[model_name] = model_values


comparison_df = pd.DataFrame(index=metrics_of_interest)

for model_name, metrics in comparison_metrics.items():
    model_values = [metrics.get(metric, np.nan) for metric in metrics_of_interest]
    comparison_df[model_name] = model_values

# i messed up naming o4-mini
comparison_df.loc['model_name'] = comparison_df.loc['model_name'].str.replace('4-mini', 'o4-mini')

# Display the comparison table
comparison_df

In [None]:
comparison_df.T.model_name.value_counts()

In [None]:
VLA_MODEL_NAMES = ["openvla", "ecot"]
existing_vlas = [m for m in comparison_df.T.model_name.unique() if any(vla in m for vla in VLA_MODEL_NAMES)]
print(f"found {len(existing_vlas)} VLA models: {existing_vlas}")
vla_df = comparison_df.T[comparison_df.T.model_name.isin(existing_vlas)].T

BASELINE_METRICS = dict() # find the VLA metrics to use as baseline comparisons
desired_metrics = ["mse"]
for _, row in vla_df.T.iterrows():
    BASELINE_METRICS[row['model_name']] = dict()
    for metric in desired_metrics:
        BASELINE_METRICS[row['model_name']][metric] = row[metric]

print(f"Collected baseline VLA metrics: {BASELINE_METRICS}")

In [None]:
model_name_counts = raw_comparison_df.T.model_name.value_counts()
print(model_name_counts)
print(f"Total unique models: {len(model_name_counts)}\n")

model_title_counts =raw_comparison_df.T.model_title.value_counts()
print(f"Total unique configurations: {len(model_title_counts)}\n")

histlen_counts = raw_comparison_df.T.history_length.value_counts()
print(histlen_counts)

histchoice_counts = raw_comparison_df.T.history_choice.value_counts()
print(histchoice_counts)


In [None]:
print(comparison_df.T.model_name.value_counts())


## Performance Metric Comparisons

In [277]:
from typing import Callable, Optional

from typing import Callable, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import Callable, Optional



def plot_metric_comparison(comparison_df: pd.DataFrame, metric_name: str, title: Optional[str]=None, lower_is_better=True, groupby: Optional[str] = None, filterby: Optional[tuple[str, Callable]] = None, other_vals: Optional[dict[str, float]] = None):
    plt.figure(figsize=(10, 6))
    
    # Get the data for this metric
    if filterby is not None:
        comparison_df = comparison_df.T[comparison_df.T.apply(filterby, axis=1)].T

    if groupby is not None:
        comparison_df = comparison_df.T.groupby([groupby])[[metric_name]].agg('mean').T

    metric_values = comparison_df.loc[metric_name]
    
    # Sort models by metric value
    if lower_is_better:
        sorted_models = metric_values.sort_values().index
    else:
        sorted_models = metric_values.sort_values(ascending=False).index
    
    # Create bar plot
    bars = plt.bar(sorted_models, metric_values[sorted_models])
    
    # Add value labels on top of bars with smaller font size
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,
                f'{height:.4f}', ha='center', va='bottom', rotation=0, fontsize=8)
    
    # Add title and labels
    plt.title(title or f'{metric_name.upper()} Comparison Across Models')
    plt.ylabel(metric_name)
    if groupby is not None:
        plt.xlabel(groupby.replace('_', ' ').title())
    
    # Set rotated tick labels with proper alignment
    plt.xticks(rotation=45, ha='right')

    if other_vals is not None:
        for k,v in other_vals.items():
            plt.axhline(y=v, color='gray', linestyle='--', alpha=0.7, label=k)
    plt.legend()

    plt.tight_layout()
    
    return plt

In [278]:
GROUPBY_OPTIONS = ['model_name', 'history_length', 'history_choice', 'expected_frames']

In [None]:
# Plot MSE comparison (lower is better)
# filterby = lambda x: "claude" in x['model_name'] and 'first' in x['history_choice']
filterby = None
groupby = "model_name"

for g in GROUPBY_OPTIONS:
    plot_metric_comparison(comparison_df, 'mse', title='Averaged MSE (lower is better) for each VLM + OpenVLA baseline', groupby=g, filterby=filterby, other_vals={"openvla": BASELINE_METRICS['openvla']['mse']})

In [None]:
best_any_way_df = comparison_df.T.groupby('model_name').agg('min').sort_values('mse')
plt.figure(figsize=(10, 6))
xvals = list(best_any_way_df.index)
yvals =list(best_any_way_df['mse'])
bars = plt.bar(xvals, yvals)
plt.axhline(y=BASELINE_METRICS['openvla']['mse'], color='gray', linestyle='--', alpha=0.7, label='openvla')

# Add value labels on top of bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,
            f'{height:.4f}', ha='center', va='bottom', rotation=0, fontsize=8)
plt.legend()
plt.title('Best MSE (lower is better) for each VLM + OpenVLA Baseline')
plt.xlabel('Model')
plt.ylabel('MSE')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()

## Action Magnitude Analysis

In [None]:
# Improved action magnitude comparison plot with simple color scheme
# Extract data
models = list(comparison_metrics.keys())
pred_mags = [comparison_metrics[model]['pred_mag'] for model in models]
gt_mag = comparison_metrics[models[0]]['gt_mag']  # Ground truth should be the same for all models

# Set up a simple color palette
colors = plt.cm.tab10(np.arange(len(models)))

# Create plot
fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar(np.arange(len(models)), pred_mags, color=colors, alpha=0.8)

# Add horizontal line for ground truth
ax.axhline(y=gt_mag, color='red', linestyle='--', alpha=0.7, label=f'Ground Truth: {gt_mag:.4f}')

# Add value labels on top of bars
for i, bar in enumerate(bars):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., 
            height + 0.01,
            f'{height:.4f}', 
            ha='center', 
            va='bottom',
            fontsize=10)

# Set labels and title
ax.set_xlabel('Model')
ax.set_ylabel('Action Magnitude')
ax.set_title('Action Magnitude Comparison')
ax.set_xticks(np.arange(len(models)))
ax.set_xticklabels(models, rotation=45, ha='right')
ax.legend()

# Add grid for better readability
ax.grid(axis='y', linestyle='--', alpha=0.3)

plt.tight_layout()

In [None]:

from __future__ import annotations

"""plot_mse_by_model
====================
flexible visualisation of mse vs history_length with robust sizing logic.

changes vs previous rev
-----------------------
* `expected_frames may contain **zero** → size fallback of *1 frame* so the dot is visible.
* size legend automatically deduplicates values and spans the whole numeric range, incl. `0 when present.
* public API unchanged: `plot_mse_by_model(df, grid=False, size_scale=20.0).

layout recap
------------
* one‑axes (default) or N×N grid (`grid=True).
* marker → model_name, color → history_choice, size → expected_frames.
* tri‑column legend outside right for color, size, marker.
"""

from typing import Dict, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ----------------------------- aesthetics helpers -----------------------------
MARKERS: list[str] = [
    "o", "s", "^", "D", "v",
    "P", "X", "*", "<", ">",
    "h", "H", "p", "8",
]


def _build_marker_map(models: list[str]) -> dict[str, str]:
    if len(models) > len(MARKERS):
        raise ValueError("too many models for built‑in marker set; extend MARKERS list")
    return {m: MARKERS[i] for i, m in enumerate(models)}


# ----------------------------- size helpers -----------------------------

def _effective_size(frames: float | pd.Series, scale: float) -> float | pd.Series:
    """convert expected_frames → marker size in points²; zero gets minimal dot."""
    if isinstance(frames, pd.Series):
        return frames.clip(lower=1.0) * scale  # 0 → 1
    return max(frames, 1.0) * scale  # 0 → 1


def _legend_size_values(frames: np.ndarray) -> list[int]:
    """choose up‑to‑4 representative frame counts for legend.

    * always include 0 when present.
    * include min, median, max of non‑zero values (deduped).
    """
    frames = frames[~np.isnan(frames)]
    zeros_present = np.any(frames == 0)
    nonzero = frames[frames > 0]
    legend_vals: list[int] = []

    if zeros_present:
        legend_vals.append(0)

    if nonzero.size:
        # pick min, median, max unique ints
        q_vals = [np.min(nonzero), np.percentile(nonzero, 25), np.median(nonzero), np.percentile(nonzero, 75), np.percentile(nonzero, 90)]
        for v in q_vals:
            iv = int(round(v))
            if iv not in legend_vals:
                legend_vals.append(iv)

    return sorted(legend_vals, key=lambda x: int(x))


# ----------------------------- single‑axes plot -----------------------------

def _plot_single(df: pd.DataFrame, marker_map: dict[str, str], cdict: dict[str, str], *, size_scale: float, other_vals: dict[str, float]) -> None:
    fig, ax = plt.subplots(figsize=(10, 8))

    for _, row in df.iterrows():
        ax.scatter(
            row["history_length"],
            row["mse"],
            s=_effective_size(row["expected_frames"], size_scale),
            marker=marker_map[row["model_name"]],
            color=cdict[row["history_choice"]],
            alpha=0.8,
        )

    ax.set_xlabel("history_len (# frames)")
    ax.set_ylabel("mse")
    ax.set_title("mse vs history_len across models")
    for k,v in other_vals.items():
        ax.axhline(y=v, color='gray', linestyle='--', alpha=0.7, label=k)

    _add_legends(ax, df, marker_map, cdict, size_scale)
    plt.tight_layout()
    plt.show()


# ----------------------------- grid plot -----------------------------
from typing import Dict
import math
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict
import math
import pandas as pd
import matplotlib.pyplot as plt

def _plot_grid(  # noqa: N802
    df: pd.DataFrame,
    marker_map: dict[str, str],
    cdict: dict[str, str],
    *,
    size_scale: float,
    sync_axes: bool = True,
    other_vals: dict[str, float] = {},
) -> None:
    extend_pct = 0.1
    if sync_axes:
        x_min, x_max = df["history_length"].dropna().agg(["min", "max"])
        y_min, y_max = df["mse"].dropna().agg(["min", "max"])
        if other_vals:
            y_min = min(y_min, min(other_vals.values()))
            y_max = max(y_max, max(other_vals.values()))
        x_range = x_max - x_min
        x_min, x_max = x_min - (x_range * extend_pct), x_max + (x_range * extend_pct)
        y_range = y_max - y_min
        y_min, y_max = y_min - (y_range * extend_pct), y_max + (y_range * extend_pct)
        sharex = sharey = True
    else:
        sharex = sharey = False

    models = sorted(df["model_name"].unique())
    n_mod = len(models)
    side = math.ceil(math.sqrt(n_mod))
    bottom_row = (n_mod - 1) // side  # last visible row

    fig, axes = plt.subplots(
        side,
        side,
        figsize=(6 * side, 6 * side),
        sharex=sharex,
        sharey=sharey,
    )

    axes_flat = axes.ravel() if n_mod > 1 else [axes]
    for ax in axes_flat[n_mod:]:
        ax.set_visible(False)

    for idx, (ax, model) in enumerate(zip(axes_flat, models)):
        sdf = df[df["model_name"] == model]
        for choice, sub in sdf.groupby("history_choice"):
            ax.scatter(
                sub["history_length"],
                sub["mse"],
                s=_effective_size(sub["expected_frames"], size_scale).astype(float),
                marker=marker_map[model],
                color=cdict[choice],
                alpha=0.8,
                label=str(choice),  # labels grabbed later
            )
        
        for k,v in other_vals.items():
            ax.axhline(y=v, color='gray', linestyle='--', alpha=0.7, label=k)

        ax.set_title(model)
        if sync_axes:
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)

        row, col = divmod(idx, side)
        if col == 0:
            ax.set_ylabel("mse")
        
        ax.set_xlabel("history_len (# frames)")
        ax.tick_params(axis="x", labelbottom=True)   # show xtick labels here

    # build combined legend once and shove it outside
    handles, labels = axes_flat[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1.02, 0.5))

    plt.tight_layout()
    plt.subplots_adjust(right=0.82)  # leave room for legend
    plt.show()




# ----------------------------- legend builder -----------------------------

def _add_legends(ax, df: pd.DataFrame, marker_map: dict[str, str], cdict: dict[str, str], size_scale: float) -> None:
    choices = sorted(df["history_choice"].unique())

    # 1) color legend (history_choice)
    color_handles = [
        Line2D([], [], marker="o", linestyle="", markerfacecolor=cdict[c], markeredgecolor="none", markersize=10, label=c)
        for c in choices
    ]

    # 2) size legend (expected_frames)
    expected_frames = df["expected_frames"]
    size_vals = _legend_size_values(expected_frames.astype(float).values)
    # size_handles = [
    #     ax.scatter([], [], s=_effective_size(v, size_scale), color="gray", alpha=0.6, label=f"{v} frames")
    #     for v in size_vals
    # ]
    size_handles = [
        Line2D([], [], marker="o", linestyle="", color="gray", markersize=_effective_size(v, size_scale)/10, label=f"{v} frames")
        for v in size_vals
    ]

    # 3) marker legend (model_name)
    marker_handles = [
        Line2D([], [], marker=marker_map[m], linestyle="", color="gray", markersize=10, label=m)
        for m in sorted(marker_map)
    ]

    legend_kw = dict(loc="upper left", bbox_to_anchor=(1.02, 1.1), frameon=False, alignment="left")
    first = ax.legend(handles=color_handles, title="History Choice", **legend_kw)
    ax.add_artist(first)

    legend_kw["bbox_to_anchor"] = (1.02, 0.45)
    second = ax.legend(handles=size_handles, title="Expected # Frames", **legend_kw)
    ax.add_artist(second)


    legend_kw["bbox_to_anchor"] = (1.02, 0.0)
    ax.legend(handles=marker_handles, title="Model Name", **legend_kw)
    
# ----------------------------- public API -----------------------------

def plot_mse_by_model(
    df: pd.DataFrame,
    *,
    grid: bool = False,
    sync_axes: bool = True,
    size_scale: float = 20.0,
    other_vals: dict[str, float] = {},
) -> None:
    """plot mse vs history_length.

    parameters
    ----------
    df : pd.DataFrame
        requires columns `{"model_name", "history_length", "mse", "history_choice", "expected_frames"}.
    grid : bool, default `False
        True → one subplot per model in a square grid.
    size_scale : float, default `20.0
        multiplier converting frames → marker points²; zero frames mapped to size of *1 frame*.
    """

    required = {"model_name", "history_length", "mse", "history_choice", "expected_frames"}
    if not required.issubset(df.columns):
        raise ValueError(f"df missing columns: {required.difference(df.columns)}")

    models = sorted(df["model_name"].unique())
    marker_map = _build_marker_map(models)

    choices = sorted(df["history_choice"].unique())
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    cdict: dict[str, str] = {c: colors[i % len(colors)] for i, c in enumerate(choices)}

    if grid:
        _plot_grid(df, marker_map, cdict, sync_axes=sync_axes, size_scale=size_scale, other_vals=other_vals)
    else:
        _plot_single(df, marker_map, cdict, size_scale=size_scale, other_vals=other_vals)
    return cdict


color_dict = plot_mse_by_model(comparison_df.T, grid=False, other_vals={"openvla": BASELINE_METRICS['openvla']['mse']})

In [None]:
comparison_df.T[comparison_df.T['expected_frames'] == 0].history_length.unique()


In [None]:
plot_mse_by_model(comparison_df.T, grid=True, other_vals={"openvla": BASELINE_METRICS['openvla']['mse']})

In [None]:
NO_DATA_KEY = "no data"
IMPOSSIBLE_KEY = "impossible"
color_dict[NO_DATA_KEY] = '#000000'
color_dict[IMPOSSIBLE_KEY] = '#FFFFFF'
print(color_dict)

color_mapping = {k: j for j, (k, v) in enumerate(color_dict.items())}


In [286]:
def get_history_choice_heatmap(df, global_expected_frames=None, global_hist_lengths=None):
    if global_hist_lengths:
        unique_history_lengths = global_hist_lengths
    else:
        unique_history_lengths = sorted(df.T.history_length.unique())
    
    if global_expected_frames:
        unique_expected_frames = global_expected_frames
    else:
        unique_expected_frames = sorted(df.T.expected_frames.unique())

    # drop odds for cleaner diagram
    unique_history_lengths = [x for x in unique_history_lengths if x%2 ==0 or x==1]
    unique_expected_frames = [x for x in unique_expected_frames if x%2 ==0 or x==1]
    # drop 0 for history focused diagram
    unique_history_lengths = [x for x in unique_history_lengths if x!=0]
    unique_expected_frames = [x for x in unique_expected_frames if x!=0]
    
    arr = np.zeros((len(unique_history_lengths), len(unique_expected_frames)), dtype=int)

    for i, history_length in enumerate(unique_history_lengths):
        for j, expected_frames in enumerate(unique_expected_frames):
            if expected_frames > history_length:
                arr[i, j] = color_mapping[IMPOSSIBLE_KEY]
                continue
            this_df = df.T[
                (df.T.history_length.astype(int) == history_length) &
                (df.T.expected_frames.astype(int) == expected_frames)
            ]
            if this_df.empty:
                arr[i, j] = color_mapping[NO_DATA_KEY]
                continue
                
            this_df.sort_values(by='mse', ascending=True, inplace=True)
            best_history_choice = this_df.iloc[0]['history_choice']
            arr[i, j] = color_mapping[best_history_choice]
    return arr, unique_expected_frames, unique_history_lengths

In [287]:
history_choice_heatmaps = dict()
overall_history_choice_details = get_history_choice_heatmap(comparison_df)
history_choice_heatmaps['overall'] = overall_history_choice_details[0]

for model_name in comparison_df.T.model_name.unique():
    this_df = comparison_df.T[comparison_df.T.model_name.apply(lambda x: model_name in x)].T
    history_choice_heatmaps[model_name] = get_history_choice_heatmap(this_df, overall_history_choice_details[1], overall_history_choice_details[2])[0]

In [288]:
from matplotlib.colors import ListedColormap, BoundaryNorm

def plot_history_choice_heatmaps(arr_dict, unique_expected_frames, unique_history_lengths, color_dict, color_mapping, overall_title):
    n_plots = len(arr_dict)
    n_cols = int(np.ceil(np.sqrt(n_plots)))
    n_rows = int(np.ceil(n_plots / n_cols))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15,15))
    if n_plots > 1:
        axes = axes.flatten()
    else:
        axes = [axes]

    colors = list(color_dict.values())
    cmap = plt.cm.colors.ListedColormap(colors)
    bounds = np.arange(len(color_dict)+1)
    norm = BoundaryNorm(bounds, cmap.N)

    print(f"bounds: {bounds}")

    print(f"colors: {colors}")
    print(f"cmap: {cmap}")

    for i, (title, arr) in enumerate(arr_dict.items()):
        ax = axes[i]
        im = ax.imshow(arr, cmap=cmap, norm=norm)
        ax.grid(False)
        ax.set_xticks(np.arange(len(unique_expected_frames)), unique_expected_frames)
        ax.set_yticks(np.arange(len(unique_history_lengths)), unique_history_lengths)
        ax.set_title(f"{title}", fontsize=14)
        ax.set_xlabel("Expected Frames", fontsize=12)
        ax.set_ylabel("History Length", fontsize=12)

        # Add grid lines to better separate cells
        ax.set_xticks(np.arange(len(unique_expected_frames) + 1) - 0.5, minor=True)
        ax.set_yticks(np.arange(len(unique_history_lengths) + 1) - 0.5, minor=True)
        ax.tick_params(which='minor', length=0)  # Remove the tick marks on grid lines


    cbar_ax = fig.add_axes([1.05, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax, boundaries=bounds)
    
    print(f"color mapping ({len(color_mapping)}): {color_mapping}")
    print(f"color dict ({len(color_dict)}): {color_dict}")
    
    # Create custom colorbar ticks at the center of each color segment
    cbar_ticks = np.arange(len(color_mapping)) + 0.5
    cbar.set_ticks(cbar_ticks)
    cbar_tick_labels = color_dict.keys()
    cbar.set_ticklabels(list(cbar_tick_labels))
    cbar.set_label('Best History Choice', fontsize=12)
    cbar_ax.tick_params(which='minor', length=0)  # Remove the tick marks on grid lines


    if overall_title:
        fig.suptitle(overall_title)

    plt.tight_layout()  # Adjust layout for better spacing
    # Hide any unused subplots
    for j in range(i + 1, n_rows * n_cols):
        if j < len(axes):
            axes[j].axis('off')

In [None]:
overall_title = "Best History Choice for each combination of History Length and Expected Frames"
plot_history_choice_heatmaps({"All Runs": overall_history_choice_details[0]}, overall_history_choice_details[1], overall_history_choice_details[2], color_dict, color_mapping, overall_title=overall_title)

In [None]:
plot_history_choice_heatmaps(history_choice_heatmaps, overall_history_choice_details[1], overall_history_choice_details[2], color_dict, color_mapping, overall_title="Best History Choice per History Length-Expected Frame Combination")

In [None]:
# figure out the missing cells
unique_history_choices = overall_history_choice_details[2]
unique_expected_frames = overall_history_choice_details[1]
unique_history_lengths = comparison_df.T.history_length.unique()
unique_model_names = comparison_df.T.model_name.unique()
unique_history_choices = comparison_df.T.history_choice.unique()

found_rows = comparison_df.T[['history_length', 'history_choice', 'model_name']].drop_duplicates()

missing_rows = []
for model_name in unique_model_names:
    for history_length in unique_history_lengths:
        for history_choice in unique_history_choices:
            # special checks for history 0 and 1 --> only None and One as history choices
            if history_length == 0:
                if history_choice != 'None':
                    continue
            elif history_length == 1:
                if history_choice != 'one':
                    continue
            else:
                # for other history lengths, none and one aren't valid options
                if history_choice in ['one']:
                    continue

            # Check if this combination exists in found_rows
            query = (
                (found_rows['model_name'] == model_name) & 
                (found_rows['history_length'] == history_length) & 
                (found_rows['history_choice'] == history_choice)
            )
            if not query.any():
                missing_rows.append((model_name, history_length, history_choice))

missing_rows = pd.DataFrame(missing_rows, columns=['model_name', 'history_length', 'history_choice'])
missing_rows.sort_values(by=['model_name', 'history_length', 'history_choice'], inplace=True)

print(f"found {len(found_rows)} rows")
print(f"missing {len(missing_rows)} rows")

for k, v in missing_rows.groupby("model_name"):
    print(f"missing for {k}")
    print(v.head())