In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.notebook import tqdm
from joblib import Parallel, delayed

# Create plots directory if it doesn't exist
PLOTS_DIR = Path("plots")
PLOTS_DIR.mkdir(exist_ok=True)
print(f"Plots will be saved to: {PLOTS_DIR.absolute()}")

In [None]:
columns = ["Who", "Action", "Round", "AgentAUtility", "AgentBUtility","NashDistance"]

In [None]:
metrics = ["RMSE", "Spearman", "KendallTau", "Pearson"]

In [None]:
models =['Classic Frequency Opponent Model',"CUHK Frequency Opponent Model","Bayesian Opponent Model","Stepwise COMB Opponent Model", 
         "Expectation COMB Opponent Model", "Conflict-Based Opponent Model", "Frequency Window Opponent Model"]

# =============================================================================
# COLOR SCHEME CONFIGURATION - Wong's Colorblind-Safe Palette
# =============================================================================
# Reference: Wong, B. (2011). Points of view: Color blindness. Nature Methods 8, 441.
# This palette is widely used in Nature, Science, and other scientific journals.

WONG_COLORS = {
    'orange': '#E69F00',
    'sky_blue': '#56B4E9', 
    'bluish_green': '#009E73',
    'yellow': '#F0E442',
    'blue': '#0072B2',
    'vermillion': '#D55E00',
    'reddish_purple': '#CC79A7',
    'black': '#000000'
}

# Map each opponent model to a specific color (consistent across all plots)
MODEL_COLORS = {
    'Classic Frequency Opponent Model': WONG_COLORS['blue'],
    'CUHK Frequency Opponent Model': WONG_COLORS['orange'],
    'Bayesian Opponent Model': WONG_COLORS['bluish_green'],
    'Stepwise COMB Opponent Model': WONG_COLORS['vermillion'],
    'Expectation COMB Opponent Model': WONG_COLORS['reddish_purple'],
    'Conflict-Based Opponent Model': WONG_COLORS['sky_blue'],
    'Frequency Window Opponent Model': WONG_COLORS['yellow']
}

# Map each opponent model to a distinct geometric marker
MODEL_MARKERS = {
    'Classic Frequency Opponent Model': 'o',      # Circle
    'CUHK Frequency Opponent Model': 's',         # Square
    'Bayesian Opponent Model': '^',               # Triangle up
    'Stepwise COMB Opponent Model': 'D',          # Diamond
    'Expectation COMB Opponent Model': 'v',       # Triangle down
    'Conflict-Based Opponent Model': 'p',         # Pentagon
    'Frequency Window Opponent Model': 'X'        # X (filled)
}

def get_model_color(model_name):
    """Get the assigned color for a given opponent model."""
    return MODEL_COLORS.get(model_name, WONG_COLORS['black'])

def get_model_marker(model_name):
    """Get the assigned marker for a given opponent model."""
    return MODEL_MARKERS.get(model_name, 'o')

def get_model_colors_list():
    """Get list of colors in the same order as models list."""
    return [get_model_color(model) for model in models]

def get_model_markers_list():
    """Get list of markers in the same order as models list."""
    return [get_model_marker(model) for model in models]

In [None]:
agents = ['MICROAgent','HybridAgent', 'BoulwareAgent', 'SAGAAgent', 'CUHKAgent',
          "ConcederAgent","NiceTitForTat", "IAMhaggler", "PonPokoAgent", "HardHeaded"]

def show_color_scheme():
    """
    Display the color scheme mapping for opponent models.
    Useful reference for paper figures.
    """
    fig, ax = plt.subplots(figsize=(10, 5))
    
    y_positions = np.arange(len(models))
    colors = get_model_colors_list()
    model_labels = [" ".join(model.split(" ")[0:-2]) for model in models]
    
    # Create horizontal bars with model colors
    bars = ax.barh(y_positions, [1] * len(models), color=colors, edgecolor='black', linewidth=1.5)
    
    # Labels
    ax.set_yticks(y_positions)
    ax.set_yticklabels(model_labels, fontsize=12)
    ax.set_xlim(0, 1)
    ax.set_xticks([])
    ax.set_xlabel('')
    ax.set_title("Opponent Model Color Scheme\n(Wong's Colorblind-Safe Palette)", 
                 fontsize=14, fontweight='bold', pad=20)
    
    # Add color hex codes as text
    for i, (bar, color) in enumerate(zip(bars, colors)):
        ax.text(0.5, i, color.upper(), ha='center', va='center', 
               fontsize=10, fontweight='bold', color='white' if i in [0, 3] else 'black')
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    plt.tight_layout()
    plt.show()
    
    return fig

In [None]:
sessions_path = "/home/ubuntu/negoformer/Negolog-RL/results/oracle/1000"

# =============================================================================
# DOMAIN CATEGORIZATION - Load domain metadata
# =============================================================================
domains_metadata_path = "../../domains/domains.xlsx"

def load_domain_categories():
    """
    Load domain categorization from metadata file.
    Returns dict with three category lists: size, opposition, balance.
    Each contains tuples of (domain_name, value).
    """
    df = pd.read_excel(domains_metadata_path)
    
    # Convert DomainName to string to handle mixed types
    if 'DomainName' in df.columns:
        df['DomainName'] = df['DomainName'].astype(str)
    
    # Display structure for debugging
    print("Domain metadata columns:", df.columns.tolist())
    print(f"\nLoaded {len(df)} domains")
    
    # Detect column names (flexible to handle different naming conventions)
    domain_col = None
    size_col = None
    opp_col = None
    bal_col = None
    
    for col in df.columns:
        col_lower = col.lower()
        if 'domain' in col_lower or 'name' in col_lower:
            domain_col = col
        elif 'size' in col_lower:
            size_col = col
        elif 'opposition' in col_lower or 'opp' in col_lower:
            opp_col = col
        elif 'balance' in col_lower or 'bal' in col_lower:
            bal_col = col
    
    print(f"\nDetected columns:")
    print(f"  Domain: {domain_col}")
    print(f"  Size: {size_col}")
    print(f"  Opposition: {opp_col}")
    print(f"  Balance: {bal_col}")
    
    # Create categorized lists
    domain_categories = {
        'size': [],
        'opposition': [],
        'balance': []
    }
    
    if domain_col:
        if size_col:
            domain_categories['size'] = list(zip(df[domain_col], df[size_col]))
        if opp_col:
            domain_categories['opposition'] = list(zip(df[domain_col], df[opp_col]))
        if bal_col:
            domain_categories['balance'] = list(zip(df[domain_col], df[bal_col]))
    
    return domain_categories, df

def normalize_domain_name(name):
    """Normalize domain name for matching (lowercase, strip whitespace)."""
    return str(name).lower().strip()

def create_domain_lookup(domain_results):
    """
    Create a lookup dictionary mapping normalized domain names to actual names.
    """
    return {normalize_domain_name(name): name for name in domain_results.keys()}

def get_domains_by_category(domain_results, category_type='size'):
    """
    Get domains sorted by a specific category (size, opposition, or balance).
    
    Args:
        domain_results: Results dict from compute_domain_results
        category_type: 'size', 'opposition', or 'balance'
    
    Returns:
        List of domain names sorted by the category value
    """
    if domain_metadata_df is None:
        return sorted(domain_results.keys())
    
    category_data = domain_categories.get(category_type, [])
    
    if not category_data:
        return sorted(domain_results.keys())
    
    # Create lookup for case-insensitive matching
    domain_lookup = create_domain_lookup(domain_results)
    
    # Filter to only domains that exist in results (case-insensitive)
    available_domains = []
    for name, val in category_data:
        normalized_name = normalize_domain_name(name)
        if normalized_name in domain_lookup:
            actual_name = domain_lookup[normalized_name]
            available_domains.append((actual_name, val))
    
    # Sort by value
    available_domains.sort(key=lambda x: x[1])
    
    return [name for name, val in available_domains]

def debug_domain_matching(domain_results):
    """
    Print debugging information about domain name matching.
    Helps identify why domains aren't matching between metadata and results.
    """
    print("\n" + "="*60)
    print("DOMAIN MATCHING DEBUG")
    print("="*60)
    
    if domain_metadata_df is None:
        print("No domain metadata loaded!")
        return
    
    # Get domain names from both sources (convert to strings)
    if 'DomainName' in domain_metadata_df.columns:
        metadata_domains = set(str(d) for d in domain_metadata_df['DomainName'].values)
    else:
        metadata_domains = set()
    
    results_domains = set(str(d) for d in domain_results.keys())
    
    print(f"\nDomains in metadata: {len(metadata_domains)}")
    print(f"Domains in results: {len(results_domains)}")
    
    # Show first few from each
    print(f"\nFirst 5 metadata domains: {sorted(list(metadata_domains))[:5]}")
    print(f"First 5 results domains: {sorted(list(results_domains))[:5]}")
    
    # Check for exact matches
    exact_matches = metadata_domains & results_domains
    print(f"\nExact matches: {len(exact_matches)}")
    
    # Check for case-insensitive matches
    metadata_normalized = {normalize_domain_name(d): d for d in metadata_domains}
    results_normalized = {normalize_domain_name(d): d for d in results_domains}
    
    case_insensitive_matches = set(metadata_normalized.keys()) & set(results_normalized.keys())
    print(f"Case-insensitive matches: {len(case_insensitive_matches)}")
    
    if len(case_insensitive_matches) > 0:
        print("\nExample matches:")
        for norm in list(sorted(case_insensitive_matches))[:5]:
            print(f"  Metadata: '{metadata_normalized[norm]}' <-> Results: '{results_normalized[norm]}'")
    
    # Show mismatches
    metadata_only = metadata_domains - results_domains
    results_only = results_domains - metadata_domains
    
    if metadata_only:
        print(f"\nIn metadata but NOT in results ({len(metadata_only)}): {sorted(list(metadata_only))[:5]}")
    if results_only:
        print(f"\nIn results but NOT in metadata ({len(results_only)}): {sorted(list(results_only))[:5]}")
    
    print("="*60 + "\n")

# Load once at startup
try:
    domain_categories, domain_metadata_df = load_domain_categories()
except Exception as e:
    print(f"Could not load domain metadata: {e}")
    domain_metadata_df = None
    domain_categories = {'size': [], 'opposition': [], 'balance': []}

# Helper Methods

In [None]:
# =============================================================================
# CORE UTILITY FUNCTIONS
# =============================================================================

def get_domains(base_dir):
    domains = set()
    for domain in os.listdir(base_dir):
        domain_path = os.path.join(base_dir, domain)
        if os.path.isdir(domain_path):
            domains.add(domain)
    return domains


In [None]:

# =============================================================================
# UNIFIED SESSION LOADING - Read each session file ONCE
# =============================================================================

def load_session_data(file_path):
    """
    Load all data from a session file in one pass.
    Returns a dictionary with all sheets and metadata cached.
    """
    result = {'filename': file_path.name, 'models': {}}
    
    try:
        # Parse agent names from filename: AgentA_AgentB_DomainX_ProcessY.xlsx
        parts = file_path.name.replace('.xlsx', '').split('_')
        result['agent_a'] = parts[0]
        result['agent_b'] = parts[1]
        
        # Read Session sheet
        session_df = pd.read_excel(file_path, sheet_name="Session")
        result['session_df'] = session_df
        
        # Find Accept row index
        accept_mask = session_df['Action'] == 'Accept'
        accept_idx = session_df[accept_mask].index
        result['accept_row_idx'] = accept_idx[0] if len(accept_idx) > 0 else len(session_df) - 1
        
        # Read all model sheets
        for model in models:
            try:
                model_df = pd.read_excel(file_path, sheet_name=model)
                result['models'][model] = {
                    'full_df': model_df,
                    'last_row': model_df.iloc[-1]  # Cache last row
                }
            except:
                pass
                
    except Exception as e:
        pass
    
    return result

def load_all_session_data(base_dir, n_jobs=32):
    """
    Load ALL session files across all domains ONCE.
    Returns: {domain: [SessionData, SessionData, ...]}
    """
    all_data = {}
    domains = get_domains(base_dir)
    
    print(f"\nLoading all session data from {len(domains)} domains...")
    
    for domain in domains:
        try:
            sessions_dir = Path(base_dir) / domain / "sessions"
            if sessions_dir.exists():
                files = list(sessions_dir.iterdir())
                
                # Parallel load all sessions for this domain
                session_data_list = Parallel(n_jobs=n_jobs)(
                    delayed(load_session_data)(file)
                    for file in tqdm(files, desc=f"Loading {domain}")
                )
                
                # Filter out None results
                session_data_list = [data for data in session_data_list if data is not None]
                
                all_data[domain] = session_data_list
                
        except Exception as e:
            print(f"Error loading domain {domain}: {e}")
    
    total_sessions = sum(len(sessions) for sessions in all_data.values())
    print(f"\nLoaded {total_sessions} total sessions across {len(all_data)} domains")
    
    return all_data

In [None]:
# =============================================================================
# EXTRACT FUNCTIONS - Process pre-loaded session data (no file I/O)
# =============================================================================

def extract_domain_results(session_data, metric_prefix=""):
    """
    Extract final metrics from a pre-loaded session.
    Uses cached last_row from each model.
    """
    results = {}
    
    for model in models:
        if model in session_data['models']:
            results[model] = {}
            last_row = session_data['models'][model]['last_row']
            
            for metric in metrics:
                column_name = f"{metric_prefix}{metric}"
                if column_name in last_row.index:
                    results[model][metric] = last_row[column_name]
    
    return results

def extract_round_by_round(session_data, metric_prefix=""):
    """
    Extract round-by-round metrics from a pre-loaded session.
    Only returns data for rounds that actually exist - NO NaN filling.
    """
    results = {}
    
    if 'session_df' not in session_data:
        return results
    
    session_df = session_data['session_df']
    if 'Round' not in session_df.columns:
        return results
    
    accept_row_idx = session_data.get('accept_row_idx', len(session_df) - 1)
    
    for model in models:
        if model not in session_data['models']:
            continue
            
        model_df = session_data['models'][model]['full_df']
        
        # Slice to accept row
        slice_end = min(accept_row_idx + 1, len(model_df), len(session_df))
        model_df_slice = model_df.iloc[:slice_end].reset_index(drop=True)
        session_df_slice = session_df.iloc[:slice_end].reset_index(drop=True)
        
        model_df_slice['Round'] = session_df_slice['Round']
        results[model] = {}
        
        for metric in metrics:
            column_name = f"{metric_prefix}{metric}"
            if column_name not in model_df_slice.columns:
                continue
            
            # Group by round and compute mean - returns ONLY rounds that exist
            round_means = model_df_slice.groupby('Round')[column_name].mean()
            
            # Just convert to dict: {round_number: value} - no NaN filling!
            results[model][metric] = dict(round_means)
    
    return results

def extract_opponent_analysis(session_data, metric_prefix=""):
    """
    Extract per-agent metrics from a pre-loaded session.
    Uses cached last_row and pre-parsed agent names.
    
    Args:
        session_data: Pre-loaded session data
        metric_prefix: Prefix for metric columns (e.g., "Overall_" for Pareto metrics)
    """
    results = {}
    
    agent_a = session_data['agent_a']
    agent_b = session_data['agent_b']
    
    if agent_a not in results:
        results[agent_a] = {}
    if agent_b not in results:
        results[agent_b] = {}
    
    for model in models:
        if model not in session_data['models']:
            continue
            
        last_row = session_data['models'][model]['last_row']
        
        if model not in results[agent_a]:
            results[agent_a][model] = {}
        if model not in results[agent_b]:
            results[agent_b][model] = {}
        
        # Extract metrics for Agent A
        for metric in metrics:
            # Try with _A suffix first, then without
            metric_col_a = f"{metric_prefix}{metric}_A"
            if metric_col_a not in last_row.index:
                metric_col_a = f"{metric_prefix}{metric}"
            if metric_col_a in last_row.index:
                results[agent_a][model][metric] = last_row[metric_col_a]
        
        # Extract metrics for Agent B
        for metric in metrics:
            # Try with _B suffix first, then without
            metric_col_b = f"{metric_prefix}{metric}_B"
            if metric_col_b not in last_row.index:
                metric_col_b = f"{metric_prefix}{metric}"
            if metric_col_b in last_row.index:
                results[agent_b][model][metric] = last_row[metric_col_b]
    
    return results
def extract_box_rmse(session_data):
    """
    Safe version: Checks if 'session_df' exists before accessing it.
    """
    results = {}
    
    # SAFETY CHECK: Return empty dict if key is missing
    if 'session_df' not in session_data:
        return results

    session_df = session_data['session_df']
    if 'Round' not in session_df.columns:
        return results
    
    # Safety Check: Accept row might be missing
    accept_row_idx = session_data.get('accept_row_idx', len(session_df))
    max_round = int(session_df['Round'].max())
    
    for model in models:
        # Safety Check
        if model not in session_data['models']:
            continue
            
        model_df = session_data['models'][model]['full_df']
        
        if 'BoxCount_A' not in model_df.columns or 'BoxCount_B' not in model_df.columns:
            continue
        
        try:
            box_count_a = int(model_df['BoxCount_A'].iloc[0])
            box_count_b = int(model_df['BoxCount_B'].iloc[0])
        except:
            continue
            
        common_boxes = min(box_count_a, box_count_b)
        
        if common_boxes == 0:
            continue
        
        model_data = {box_idx: [[] for _ in range(max_round + 1)] for box_idx in range(common_boxes)}
        
        for row_idx in range(min(len(session_df), len(model_df), accept_row_idx + 1)):
            if session_df.iloc[row_idx]["Action"] == "Accept":
                break
            
            try:
                round_num = int(session_df.iloc[row_idx]["Round"])
            except:
                continue
            
            for box_idx in range(common_boxes):
                rmse_a_col = f'Box{box_idx}_RMSE_A'
                rmse_b_col = f'Box{box_idx}_RMSE_B'
                
                if rmse_a_col in model_df.columns and rmse_b_col in model_df.columns:
                    rmse_a = model_df.iloc[row_idx][rmse_a_col]
                    rmse_b = model_df.iloc[row_idx][rmse_b_col]
                    
                    if pd.notna(rmse_a) and pd.notna(rmse_b):
                        avg_rmse = (rmse_a + rmse_b) / 2
                        model_data[box_idx][round_num].append(avg_rmse)
        
        averaged_data = {}
        for box_idx in range(common_boxes):
            averaged_data[box_idx] = []
            for round_values in model_data[box_idx]:
                if round_values:
                    averaged_data[box_idx].append(np.mean(round_values))
        
        results[model] = averaged_data
    
    return results

In [None]:
# =============================================================================
# COMPUTE FUNCTIONS - Aggregate across pre-loaded sessions
# =============================================================================

def compute_domain_results(all_sessions, metric_prefix=""):
    """
    Compute domain-level results from pre-loaded sessions.
    Returns: {domain: {model: {metric: [list of values]}}}
    """
    domain_results = {}
    
    for domain, session_list in all_sessions.items():
        # Initialize domain storage
        domain_results[domain] = {model: {metric: [] for metric in metrics} for model in models}
        
        # Extract results from each session
        for session_data in session_list:
            session_result = extract_domain_results(session_data, metric_prefix)
            
            for model in models:
                if model in session_result:
                    for metric in metrics:
                        if metric in session_result[model]:
                            domain_results[domain][model][metric].append(
                                session_result[model][metric]
                            )
    
    return domain_results

def compute_round_by_round(all_sessions, metric_prefix="", n_jobs=32):
    """
    Compute round-by-round aggregated results from pre-loaded sessions.
    For each round, only averages sessions that actually reached that round.
    Also tracks mean session end round per agent to find when agents finish.
    Calculates quartiles for agent activity (100%, 75%, 25%, 10% active).
    """
    all_session_list = []
    for domain, session_list in all_sessions.items():
        all_session_list.extend(session_list)
    
    total_sessions = len(all_session_list)
    print(f"Processing {total_sessions} sessions for round-by-round analysis...")
    
    session_results = Parallel(n_jobs=n_jobs)(
        delayed(extract_round_by_round)(session_data, metric_prefix)
        for session_data in tqdm(all_session_list, desc="Extracting round-by-round")
    )
    
    # Filter out empty results
    session_results = [res for res in session_results if res]
    
    # Collect values per round: {model: {metric: {round_num: [values from all sessions]}}}
    round_data = {model: {metric: {} for metric in metrics} for model in models}
    
    # Track all session end rounds for each agent
    agent_session_ends = {}
    
    # Get all unique agents
    all_agents = set()
    for session_data in all_session_list:
        if 'agent_a' in session_data and 'agent_b' in session_data:
            all_agents.add(session_data['agent_a'])
            all_agents.add(session_data['agent_b'])
    
    total_agents = len(all_agents)
    print(f"Total unique agents: {total_agents}")
    
    # Process each session
    for idx, session_result in enumerate(session_results):
        session_data = all_session_list[idx]
        
        # Get agents and when this session ended
        if 'agent_a' not in session_data or 'agent_b' not in session_data:
            continue
        
        agent_a = session_data['agent_a']
        agent_b = session_data['agent_b']
        
        # Get the round when this session ended
        session_df = session_data.get('session_df')
        if session_df is None or 'Round' not in session_df.columns:
            continue
        
        accept_row_idx = session_data.get('accept_row_idx', len(session_df) - 1)
        
        try:
            end_round = int(session_df.iloc[accept_row_idx]['Round'])
        except:
            continue
        
        # Track all end rounds for each agent
        if agent_a not in agent_session_ends:
            agent_session_ends[agent_a] = []
        if agent_b not in agent_session_ends:
            agent_session_ends[agent_b] = []
        
        agent_session_ends[agent_a].append(end_round)
        agent_session_ends[agent_b].append(end_round)
        
        for model in models:
            if model not in session_result:
                continue
            for metric in metrics:
                if metric not in session_result[model]:
                    continue
                # session_result[model][metric] is a dict: {round_num: value}
                for round_num, value in session_result[model][metric].items():
                    if pd.notna(value):  # Skip NaN values
                        if round_num not in round_data[model][metric]:
                            round_data[model][metric][round_num] = []
                        round_data[model][metric][round_num].append(value)
    
    # Calculate mean end round for each agent
    agent_mean_ends = {}
    for agent, end_rounds in agent_session_ends.items():
        if end_rounds:
            agent_mean_ends[agent] = np.mean(end_rounds)
    
    # Get the sorted list of agent mean end rounds
    sorted_mean_ends = sorted(agent_mean_ends.values())
    
    # Calculate quartile rounds based on agent count
    # - 100% active (min): round where all agents still have sessions
    # - 75% active (25th percentile): round where 75% of agents still have sessions  
    # - 25% active (75th percentile): round where only 25% of agents still have sessions
    # - 10% active (90th percentile): round where only 10% of agents still have sessions
    last_full_round = int(min(sorted_mean_ends)) if sorted_mean_ends else 0
    round_75_pct_active = int(np.percentile(sorted_mean_ends, 25)) if sorted_mean_ends else 0
    round_25_pct_active = int(np.percentile(sorted_mean_ends, 75)) if sorted_mean_ends else 0
    round_10_pct_active = int(np.percentile(sorted_mean_ends, 90)) if sorted_mean_ends else 0
    
    # Compute mean and std for each round
    aggregated_results = {model: {metric: {'mean': [], 'std': []} for metric in metrics} for model in models}
    
    for model in models:
        for metric in metrics:
            if not round_data[model][metric]:
                continue
            
            # Sort by round number
            sorted_rounds = sorted(round_data[model][metric].keys())
            
            for round_num in sorted_rounds:
                values = round_data[model][metric][round_num]
                if values:
                    aggregated_results[model][metric]['mean'].append(np.mean(values))
                    aggregated_results[model][metric]['std'].append(np.std(values))
    
    # Add metadata about the rounds and quartiles
    aggregated_results['_metadata'] = {
        'last_full_round': last_full_round,
        'round_75_pct_active': round_75_pct_active,
        'round_25_pct_active': round_25_pct_active,
        'round_10_pct_active': round_10_pct_active,
        'total_agents': total_agents,
        'agents_at_last_round': total_agents
    }
    
    print(f"Agent activity milestones:")
    print(f"  - 100% agents active until round {last_full_round}")
    print(f"  - 75% agents active until round {round_75_pct_active}")
    print(f"  - 25% agents active until round {round_25_pct_active}")
    print(f"  - 10% agents active until round {round_10_pct_active}")
    
    return aggregated_results

def compute_agent_analysis(all_sessions, metric_prefix=""):
    """
    Compute per-agent opponent model performance from pre-loaded sessions.
    Returns: {agent: {model: {metric: [list of values]}}}
    """
    agent_results = {}
    
    # Process all sessions across all domains
    for domain, session_list in all_sessions.items():
        for session_data in session_list:
            session_result = extract_opponent_analysis(session_data, metric_prefix=metric_prefix)
            
            for agent, models_data in session_result.items():
                if agent not in agent_results:
                    agent_results[agent] = {model: {metric: [] for metric in metrics} for model in models}
                
                for model, metrics_data in models_data.items():
                    for metric, value in metrics_data.items():
                        agent_results[agent][model][metric].append(value)
    
    return agent_results

def compute_box_rmse_for_domain(session_list, n_jobs=1):
    """
    Compute box-specific RMSE for a SINGLE domain.
    Safe version: Use n_jobs=1 if calling from another Parallel loop.
    """
    
    # PARALLEL extraction
    # SAFETY: If outer loop is parallel, we MUST use serial extraction here
    extraction_jobs = n_jobs if n_jobs is not None else 1
    
    if extraction_jobs > 1:
        session_results = Parallel(n_jobs=extraction_jobs)(
            delayed(extract_box_rmse)(session_data)
            for session_data in session_list # Removed tqdm to avoid nested bars mess
        )
    else:
        # Serial (Fastest when running inside another process)
        session_results = [extract_box_rmse(s) for s in session_list]
    
    # Filter empty results (Crucial for avoiding crashes)
    session_results = [r for r in session_results if r]
    
    # SEQUENTIAL aggregation
    all_session_results = {model: {} for model in models}
    
    for session_result in session_results:
        for model in models:
            if model not in session_result:
                continue
            for box_idx, values in session_result[model].items():
                if box_idx not in all_session_results[model]:
                    all_session_results[model][box_idx] = []
                all_session_results[model][box_idx].append(values)
    
    # Compute mean/std
    aggregated_results = {}
    for model in models:
        aggregated_results[model] = {}
        for box_idx, session_arrays in all_session_results[model].items():
            if not session_arrays:
                continue
            max_rounds = max(len(arr) for arr in session_arrays)
            means, stds = [], []
            for r in range(max_rounds):
                vals = [arr[r] for arr in session_arrays if r < len(arr)]
                if vals:
                    means.append(np.mean(vals))
                    stds.append(np.std(vals))
            aggregated_results[model][box_idx] = {'mean': means, 'std': stds}
    
    return aggregated_results

In [None]:
# =============================================================================
# AGGREGATION HELPERS
# =============================================================================

def aggregate_results_for_boxplot(domain_results):
    """
    Aggregate domain results for boxplot visualization.
    
    Returns:
        DataFrame: Models x Metrics, where each cell contains a list of domain values
    """
    overall_results = pd.DataFrame([[[] for _ in metrics] for _ in models], 
                                   index=models, columns=metrics)
    
    for domain, domain_data in domain_results.items():
        for model in models:
            if model in domain_data:
                for metric in metrics:
                    if metric in domain_data[model]:
                        # domain_data[model][metric] is already a list of values from sessions
                        # We extend the overall list with these values
                        overall_results.loc[model, metric].extend(domain_data[model][metric])
    
    return overall_results

In [None]:
# =============================================================================
# CACHING - Save/load sessions to avoid reloading
# =============================================================================

import pickle
def load_sessions_cache(cache_path="session_cache.pkl"):
    """
    Load sessions from pickle cache if available.

    Args:
        cache_path: Path to the cache file

    Returns:
        all_sessions dict if cache exists, None otherwise
    """
    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as f:
            all_sessions = pickle.load(f)

        total_sessions = sum(len(sessions) for sessions in all_sessions.values())
        print(f"✓ Loaded {total_sessions} cached sessions from {cache_path}")
        return all_sessions

    return None


# Visualizations

In [None]:
def plot_all_metrics_subplots(overall_results, save_path=None):
    """
    Creates boxplots for each metric showing distribution across all sessions.
    
    Uses Wong's colorblind-safe palette with consistent colors per model.
    
    Args:
        overall_results: DataFrame with models as index and metrics as columns
        save_path: Optional path to save the figure
    """
    metrics = overall_results.columns
    num_metrics = len(metrics)

    fig, axes = plt.subplots(1, num_metrics, figsize=(10 * num_metrics, 12))

    if num_metrics == 1:
        axes = [axes]

    for ax, metric in zip(axes, metrics):
        data = []
        labels = []
        colors = []

        for model in overall_results.index:
            values = overall_results.loc[model, metric]
            if isinstance(values, list) and len(values) > 0:
                data.append(values)
                labels.append(" ".join(model.split(" ")[0:-2]))
                colors.append(get_model_color(model))

        bp = ax.boxplot(data, patch_artist=True, widths=0.6)
        
        # Color each box with the corresponding model color
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
            patch.set_linewidth(2)
        
        # Style the other elements with thicker lines
        for element in ['whiskers', 'caps']:
            plt.setp(bp[element], color='black', linewidth=2)
        plt.setp(bp['medians'], color='black', linewidth=3)
        plt.setp(bp['fliers'], markeredgecolor='black', markersize=8, alpha=0.5)
        
        # Larger, bolder title and labels
        # Title removed per user request
        ax.set_xlabel('Opponent Model', fontsize=22, fontweight='bold', labelpad=12)
        ax.set_ylabel(metric, fontsize=22, fontweight='bold', labelpad=12)
        
        # Larger, bolder tick labels
        ax.set_xticks(range(1, len(labels) + 1))
        ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=16, fontweight='bold')
        ax.tick_params(axis='y', labelsize=16)
        for label in ax.get_yticklabels():
            label.set_fontweight('bold')
        
        ax.grid(True, alpha=0.3, axis='y', linewidth=1.5)

    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    plt.show()
    return fig

In [None]:
def plot_round_by_round_performance(aggregated_results, save_path=None, figsize=(28, 16), marker_interval=3):
    """
    Plot round-by-round performance for all metrics as line plots.
    
    Each subplot shows one metric with all opponent models as different lines.
    Uses Wong's colorblind-safe palette for consistency across all visualizations.
    Uses distinct geometric markers (circle, square, triangle, diamond, etc.) for each model.
    Adds vertical lines showing agent activity milestones:
      - Red solid: 100% agents active
      - Orange dashed: 75% agents active
      - Purple dotted: 25% agents active
      - Green dash-dot: 10% agents active
    
    Args:
        aggregated_results: Results from compute_round_by_round
        save_path: Optional path to save the figure
        figsize: Figure size tuple
        marker_interval: Show markers every N rounds (default: 3)
    """
    num_metrics = len(metrics)
    
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    axes = axes.flatten()
    
    # Extract metadata about rounds and quartiles
    metadata = aggregated_results.get('_metadata', {})
    last_full_round = metadata.get('last_full_round', 0)
    round_75_pct = metadata.get('round_75_pct_active', 0)
    round_25_pct = metadata.get('round_25_pct_active', 0)
    round_10_pct = metadata.get('round_10_pct_active', 0)
    total_agents = metadata.get('total_agents', 0)
    
    # Store handles and labels for shared legend
    handles = []
    labels = []
    
    for idx, metric in enumerate(metrics):
        ax = axes[idx]
        
        for model in models:
            if metric in aggregated_results[model]:
                mean_values = aggregated_results[model][metric]['mean'][1:]
                std_values = aggregated_results[model][metric]['std'][1:]
                
                if mean_values:
                    rounds = list(range(len(mean_values)))
                    mean_array = np.array(mean_values)
                    
                    # Shorten model name for legend
                    model_short = " ".join(model.split(" ")[0:-2])
                    
                    # Get consistent color and marker for this model
                    color = get_model_color(model)
                    marker = get_model_marker(model)
                    
                    # Plot line without markers first
                    line, = ax.plot(rounds, mean_array, 
                           label=model_short, 
                           color=color, 
                           linewidth=1.5,
                           alpha=0.9)
                    
                    # Add markers at sampled intervals
                    sampled_rounds = rounds[::marker_interval]
                    sampled_values = mean_array[::marker_interval]
                    ax.scatter(sampled_rounds, sampled_values,
                              marker=marker,
                              color=color,
                              s=100,  # marker size
                              edgecolors='black',
                              linewidths=1.0,
                              zorder=5,
                              alpha=0.9)
                    
                    # Collect handles and labels only from first subplot
                    if idx == 0:
                        # Create a combined handle for legend
                        from matplotlib.lines import Line2D
                        combined_handle = Line2D([0], [0], color=color, linewidth=1.5,
                                                 marker=marker, markersize=10,
                                                 markeredgecolor='black', markeredgewidth=1.0)
                        handles.append(combined_handle)
                        labels.append(model_short)
        
        # Add vertical lines for agent activity milestones
        # 100% agents active (solid red)
        if last_full_round > 0:
            plot_position = last_full_round - 1  # Subtract 1 because we skip round 0
            ax.axvline(x=plot_position, color='red', linestyle='-', 
                      linewidth=2.5, alpha=0.8, zorder=10)
        
        # 75% agents active (dashed orange)
        if round_75_pct > 0:
            plot_position = round_75_pct - 1
            ax.axvline(x=plot_position, color='#FF8C00', linestyle='--', 
                      linewidth=2.5, alpha=0.8, zorder=10)
        
        # 25% agents active (dotted purple)
        if round_25_pct > 0:
            plot_position = round_25_pct - 1
            ax.axvline(x=plot_position, color='#8B008B', linestyle=':', 
                      linewidth=2.5, alpha=0.8, zorder=10)
        
        # 10% agents active (dash-dot green)
        if round_10_pct > 0:
            plot_position = round_10_pct - 1
            ax.axvline(x=plot_position, color='#006400', linestyle='-.', 
                      linewidth=2.5, alpha=0.8, zorder=10)
        
        # Larger, bolder labels and title
        ax.set_xlabel('Round', fontsize=20, fontweight='bold', labelpad=10)
        ax.set_ylabel(metric, fontsize=20, fontweight='bold', labelpad=10)
        # Title removed per user request
        
        # Larger, bolder tick labels
        ax.tick_params(axis='both', labelsize=16)
        for label in ax.get_xticklabels():
            label.set_fontweight('bold')
        for label in ax.get_yticklabels():
            label.set_fontweight('bold')
        
        ax.grid(True, alpha=0.3, linewidth=1.5)
    
    # Add vertical lines to legend
    from matplotlib.lines import Line2D
    
    if last_full_round > 0:
        line_100 = Line2D([0], [0], color='red', linestyle='-', linewidth=2.5)
        handles.append(line_100)
        labels.append(f'100% active (round {last_full_round})')
    
    if round_75_pct > 0:
        line_75 = Line2D([0], [0], color='#FF8C00', linestyle='--', linewidth=2.5)
        handles.append(line_75)
        labels.append(f'75% active (round {round_75_pct})')
    
    if round_25_pct > 0:
        line_25 = Line2D([0], [0], color='#8B008B', linestyle=':', linewidth=2.5)
        handles.append(line_25)
        labels.append(f'25% active (round {round_25_pct})')
    
    if round_10_pct > 0:
        line_10 = Line2D([0], [0], color='#006400', linestyle='-.', linewidth=2.5)
        handles.append(line_10)
        labels.append(f'10% active (round {round_10_pct})')
    
    # Add single large legend outside the plots
    fig.legend(handles, labels, 
              loc='center left',
              bbox_to_anchor=(1.0, 0.5),
              fontsize=16,
              frameon=True,
              shadow=True,
              framealpha=0.95,
              edgecolor='black',
              fancybox=True,
              title='Opponent Model',
              title_fontsize=18,
              ncol=1)
    
    plt.tight_layout()
    # Make room for the legend
    plt.subplots_adjust(right=0.82)
    
    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    plt.show()
    return fig

In [None]:
def plot_domain_heatmap(domain_results, metric='RMSE', figsize=(14, 10)):
    """
    Creates a heatmap showing opponent model performance across different domains.
    
    Args:
        domain_results: Dictionary from compute_domain_results()
        metric: Which metric to visualize ('RMSE', 'Spearman', 'KendallTau', 'Pearson')
        figsize: Figure size tuple
        
    Returns:
        matplotlib figure object
    """
    # Get sorted domain names
    domain_names = sorted(domain_results.keys())
    
    # Initialize matrix to store metric values
    metric_matrix = np.zeros((len(models), len(domain_names)))
    metric_matrix[:] = np.nan  # Start with NaN for missing data
    
    # Populate the matrix (compute mean of session values)
    for j, domain in enumerate(domain_names):
        for i, model in enumerate(models):
            if model in domain_results[domain] and metric in domain_results[domain][model]:
                values = domain_results[domain][model][metric]
                if values:  # Check if list is not empty
                    metric_matrix[i, j] = np.mean(values)
    
    # Calculate median for each model (row) and sort
    row_medians = np.nanmedian(metric_matrix, axis=1)
    
    # Sort indices: ascending for RMSE (lower is better), descending for correlations (higher is better)
    if metric == 'RMSE':
        sorted_indices = np.argsort(row_medians)  # Ascending
    else:
        sorted_indices = np.argsort(row_medians)[::-1]  # Descending
    
    # Reorder matrix and labels
    metric_matrix = metric_matrix[sorted_indices, :]
    sorted_models = [models[i] for i in sorted_indices]
    model_labels = [" ".join(model.split(" ")[0:-2]) for model in sorted_models]
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    
    # Choose colormap based on metric (lower is better for RMSE, higher is better for correlations)
    if metric == 'RMSE':
        cmap = 'rocket_r'  # Lower values are better (lighter color)
    else:
        cmap = 'rocket'  # Higher values are better (darker color)
    
    # Create heatmap
    sns.heatmap(metric_matrix, annot=True, fmt='.3f', cmap=cmap,
                xticklabels=domain_names, yticklabels=model_labels,
                ax=ax, cbar_kws={'label': metric},
                mask=np.isnan(metric_matrix))  # Mask NaN values
    
    # Labels
    ax.set_xlabel('Domain', fontsize=12)
    ax.set_ylabel('Opponent Model', fontsize=12)
    # Title removed per user request
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    plt.tight_layout()
    plt.close()
    
    return fig


def plot_domain_heatmap_by_categories(domain_results, metric='RMSE', figsize=(32, 28)):
    """
    Creates three beautiful heatmaps showing opponent model performance across domains,
    grouped by domain characteristics: Size, Opposition, and Balance.
    
    Args:
        domain_results: Dictionary from compute_domain_results()
        metric: Which metric to visualize ('RMSE', 'Spearman', 'KendallTau', 'Pearson')
        figsize: Figure size tuple for the entire figure
        
    Returns:
        matplotlib figure object
    """
    # Create figure with 3 subplots (one per category)
    fig, axes = plt.subplots(3, 1, figsize=figsize)
    
    categories = ['size', 'opposition', 'balance']
    category_titles = ['Domain Size', 'Domain Opposition', 'Domain Balance']
    
    # Choose colormap - using more readable colormaps
    if metric == 'RMSE':
        # For RMSE: lower is better - use reversed colormap (light = good, dark = bad)
        cmap = 'YlOrRd'  # Yellow (good) to Red (bad)
    else:
        # For correlations: higher is better - use sequential colormap (dark = good, light = bad)
        cmap = 'YlGnBu'  # Yellow (bad) to Blue/Green (good)
    
    for idx, (category, cat_title) in enumerate(zip(categories, category_titles)):
        ax = axes[idx]
        
        # Get domains sorted by this category
        domain_names = get_domains_by_category(domain_results, category)
        
        if not domain_names:
            ax.text(0.5, 0.5, f'No domain data for {cat_title}', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=20)
            ax.axis('off')
            continue
        
        # Initialize matrix to store metric values
        metric_matrix = np.zeros((len(models), len(domain_names)))
        metric_matrix[:] = np.nan
        
        # Populate the matrix
        for j, domain in enumerate(domain_names):
            for i, model in enumerate(models):
                if model in domain_results[domain] and metric in domain_results[domain][model]:
                    values = domain_results[domain][model][metric]
                    if values:
                        metric_matrix[i, j] = np.mean(values)
        
        # Calculate median for each model and sort
        row_medians = np.nanmedian(metric_matrix, axis=1)
        
        if metric == 'RMSE':
            sorted_indices = np.argsort(row_medians)  # Ascending (best first)
        else:
            sorted_indices = np.argsort(row_medians)[::-1]  # Descending (best first)
        
        # Reorder matrix and labels
        metric_matrix = metric_matrix[sorted_indices, :]
        sorted_models = [models[i] for i in sorted_indices]
        model_labels = [" ".join(model.split(" ")[0:-2]) for model in sorted_models]
        
        # Create heatmap with larger fonts and cells
        sns.heatmap(metric_matrix, annot=True, fmt='.3f', cmap=cmap,
                    xticklabels=domain_names, yticklabels=model_labels,
                    ax=ax, 
                    cbar_kws={'label': metric, 'shrink': 0.8},
                    mask=np.isnan(metric_matrix),
                    annot_kws={'fontsize': 11, 'weight': 'bold'},  # Bigger annotation font
                    linewidths=0.5, linecolor='white',  # Cell borders for clarity
                    square=False)  # Allow rectangular cells
        
        # Larger, bolder fonts for axis labels
        ax.set_xlabel('Domain', fontsize=22, fontweight='bold', labelpad=12)
        ax.set_ylabel('Opponent Model', fontsize=22, fontweight='bold', labelpad=12)
        # Title removed per user request
        
        # Larger, bolder tick labels
        ax.tick_params(axis='x', rotation=45, labelsize=16)
        ax.tick_params(axis='y', rotation=0, labelsize=16)
        
        # Make tick labels bold
        for label in ax.get_xticklabels():
            label.set_fontweight('bold')
        for label in ax.get_yticklabels():
            label.set_fontweight('bold')
        
        plt.setp(ax.get_xticklabels(), ha='right')
        
        # Larger colorbar label
        cbar = ax.collections[0].colorbar
        cbar.ax.tick_params(labelsize=14)
        cbar.set_label(metric, fontsize=18, fontweight='bold')
    
    # Suptitle removed per user request
    plt.tight_layout()
    plt.close()
    
    return fig


def plot_opponent_histogram(agent_results, metric='RMSE', bins=None, figsize=(28, 24)):
    """
    Creates stacked histograms showing opponent model performance distribution per agent.
    
    One histogram per opponent model, where each agent is a different color in the stack.
    
    Args:
        agent_results: Dictionary from compute_agent_analysis()
        metric: Which metric to visualize ('RMSE', 'Spearman', 'KendallTau', 'Pearson')
        bins: Number of bins or bin edges (default: auto)
        figsize: Figure size tuple
        
    Returns:
        matplotlib figure object
    """
    # Get sorted agent names
    agent_names = sorted(agent_results.keys())
    
    # Create a 3x3 grid for 7 models
    fig, axes = plt.subplots(3, 3, figsize=figsize)
    axes = axes.flatten()
    
    # Generate colors for agents using tab10 colormap
    agent_colors = plt.cm.tab10(np.linspace(0, 1, len(agent_names)))
    
    # Find global min/max for consistent binning across all models
    all_values = []
    for model in models:
        for agent in agent_names:
            if model in agent_results[agent] and metric in agent_results[agent][model]:
                values = agent_results[agent][model][metric]
                if values:
                    all_values.extend(values)
    
    if not all_values:
        print("No data to plot")
        return fig
    
    # Determine bins
    if bins is None:
        bins = 20  # Default number of bins
    
    global_min = np.min(all_values)
    global_max = np.max(all_values)
    bin_edges = np.linspace(global_min, global_max, bins + 1)
    
    # Plot each model
    for model_idx, model in enumerate(models):
        ax = axes[model_idx]
        
        # Collect data for each agent
        agent_data = []
        agent_labels = []
        agent_color_list = []
        
        for agent_idx, agent in enumerate(agent_names):
            if model in agent_results[agent] and metric in agent_results[agent][model]:
                values = agent_results[agent][model][metric]
                if values:
                    agent_data.append(values)
                    agent_labels.append(agent)
                    agent_color_list.append(agent_colors[agent_idx])
        
        if agent_data:
            # Create stacked histogram
            ax.hist(agent_data, bins=bin_edges, stacked=True, 
                   color=agent_color_list, label=agent_labels,
                   alpha=0.85, edgecolor='black', linewidth=0.5)
            
            # Labels and title
            model_short = " ".join(model.split(" ")[0:-2])
            # Title removed per user request
            ax.set_xlabel(metric, fontsize=18, fontweight='bold', labelpad=10)
            ax.set_ylabel('Count', fontsize=18, fontweight='bold', labelpad=10)
            
            # Larger, bolder tick labels
            ax.tick_params(axis='both', labelsize=14)
            for label in ax.get_xticklabels():
                label.set_fontweight('bold')
            for label in ax.get_yticklabels():
                label.set_fontweight('bold')
            
            # Grid
            ax.grid(True, alpha=0.3, linewidth=1.5, axis='y')
            
            # Legend only on first subplot
            if model_idx == 0:
                ax.legend(title='Agent', title_fontsize=14, 
                         fontsize=12, loc='upper right',
                         frameon=True, shadow=True, framealpha=0.9)
        else:
            ax.axis('off')
    
    # Hide extra subplots
    for idx in range(len(models), 9):
        axes[idx].axis('off')
    
    # Suptitle removed per user request
    plt.tight_layout()
    plt.show()
    
    return fig


def plot_opponent_heatmap(agent_results, metric='RMSE', figsize=(18, 12)):
    """
    Creates a beautiful heatmap showing opponent model performance for each agent.
    
    Args:
        agent_results: Dictionary from compute_agent_analysis()
        metric: Which metric to visualize ('RMSE', 'Spearman', 'KendallTau', 'Pearson')
        figsize: Figure size tuple
        
    Returns:
        matplotlib figure object
    """
    # Get sorted agent names
    agent_names = sorted(agent_results.keys())
    
    # Initialize matrix to store metric values
    metric_matrix = np.zeros((len(models), len(agent_names)))
    metric_matrix[:] = np.nan  # Start with NaN for missing data
    
    # Populate the matrix (compute mean across all sessions for each agent-model pair)
    for j, agent in enumerate(agent_names):
        for i, model in enumerate(models):
            if model in agent_results[agent] and metric in agent_results[agent][model]:
                values = agent_results[agent][model][metric]
                if values:
                    metric_matrix[i, j] = np.mean(values)
    
    # Calculate median for each model (row) and sort
    row_medians = np.nanmedian(metric_matrix, axis=1)
    
    # Sort indices: ascending for RMSE (lower is better), descending for correlations (higher is better)
    if metric == 'RMSE':
        sorted_indices = np.argsort(row_medians)  # Ascending
    else:
        sorted_indices = np.argsort(row_medians)[::-1]  # Descending
    
    # Reorder matrix and labels
    metric_matrix = metric_matrix[sorted_indices, :]
    sorted_models = [models[i] for i in sorted_indices]
    model_labels = [" ".join(model.split(" ")[0:-2]) for model in sorted_models]
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    
    # Choose colormap - using more readable colormaps
    if metric == 'RMSE':
        cmap = 'YlOrRd'  # Yellow (good) to Red (bad)
    else:
        cmap = 'YlGnBu'  # Yellow (bad) to Blue/Green (good)
    
    # Create heatmap with larger fonts
    sns.heatmap(metric_matrix, annot=True, fmt='.3f', cmap=cmap,
                xticklabels=agent_names, yticklabels=model_labels,
                ax=ax, 
                cbar_kws={'label': metric},
                mask=np.isnan(metric_matrix),
                annot_kws={'fontsize': 11, 'weight': 'bold'},
                linewidths=0.5, linecolor='white')
    
    # Larger, bolder fonts for axis labels
    ax.set_xlabel('Agent', fontsize=20, fontweight='bold', labelpad=12)
    ax.set_ylabel('Opponent Model', fontsize=20, fontweight='bold', labelpad=12)
    # Title removed per user request
    
    # Larger, bolder tick labels
    ax.tick_params(axis='x', rotation=45, labelsize=15)
    ax.tick_params(axis='y', rotation=0, labelsize=15)
    
    # Make tick labels bold
    for label in ax.get_xticklabels():
        label.set_fontweight('bold')
    for label in ax.get_yticklabels():
        label.set_fontweight('bold')
    
    plt.setp(ax.get_xticklabels(), ha='right')
    
    # Larger colorbar
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=13)
    cbar.set_label(metric, fontsize=17, fontweight='bold')
    
    plt.tight_layout()
    plt.close()
    
    return fig

In [None]:
def plot_pareto_box_rmse_by_model(box_results, domain=None, figsize=(28, 24)):
    """
    Plot box-specific RMSE performance for all opponent models.

    Creates a 3x3 grid (7 models + 2 empty).
    Each subplot shows RMSE over rounds with one line per box index.
    
    Args:
        box_results: Dict from compute_box_rmse_for_domain
        domain: Domain name for title (optional)
    """
    fig, axes = plt.subplots(3, 3, figsize=figsize)
    axes = axes.flatten()

    for model_idx, model in enumerate(models):
        ax = axes[model_idx]

        if not box_results[model]:
            ax.axis('off')
            continue

        # Get all box indices for this model
        box_indices = sorted(box_results[model].keys())

        # Define colors using viridis: early boxes (0) = purple, late boxes (max) = yellow
        colors = plt.cm.viridis(np.linspace(0, 1, len(box_indices)))

        # Plot line for each box
        for color_idx, box_idx in enumerate(box_indices):
            mean_values = box_results[model][box_idx]['mean'][1:]  # Skip round 0

            if mean_values:
                rounds = list(range(1, len(mean_values) + 1))
                mean_array = np.array(mean_values)

                ax.plot(rounds, mean_array,
                       label=f'Box {box_idx}',
                       color=colors[color_idx],
                       linewidth=3,
                       markersize=6,
                       marker='o',
                       alpha=0.85)

        # Larger, bolder labels and title
        ax.set_xlabel('Round', fontsize=18, fontweight='bold', labelpad=10)
        ax.set_ylabel('RMSE', fontsize=18, fontweight='bold', labelpad=10)

        model_short = " ".join(model.split(" ")[0:-2])
        domain_str = f" (Domain {domain})" if domain else ""
        # Title removed per user request
        
        # Larger, bolder tick labels
        ax.tick_params(axis='both', labelsize=14)
        for label in ax.get_xticklabels():
            label.set_fontweight('bold')
        for label in ax.get_yticklabels():
            label.set_fontweight('bold')
        
        # Larger legend
        ax.legend(title='Box Index', title_fontsize=13, loc='best', 
                 fontsize=12, ncol=2, frameon=True, shadow=True,
                 framealpha=0.9, edgecolor='black')
        
        ax.grid(True, alpha=0.3, linewidth=1.5)

    # Hide last 2 subplots
    for idx in range(len(models), 9):
        axes[idx].axis('off')

    plt.tight_layout()
    plt.show()

# Significance Tests

In [None]:
from scipy.stats import friedmanchisquare, wilcoxon, ttest_rel, levene, ks_2samp
from statsmodels.stats.anova import AnovaRM
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def get_best_model_per_metric(ratings_df: pd.DataFrame) -> Dict[str, str]:
    best_models = {}
    for metric in ratings_df.columns:
        mean_scores = ratings_df[metric].apply(lambda x: np.mean(np.array(x)))
        best_models[metric] = mean_scores.idxmax()
    return best_models


def prepare_long_format_data(ratings_df: pd.DataFrame, metric: str) -> pd.DataFrame:
    long_data_list = []
    n_subjects = len(next(iter(ratings_df[metric])))
    for subject_idx in range(n_subjects):
        for model_name in ratings_df.index:
            value = ratings_df.loc[model_name, metric][subject_idx]
            long_data_list.append({
                'Subject': subject_idx,
                'Model': model_name,
                'Value': value
            })
    return pd.DataFrame(long_data_list)


def test_normality(data: np.ndarray) -> bool:
    _, p_value = ks_2samp(data, np.random.normal(np.mean(data), np.std(data), len(data)))
    return p_value >= 0.05


def calculate_cohens_d(group1: np.ndarray, group2: np.ndarray) -> float:
    """
    Calculate Cohen's d effect size for two groups.
    
    Args:
        group1: First group's data
        group2: Second group's data
        
    Returns:
        Cohen's d effect size value
    """
    n1, n2 = len(group1), len(group2)
    var1, var2 = np.var(group1, ddof=1), np.var(group2, ddof=1)
    
    # Pooled standard deviation
    pooled_sd = np.sqrt(((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2))
    
    # Cohen's d
    d = (np.mean(group1) - np.mean(group2)) / pooled_sd
    return d


def perform_group_significance_test(ratings_df: pd.DataFrame, metric: str) -> Tuple[str, float]:
    methods = list(ratings_df.index)

    # If only two methods, use paired test
    if len(methods) < 3:
        arr1 = np.array(ratings_df.loc[methods[0], metric])
        arr2 = np.array(ratings_df.loc[methods[1], metric])
        norm1 = test_normality(arr1)
        norm2 = test_normality(arr2)
        if norm1 and norm2:
            _, lev_p = levene(arr1, arr2)
            if lev_p >= 0.05:
                _, p_value = ttest_rel(arr1, arr2)
                return "Paired t-test", p_value
        _, p_value = wilcoxon(arr1, arr2)
        return "Wilcoxon Signed-Rank Test", p_value

    # For three or more, test normality
    normal_tests = []
    for method in methods:
        data = np.array(ratings_df.loc[method, metric])
        is_normal = test_normality(data)
        normal_tests.append(is_normal)

    # If all normal, try RM-ANOVA
    if all(normal_tests):
        long_data = prepare_long_format_data(ratings_df, metric)
        try:
            anova = AnovaRM(long_data, 'Value', 'Subject', within=['Model']).fit()
            return "Repeated Measures ANOVA", anova.anova_table['Pr > F'][0]
        except Exception as e:
            logger.warning(f"RM ANOVA failed: {e}, falling back to Friedman test")
            
    # Friedman fallback
    friedman_data = [np.array(ratings_df.loc[m, metric]) for m in methods]
    stat, p_value = friedmanchisquare(*friedman_data)
    return "Friedman Test", p_value


def perform_pairwise_comparison(
    best_scores: List[float],
    comparison_scores: List[float]
) -> Tuple[str, float, float]:
    best_array = np.array(best_scores)
    comp_array = np.array(comparison_scores)
    
    # Calculate effect size
    effect_size = calculate_cohens_d(best_array, comp_array)
    
    norm1 = test_normality(best_array)
    norm2 = test_normality(comp_array)
    if norm1 and norm2:
        _, lev_p = levene(best_array, comp_array)
        if lev_p >= 0.05:
            _, p_value = ttest_rel(best_array, comp_array)
            return "Paired t-test", p_value, effect_size
    _, p_value = wilcoxon(best_array, comp_array)
    return "Wilcoxon Signed-Rank Test", p_value, effect_size


def interpret_effect_size(effect_size: float) -> str:
    """
    Interpret Cohen's d effect size.
    
    Args:
        effect_size: The calculated Cohen's d value
        
    Returns:
        String interpretation of the effect size
    """
    if abs(effect_size) < 0.2:
        return "negligible"
    elif abs(effect_size) < 0.5:
        return "small"
    elif abs(effect_size) < 0.8:
        return "medium"
    else:
        return "large"


def compare_models(ratings_df: pd.DataFrame) -> pd.DataFrame:
    comparisons_data: List[Dict] = []
    best_models = get_best_model_per_metric(ratings_df)
    for metric in ratings_df.columns:
        best_model = best_models[metric]
        group_test, group_p = perform_group_significance_test(ratings_df, metric)
        if group_p < 0.05 and len(ratings_df.index) > 1:
            for model in ratings_df.index:
                if model == best_model:
                    continue
                try:
                    test_used, p_val, effect_size = perform_pairwise_comparison(
                        ratings_df.loc[best_model, metric],
                        ratings_df.loc[model, metric]
                    )
                    effect_interpretation = interpret_effect_size(effect_size)
                    
                    comparisons_data.append({
                        "Metric": metric,
                        "Group Test": group_test,
                        "Group p": group_p,
                        "Best Method": best_model,
                        "Compared Method": model,
                        "Pair Test": test_used,
                        "Pair p": p_val,
                        "Significant": p_val < 0.05,
                        "Effect Size (Cohen's d)": effect_size,
                        "Effect Interpretation": effect_interpretation
                    })
                except Exception as e:
                    logger.warning(f"Comparison {best_model} vs {model} failed: {e}")
        else:
            comparisons_data.append({
                "Metric": metric,
                "Group Test": group_test,
                "Group p": group_p,
                "Best Method": best_model,
                "Compared Method": None,
                "Pair Test": None,
                "Pair p": None,
                "Significant": False,
                "Effect Size (Cohen's d)": None,
                "Effect Interpretation": None
            })
    return pd.DataFrame(comparisons_data)


# Results

In [None]:
# =============================================================================
# MASTER SESSION LOADING - Load once, use everywhere
# =============================================================================
import os

if os.path.exists("session_cache.pkl"):
    all_sessions = load_sessions_cache(cache_path="session_cache.pkl")
else:
    all_sessions = load_all_session_data(sessions_path, n_jobs=200)
    # save_sessions(all_sessions, base_path="data/session_cache")

## Overall Results

In [None]:
metrics = ["RMSE", "Spearman", "KendallTau", "Pearson"]

In [None]:
# Compute domain results from cached sessions
domain_results = compute_domain_results(all_sessions)

# Aggregate for boxplot
overall_results = aggregate_results_for_boxplot(domain_results)

# Debug domain matching
debug_domain_matching(domain_results)

In [None]:
plot_all_metrics_subplots(overall_results, save_path=PLOTS_DIR / "overall_boxplots.png")

In [None]:
def create_mean_std_table(overall_results):
    """
    Create a table showing mean ± std for each model across all metrics.
    
    Args:
        overall_results: DataFrame with models as index and metrics as columns,
                        where each cell contains a list of values.
    
    Returns:
        DataFrame with formatted mean ± std strings
    """
    # Create a new DataFrame for the summary
    summary_data = {}
    
    for metric in overall_results.columns:
        summary_data[metric] = []
        for model in overall_results.index:
            values = overall_results.loc[model, metric]
            if isinstance(values, list) and len(values) > 0:
                mean_val = np.mean(values)
                std_val = np.std(values)
                summary_data[metric].append(f"{mean_val:.4f} ± {std_val:.4f}")
            else:
                summary_data[metric].append("N/A")
    
    # Create DataFrame with shortened model names
    model_labels = [" ".join(model.split(" ")[0:-2]) for model in overall_results.index]
    summary_df = pd.DataFrame(summary_data, index=model_labels)
    
    return summary_df

# Display mean ± std table for all metrics
mean_std_table = create_mean_std_table(overall_results)
print("Mean ± Std for Each Model Across All Metrics:\n")
display(mean_std_table)

In [None]:
significane_tests = overall_results.copy()

def make_negative(x):
    if isinstance(x, list):
        return [-v for v in x]
    return -x

significane_tests['RMSE'] = significane_tests['RMSE'].apply(make_negative)

In [None]:
compare_models(significane_tests).to_excel("overall_significance_test.xlsx")

### Domain Heatmaps

In [None]:
# Pearson Domain Heatmap by Categories (Size, Opposition, Balance)
fig = plot_domain_heatmap_by_categories(domain_results, metric='Pearson')
fig.savefig(PLOTS_DIR / "overall_domain_heatmap_pearson.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'overall_domain_heatmap_pearson.png'}")

In [None]:
# KendallTau Domain Heatmap by Categories (Size, Opposition, Balance)
fig = plot_domain_heatmap_by_categories(domain_results, metric='KendallTau')
fig.savefig(PLOTS_DIR / "overall_domain_heatmap_kendalltau.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'overall_domain_heatmap_kendalltau.png'}")

In [None]:
# Spearman Domain Heatmap by Categories (Size, Opposition, Balance)
fig = plot_domain_heatmap_by_categories(domain_results, metric='Spearman')
fig.savefig(PLOTS_DIR / "overall_domain_heatmap_spearman.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'overall_domain_heatmap_spearman.png'}")

In [None]:
# RMSE Domain Heatmap by Categories (Size, Opposition, Balance)
fig = plot_domain_heatmap_by_categories(domain_results, metric='RMSE')
fig.savefig(PLOTS_DIR / "overall_domain_heatmap_rmse.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'overall_domain_heatmap_rmse.png'}")

### Opponent Heatmaps (Per Agent)

In [None]:
# Compute opponent model performance per agent from cached sessions
agent_results = compute_agent_analysis(all_sessions)

In [None]:
# Pearson by Agent
fig = plot_opponent_heatmap(agent_results, metric='Pearson')
fig.savefig(PLOTS_DIR / "overall_agent_heatmap_pearson.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'overall_agent_heatmap_pearson.png'}")

In [None]:
# KendallTau by Agent
fig = plot_opponent_heatmap(agent_results, metric='KendallTau')
fig.savefig(PLOTS_DIR / "overall_agent_heatmap_kendalltau.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'overall_agent_heatmap_kendalltau.png'}")

In [None]:
# Spearman by Agent
fig = plot_opponent_heatmap(agent_results, metric='Spearman')
fig.savefig(PLOTS_DIR / "overall_agent_heatmap_spearman.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'overall_agent_heatmap_spearman.png'}")

In [None]:
# RMSE by Agent
fig = plot_opponent_heatmap(agent_results, metric='RMSE')
fig.savefig(PLOTS_DIR / "overall_agent_heatmap_rmse.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'overall_agent_heatmap_rmse.png'}")

### Round-by-Round Performance Analysis

In [None]:
# Compute round-by-round results from cached sessions
round_by_round_results = compute_round_by_round(all_sessions, n_jobs=200)

In [None]:
# Plot round-by-round performance for all metrics
plot_round_by_round_performance(round_by_round_results, save_path=PLOTS_DIR / "overall_round_by_round.png")

## Overall Pareto Metrics

In [None]:
metrics = ["RMSE", "Spearman", "Kendall", "Pearson"]

In [None]:
# Compute overall (Pareto) domain results from cached sessions
pareto_domain_results = compute_domain_results(all_sessions, metric_prefix="Overall_")

# Aggregate for boxplot
pareto_results = aggregate_results_for_boxplot(pareto_domain_results)

In [None]:
significane_tests = pareto_results.copy()

def make_negative(x):
    if isinstance(x, list):
        return [-v for v in x]
    return -x

significane_tests['RMSE'] = significane_tests['RMSE'].apply(make_negative)

In [None]:
compare_models(significane_tests).to_excel("pareto_significance_test.xlsx")

In [None]:
# Boxplots for overall metrics
plot_all_metrics_subplots(pareto_results, save_path=PLOTS_DIR / "pareto_boxplots.png")

### Pareto Domain Heatmaps

In [None]:
# Pareto Pearson Domain Heatmap by Categories (Size, Opposition, Balance)
fig = plot_domain_heatmap_by_categories(pareto_domain_results, metric='Pearson')
fig.savefig(PLOTS_DIR / "pareto_domain_heatmap_pearson.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'pareto_domain_heatmap_pearson.png'}")

In [None]:
# Pareto Kendall Domain Heatmap by Categories (Size, Opposition, Balance)
fig = plot_domain_heatmap_by_categories(pareto_domain_results, metric='Kendall')
fig.savefig(PLOTS_DIR / "pareto_domain_heatmap_kendall.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'pareto_domain_heatmap_kendall.png'}")

In [None]:
# Pareto Spearman Domain Heatmap by Categories (Size, Opposition, Balance)
fig = plot_domain_heatmap_by_categories(pareto_domain_results, metric='Spearman')
fig.savefig(PLOTS_DIR / "pareto_domain_heatmap_spearman.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'pareto_domain_heatmap_spearman.png'}")

In [None]:
# Pareto RMSE Domain Heatmap by Categories (Size, Opposition, Balance)
fig = plot_domain_heatmap_by_categories(pareto_domain_results, metric='RMSE')
fig.savefig(PLOTS_DIR / "pareto_domain_heatmap_rmse.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'pareto_domain_heatmap_rmse.png'}")

### Pareto Opponent Heatmaps (Per Agent)

In [None]:
# Compute Pareto opponent model performance per agent from cached sessions
pareto_agent_results = compute_agent_analysis(all_sessions, metric_prefix="Overall_")

In [None]:
# Pareto Pearson by Agent
fig = plot_opponent_heatmap(pareto_agent_results, metric='Pearson')
fig.savefig(PLOTS_DIR / "pareto_agent_heatmap_pearson.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'pareto_agent_heatmap_pearson.png'}")

In [None]:
# Pareto Kendall by Agent
fig = plot_opponent_heatmap(pareto_agent_results, metric='Kendall')
fig.savefig(PLOTS_DIR / "pareto_agent_heatmap_kendall.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'pareto_agent_heatmap_kendall.png'}")

In [None]:
# Pareto Spearman by Agent
fig = plot_opponent_heatmap(pareto_agent_results, metric='Spearman')
fig.savefig(PLOTS_DIR / "pareto_agent_heatmap_spearman.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'pareto_agent_heatmap_spearman.png'}")

In [None]:
# Pareto RMSE by Agent
fig = plot_opponent_heatmap(pareto_agent_results, metric='RMSE')
fig.savefig(PLOTS_DIR / "pareto_agent_heatmap_rmse.png", dpi=150, bbox_inches='tight')
print(f"Saved: {PLOTS_DIR / 'pareto_agent_heatmap_rmse.png'}")

### Round-by-Round Performance Analysis

In [None]:
# Compute overall (Pareto) round-by-round results from cached sessions
overall_round_by_round_results = compute_round_by_round(all_sessions, metric_prefix="Overall_")

In [None]:
# Plot Overall metrics by round (all 4 metrics)
plot_round_by_round_performance(overall_round_by_round_results, save_path=PLOTS_DIR / "pareto_round_by_round.png")