# Probe Results Analysis

This notebook replicates the functionality of `scripts/plot_probes.py` for analyzing probe performance across layers.

In [None]:
import json
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path
import glob
import re
import seaborn as sns

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

In [None]:
# Configuration - Update these paths
RESULTS_DIR = "../results/week-34"  # Update this path to your results directory
TAXONOMY_FILE = "../dataset/mcrae-x-things-taxonomy-simp.json"  # Update this path
SAVE_PLOTS = True  # Set to True if you want to save plots
SAVE_DIR = "../plots/notebook_analysis"  # Directory to save plots
METRICS = ["f1", "accuracy"]  # Metrics to analyze

In [None]:
def load_results(results_dir):
    """Load all probe result files."""
    results = {}

    # Look for probe result files
    pattern = str(Path(results_dir) / "probe_results_logistic_*.json")
    files = glob.glob(pattern)

    for file_path in files:
        filename = Path(file_path).name

        # Extract layer from filename
        if "last" in filename:
            layer = "last"
        else:
            # Try to find number in filename
            numbers = re.findall(r"\d+", filename)
            if numbers:
                layer = int(numbers[-1])  # Take the last number
            else:
                continue  # Skip if can't find layer

        with open(file_path, "r") as f:
            data = json.load(f)

        results[layer] = data
        print(f"Loaded {filename} -> layer {layer}")

    return results


def load_taxonomy(taxonomy_file):
    """Load taxonomy file."""
    with open(taxonomy_file, "r") as f:
        return json.load(f)


def extract_layer_performance(results, metric="f1"):
    """Extract performance metrics for each layer."""
    layer_data = []

    for layer, data in results.items():
        individual_results = data["individual_results"]

        # Get all scores for this metric
        scores = [r[f"mean_{metric}"]*100 for r in individual_results]  # Convert to percentage

        layer_data.append(
            {
                "layer": layer,
                "mean": np.mean(scores),
                "std": np.std(scores),
                "median": np.median(scores),
                "min": np.min(scores),
                "max": np.max(scores),
                "n_attributes": len(scores),
            }
        )

    # Sort by layer (numeric first, then 'last')
    def sort_key(item):
        if item["layer"] == "last":
            return 999
        else:
            return item["layer"]

    layer_data.sort(key=sort_key)
    return layer_data

In [None]:
# Load results and taxonomy
print(f"Loading results from: {RESULTS_DIR}")
results = load_results(RESULTS_DIR)

if not results:
    print("No results found! Please check the RESULTS_DIR path.")
else:
    print(f"Found {len(results)} layers: {sorted(results.keys())}")

# Load taxonomy
try:
    taxonomy = load_taxonomy(TAXONOMY_FILE)
    print(f"Loaded taxonomy with {len(taxonomy)} attributes")
    print(f"Categories: {sorted(set(taxonomy.values()))}")
except Exception as e:
    print(f"Could not load taxonomy: {e}")
    taxonomy = {}

# Create save directory if needed
if SAVE_PLOTS:
    Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)
    print(f"Plots will be saved to: {SAVE_DIR}")

## 1. Overview Performance Across Layers

In [None]:
def overview_performance(results, metric="f1", save_path=None):
    """Plot line plot with mean performance at each layer."""
    # Extract performance data for all layers
    layer_data = extract_layer_performance(results, metric)

    if not layer_data:
        print("No layer data found!")
        return

    # Extract data for plotting
    layers = [d["layer"] for d in layer_data]
    means = [d["mean"] for d in layer_data]
    stds = [d["std"] for d in layer_data]
    n_attributes = [d["n_attributes"] for d in layer_data]

    # Create the plot
    plt.figure(figsize=(12, 6))
    x_pos = range(len(layers))

    # Main line plot
    plt.plot(
        x_pos,
        means,
        "o-",
        linewidth=3,
        markersize=8,
        color="steelblue",
        label="Mean Performance",
    )

    # Error bars (shaded area)
    plt.fill_between(
        x_pos,
        [m - s for m, s in zip(means, stds)],
        [m + s for m, s in zip(means, stds)],
        alpha=0.3,
        color="steelblue",
        label="± STD",
    )

    # Formatting
    plt.xlabel("Layer", fontsize=12)
    plt.ylabel(f"Mean {metric.upper()} Score (%)", fontsize=12)
    plt.title(
        f"Probe Performance Overview - {metric.upper()} Across Layers",
        fontsize=14,
    )
    plt.grid(True, alpha=0.3)
    plt.legend()

    # Set x-axis labels
    plt.xticks(x_pos, [str(l) for l in layers])

    # Add annotations for best and worst layers
    best_idx = np.argmax(means)
    worst_idx = np.argmin(means)

    plt.annotate(
        f"Best: Layer {layers[best_idx]}\n{means[best_idx]:.1f}%",
        xy=(best_idx, means[best_idx]),
        xytext=(10, 15),
        textcoords="offset points",
        bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgreen", alpha=0.8),
        arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0"),
        fontsize=10,
    )

    plt.annotate(
        f"Worst: Layer {layers[worst_idx]}\n{means[worst_idx]:.1f}%",
        xy=(worst_idx, means[worst_idx]),
        xytext=(10, -25),
        textcoords="offset points",
        bbox=dict(boxstyle="round,pad=0.5", facecolor="lightcoral", alpha=0.8),
        arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0"),
        fontsize=10,
    )

    # Add text box with summary stats
    summary_text = f"""Summary:
Layers: {len(layers)}
Best: Layer {layers[best_idx]} ({means[best_idx]:.1f}%)
Worst: Layer {layers[worst_idx]} ({means[worst_idx]:.1f}%)
Range: {max(means) - min(means):.1f}%
Avg attributes/layer: {np.mean(n_attributes):.0f}"""

    plt.text(
        0.02,
        0.98,
        summary_text,
        transform=plt.gca().transAxes,
        fontsize=9,
        verticalalignment="top",
        bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8),
    )

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Saved plot to: {save_path}")

    plt.show()

    # Print summary to console
    print(f"\nOverview Performance Summary ({metric.upper()}):")
    print("=" * 40)
    for i, (layer, mean, std, n_attr) in enumerate(
        zip(layers, means, stds, n_attributes)
    ):
        status = ""
        if i == best_idx:
            status = " ⭐ BEST"
        elif i == worst_idx:
            status = " ⚠️  WORST"
        print(f"Layer {layer:>4}: {mean:.1f}% ± {std:.1f}% (n={n_attr:3d}){status}")

    print(
        f"\nPerformance range: {min(means):.1f}% - {max(means):.1f}% (Δ={max(means)-min(means):.1f}%)"
    )

    return layer_data


# Create overview performance plot
for metric in METRICS:
    save_path = f"{SAVE_DIR}/overview_performance_{metric}.png" if SAVE_PLOTS else None
    layer_data = overview_performance(results, metric=metric, save_path=save_path)

## 2. Category Breakdown Analysis

In [None]:
def category_breakdown(results, layer, taxonomy, metric="f1", save_path=None):
    """Plot performance by semantic category."""
    if layer not in results:
        print(f"Layer {layer} not found in results")
        return

    # Get all attribute scores for this layer
    individual_results = results[layer]["individual_results"]

    # Group scores and baselines by category
    category_scores = {}
    category_baselines = {}
    category_counts = {}
    for result in individual_results:
        attr = result["attribute"]
        score = result[f"mean_{metric}"] * 100  # Convert to percentage
        baseline = result[f"baseline_mean_{metric}"] * 100
        # Get category from taxonomy
        category = taxonomy.get(attr, "unknown")
        if category not in category_scores:
            category_scores[category] = []
            category_baselines[category] = []
            category_counts[category] = 0
        category_scores[category].append(score)
        category_baselines[category].append(baseline)
        category_counts[category] += 1
        
    # Calculate mean score and baseline per category
    category_means = {}
    category_stds = {}
    category_baseline_means = {}
    for cat, scores in category_scores.items():
        category_means[cat] = np.mean(scores)
        category_stds[cat] = np.std(scores)
        category_baseline_means[cat] = np.mean(category_baselines[cat])
        
    # Sort categories by performance
    sorted_categories = sorted(category_means.items(), key=lambda x: x[1], reverse=True)
    categories = [cat for cat, _ in sorted_categories]
    means = [score for _, score in sorted_categories]
    stds = [category_stds[cat] for cat, _ in sorted_categories]
    counts = [category_counts[cat] for cat, _ in sorted_categories]
    baselines = [category_baseline_means[cat] for cat, _ in sorted_categories]
    
    # Create plot
    plt.figure(figsize=(12, 8))
    colors = plt.cm.Set3(np.linspace(0, 1, len(categories)))
    x_pos = range(len(categories))

    bars = plt.bar(
        x_pos, means, yerr=stds, capsize=5, color=colors, alpha=0.7, edgecolor="black"
    )

    # Add red baseline markers for each category
    for i, baseline in enumerate(baselines):
        plt.plot(
            [i - 0.4, i + 0.4],
            [baseline, baseline],
            color="red",
            linestyle="--",
            linewidth=2,
            alpha=0.8,
        )

        # Add baseline percentage text
        plt.text(
            i,
            baseline - 2,
            f"{baseline:.1f}%",
            color="red",
            fontweight="bold",
            ha="center",
            va="top",
            fontsize=8,
            bbox=dict(
                boxstyle="round,pad=0.2", facecolor="white", edgecolor="red", alpha=0.8
            ),
        )

    plt.xlabel("Category")
    plt.ylabel(f"Mean {metric.upper()} Score (%)")
    plt.title(f"{metric.upper()} Performance by Category - Layer {layer}")
    plt.xticks(x_pos, categories, rotation=45, ha="right")
    plt.grid(True, alpha=0.3, axis="y")

    # Add performance value and count labels on bars
    for i, (bar, count, mean_val) in enumerate(zip(bars, counts, means)):
        height = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + stds[i] + 1,
            f"{mean_val:.1f}%\nn={count}",
            ha="center",
            va="bottom",
            fontsize=9,
        )
        
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Saved plot to: {save_path}")
    plt.show()
    
    # Print detailed breakdown
    print(f"\nCategory Breakdown for Layer {layer}:")
    print("=" * 50)
    for cat, mean_score in sorted_categories:
        std_score = category_stds[cat]
        baseline_score = category_baseline_means[cat]
        count = category_counts[cat]
        print(
            f"{cat:15s}: {mean_score:.1f}% ± {std_score:.1f}% (baseline: {baseline_score:.1f}%, n={count})"
        )
        
    return category_means, category_counts


# Create category breakdown for best and last layers
if taxonomy and results:
    # Find best layer
    layer_data = extract_layer_performance(results, "f1")
    if layer_data:
        best_idx = max(range(len(layer_data)), key=lambda i: layer_data[i]['mean'])
        best_layer = layer_data[best_idx]['layer']
        print(f"Best performing layer: {best_layer} (F1: {layer_data[best_idx]['mean']:.1f}%)")
        
        # Plot for best layer
        save_path = f"{SAVE_DIR}/category_breakdown_best.png" if SAVE_PLOTS else None
        category_breakdown(results, best_layer, taxonomy, save_path=save_path)
        
        # Plot for last layer if different from best
        if "last" in results and best_layer != "last":
            save_path = f"{SAVE_DIR}/category_breakdown_last.png" if SAVE_PLOTS else None
            category_breakdown(results, "last", taxonomy, save_path=save_path)
else:
    print("Skipping category analysis - no taxonomy or results found")

## 3. Performance Curves by Category

In [None]:
def performance_curves(results, taxonomy, categories=None, metric="f1", save_path=None):
    """Plot performance curves by category across layers."""
    if not categories:
        # Get all unique categories from taxonomy
        categories = list(set(taxonomy.values()))
        categories.sort()

    # Initialize data structure for each category
    category_data = {cat: [] for cat in categories}

    # Sort layers properly
    sorted_layers = sorted(results.keys(), key=lambda x: 999 if x == "last" else x)

    # For each layer, calculate performance by category
    for layer in sorted_layers:
        if layer not in results:
            continue

        individual_results = results[layer]["individual_results"]

        # Group scores by category for this layer
        layer_category_scores = {cat: [] for cat in categories}

        for result in individual_results:
            attr = result["attribute"]
            score = result[f"mean_{metric}"] * 100  # Convert to percentage
            category = taxonomy.get(attr, "unknown")

            if category in layer_category_scores:
                layer_category_scores[category].append(score)

        # Calculate mean for each category in this layer
        for cat in categories:
            if layer_category_scores[cat]:  # If category has attributes in this layer
                mean_score = np.mean(layer_category_scores[cat])
                std_score = np.std(layer_category_scores[cat])
                n_attrs = len(layer_category_scores[cat])

                category_data[cat].append(
                    {
                        "layer": layer,
                        "mean": mean_score,
                        "std": std_score,
                        "n_attributes": n_attrs,
                    }
                )

    # Filter categories that have data across layers
    categories_with_data = []
    for cat in categories:
        if len(category_data[cat]) >= 3:  # At least 3 layers with data
            categories_with_data.append(cat)

    if not categories_with_data:
        print("No categories with sufficient data across layers!")
        return

    # Create the plot
    plt.figure(figsize=(14, 8))

    # Color palette
    colors = plt.cm.tab10(np.linspace(0, 1, len(categories_with_data)))

    # Plot each category
    for i, cat in enumerate(categories_with_data):
        data = category_data[cat]
        layers = [d["layer"] for d in data]
        means = [d["mean"] for d in data]
        stds = [d["std"] for d in data]

        # Convert layer names to x positions
        x_positions = []
        for layer in layers:
            if layer == "last":
                x_positions.append(len(sorted_layers) - 1)
            else:
                x_positions.append(sorted_layers.index(layer))

        # Plot line with error bars
        plt.plot(
            x_positions,
            means,
            "o-",
            linewidth=2.5,
            markersize=6,
            color=colors[i],
            label=f"{cat}",
            alpha=0.8,
        )

        # Add error bars (shaded area)
        plt.fill_between(
            x_positions,
            [m - s for m, s in zip(means, stds)],
            [m + s for m, s in zip(means, stds)],
            alpha=0.15,
            color=colors[i],
        )

    # Formatting
    plt.xlabel("Layer", fontsize=12)
    plt.ylabel(f"Mean {metric.upper()} Score (%)", fontsize=12)
    plt.title(
        f"Performance Curves by Category - {metric.upper()} Across Layers",
        fontsize=14,
    )
    plt.grid(True, alpha=0.3)

    # Set x-axis
    x_labels = [str(layer) for layer in sorted_layers]
    plt.xticks(range(len(sorted_layers)), x_labels)

    # Legend
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=10)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Saved plot to: {save_path}")

    plt.show()

    # Print analysis
    print(f"\nPerformance Curves Analysis ({metric.upper()}):")
    print("=" * 50)

    # Find peak layer for each category
    category_peaks = {}
    for cat in categories_with_data:
        data = category_data[cat]
        if data:
            best_performance = max(data, key=lambda x: x["mean"])
            category_peaks[cat] = {
                "layer": best_performance["layer"],
                "score": best_performance["mean"],
                "n_attrs": best_performance["n_attributes"],
            }

    # Sort categories by their peak performance layer
    sorted_peaks = sorted(
        category_peaks.items(),
        key=lambda x: 999 if x[1]["layer"] == "last" else x[1]["layer"],
    )

    print("Categories ranked by peak performance layer:")
    for cat, peak_info in sorted_peaks:
        print(
            f"  {cat:15s}: Peak at layer {peak_info['layer']:>4} "
            f"({peak_info['score']:.1f}%, n={peak_info['n_attrs']})"
        )

    return category_data, category_peaks


# Create performance curves plot
if taxonomy and results:
    save_path = f"{SAVE_DIR}/performance_curves.png" if SAVE_PLOTS else None
    category_data, category_peaks = performance_curves(results, taxonomy, save_path=save_path)
else:
    print("Skipping performance curves - no taxonomy or results found")

## 4. Individual Attribute Analysis

In [None]:
def attribute_breakdown(results, layer, taxonomy, metric="f1", save_path=None, top_n_per_category=10):
    """Plot individual attributes as bars, color-coded by category."""
    if layer not in results:
        print(f"Layer {layer} not found in results")
        return

    # Get all attribute scores for this layer
    individual_results = results[layer]["individual_results"]

    # Create list of (attribute, score, baseline, category) tuples
    attr_data = []
    for result in individual_results:
        attr = result["attribute"]
        score = result[f"mean_{metric}"] * 100  # Convert to percentage
        baseline = result[f"baseline_mean_{metric}"] * 100
        category = taxonomy.get(attr, "unknown")
        attr_data.append((attr, score, baseline, category))

    # Group by category and sort within each category by score descending
    category_groups = {}
    for attr, score, baseline, category in attr_data:
        if category not in category_groups:
            category_groups[category] = []
        category_groups[category].append((attr, score, baseline))

    # Sort categories by their mean performance
    category_means = {
        cat: np.mean([score for _, score, _ in attrs])
        for cat, attrs in category_groups.items()
    }
    sorted_categories = sorted(category_means.items(), key=lambda x: x[1], reverse=True)
    category_order = [cat for cat, _ in sorted_categories]

    # Sort attributes within each category by score
    for cat in category_groups:
        category_groups[cat].sort(key=lambda x: x[1], reverse=True)

    # Get unique categories and assign colors
    unique_categories = category_order
    category_colors = dict(
        zip(unique_categories, plt.cm.Set3(np.linspace(0, 1, len(unique_categories))))
    )

    # Create plot for top N per category
    category_groups_top = {}
    for cat in category_order:
        if cat in category_groups:
            category_groups_top[cat] = category_groups[cat][:top_n_per_category]

    # Flatten data for plotting
    attributes = []
    scores = []
    baselines = []
    categories = []
    
    for cat in category_order:
        if cat in category_groups_top:
            for attr, score, baseline in category_groups_top[cat]:
                attributes.append(attr)
                scores.append(score)
                baselines.append(baseline)
                categories.append(cat)

    if not attributes:
        print("No attributes to plot")
        return

    # Create the plot
    plt.figure(figsize=(20, 10))
    x_pos = range(len(attributes))
    bar_colors = [category_colors[cat] for cat in categories]

    # Create bars
    bars = plt.bar(
        x_pos,
        scores,
        color=bar_colors,
        alpha=0.7,
        edgecolor="black",
        linewidth=0.5,
    )

    # Add baseline markers
    for i, baseline in enumerate(baselines):
        plt.plot(
            [i - 0.4, i + 0.4],
            [baseline, baseline],
            color="red",
            linestyle="--",
            linewidth=1.5,
            alpha=0.8,
        )

    # Add category separators and labels
    current_cat = None
    cat_positions = {}
    
    for i, cat in enumerate(categories):
        if cat not in cat_positions:
            cat_positions[cat] = []
        cat_positions[cat].append(i)
        
        if current_cat is not None and cat != current_cat:
            plt.axvline(x=i-0.5, color="gray", linestyle="-", linewidth=2, alpha=0.6)
        current_cat = cat

    # Add category labels at the top
    y_max = max(scores) * 1.1
    for cat, positions in cat_positions.items():
        center_pos = (min(positions) + max(positions)) / 2
        plt.text(
            center_pos,
            y_max * 0.95,
            cat,
            ha="center",
            va="center",
            fontweight="bold",
            fontsize=12,
            bbox=dict(
                boxstyle="round,pad=0.3",
                facecolor=category_colors[cat],
                alpha=0.3,
                edgecolor="black",
            ),
        )

    # Formatting
    plt.xlabel("Attributes (grouped by category)", fontsize=12)
    plt.ylabel(f"{metric.upper()} Score (%)", fontsize=12)
    plt.title(
        f"Top {top_n_per_category} Attributes per Category (Layer {layer})",
        fontsize=14,
    )
    plt.grid(True, alpha=0.3, axis="y")

    # Set x-axis labels
    plt.xticks(x_pos, attributes, rotation=90, ha="right", fontsize=8)
    plt.ylim(0, y_max)
    
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Saved plot to: {save_path}")

    plt.show()

    # Print top attributes per category
    print(f"\nTop {top_n_per_category} attributes per category (Layer {layer}):")
    for cat in category_order:
        if cat in category_groups_top and category_groups_top[cat]:
            print(f"\n{cat}:")
            for i, (attr, score, baseline) in enumerate(category_groups_top[cat], 1):
                print(f"  {i:2d}. {attr:<30}: {score:.1f}% (baseline: {baseline:.1f}%)")

    return attr_data


# Create attribute breakdown for best layer
if taxonomy and results and 'best_layer' in locals():
    save_path = f"{SAVE_DIR}/attribute_breakdown_best.png" if SAVE_PLOTS else None
    attribute_breakdown(results, best_layer, taxonomy, save_path=save_path)
else:
    print("Skipping attribute breakdown - no taxonomy, results, or best layer identified")

## 5. Summary and Insights

In [None]:
# Generate comprehensive summary
if results:
    print("\n" + "="*60)
    print("PROBE ANALYSIS SUMMARY")
    print("="*60)
    
    # Overall statistics
    total_layers = len(results)
    layer_data = extract_layer_performance(results, "f1")
    
    if layer_data:
        best_layer_data = max(layer_data, key=lambda x: x['mean'])
        worst_layer_data = min(layer_data, key=lambda x: x['mean'])
        avg_performance = np.mean([d['mean'] for d in layer_data])
        
        print(f"\nDataset Overview:")
        print(f"- Total layers analyzed: {total_layers}")
        print(f"- Average attributes per layer: {np.mean([d['n_attributes'] for d in layer_data]):.0f}")
        print(f"- Performance range: {worst_layer_data['mean']:.1f}% - {best_layer_data['mean']:.1f}%")
        print(f"- Average performance: {avg_performance:.1f}%")
        
        print(f"\nKey Findings:")
        print(f"- Best performing layer: {best_layer_data['layer']} ({best_layer_data['mean']:.1f}% F1)")
        print(f"- Worst performing layer: {worst_layer_data['layer']} ({worst_layer_data['mean']:.1f}% F1)")
        print(f"- Performance improvement: {best_layer_data['mean'] - worst_layer_data['mean']:.1f}% from worst to best")
        
        if taxonomy:
            print(f"\nCategory Analysis:")
            print(f"- Total categories: {len(set(taxonomy.values()))}")
            print(f"- Categories: {', '.join(sorted(set(taxonomy.values())))}")
            
            if 'category_peaks' in locals():
                print(f"\nCategory Peak Performance:")
                for cat, peak in sorted(category_peaks.items(), key=lambda x: x[1]['score'], reverse=True):
                    print(f"- {cat}: {peak['score']:.1f}% at layer {peak['layer']}")
    
    print(f"\nRecommendations:")
    if 'best_layer_data' in locals():
        print(f"- Use layer {best_layer_data['layer']} for optimal probe performance")
        print(f"- Consider layers around {best_layer_data['layer']} for fine-tuning")
        
    if SAVE_PLOTS:
        print(f"\nAll plots saved to: {SAVE_DIR}")
        
else:
    print("\nNo results to summarize. Please check your RESULTS_DIR configuration.")