# Evaluation Results Analysis

This notebook analyzes model evaluation results to understand:
- Where models are making mistakes
- Answer distribution patterns
- Most common confusions
- Per-class accuracy

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from collections import Counter, defaultdict

%matplotlib inline
sns.set_style("whitegrid")

## Configuration

In [None]:
# Path to your results file
RESULTS_FILE = "./output/results_paligemma_Controlled_Images_B_None_fouroption_False.json"

# Dataset configuration - UPDATE THESE to match your results file
DATASET_NAME = "Controlled_Images_B"  # Change to match your dataset
DATA_DIR = "../data"  # Path to data directory

# You can also compare multiple models
# RESULTS_FILES = {
#     "PaliGemma": "./output/results_paligemma_Controlled_Images_B_None_fouroption_False.json",
#     "Qwen2.5-VL": "./output/results_qwen_vllm_Controlled_Images_B_None_fouroption_False.json",
# }

## Load Results

In [None]:
def extract_spatial_answer(generation):
    """
    Extract spatial relation from model generation.
    
    Args:
        generation: Raw model output string
    
    Returns:
        Extracted spatial relation (lowercase) or 'unknown'
    """
    gen_lower = generation.lower().strip()
    
    # List of possible spatial relations
    spatial_relations = [
        'left', 'right', 'above', 'below', 'top', 'bottom',
        'on', 'under', 'front', 'behind', 'in-front'
    ]
    
    # Try to find exact matches first
    for relation in spatial_relations:
        if relation == gen_lower or f' {relation} ' in f' {gen_lower} ':
            return relation
    
    # Try to find relations at the beginning or end
    for relation in spatial_relations:
        if gen_lower.startswith(relation + ' ') or gen_lower.endswith(' ' + relation):
            return relation
    
    # If no exact match, return the generation (truncated)
    return gen_lower[:20] if gen_lower else 'unknown'


# Load results
with open(RESULTS_FILE, 'r') as f:
    results = json.load(f)

print(f"Loaded {len(results)} results")
print(f"\nFirst result example:")
print(json.dumps(results[0], indent=2))

## Extract and Clean Predictions

In [None]:
# Extract predictions and ground truth
predicted_answers = []
golden_answers = []
correct_predictions = []
raw_generations = []

for result in results:
    raw_gen = result['Generation']
    pred = extract_spatial_answer(raw_gen)
    gold = result['Golden'].lower() if isinstance(result['Golden'], str) else result['Golden'][0].lower()
    
    raw_generations.append(raw_gen)
    predicted_answers.append(pred)
    golden_answers.append(gold)
    correct_predictions.append(pred == gold)

# Create DataFrame for easier analysis
df = pd.DataFrame({
    'prompt': [r['Prompt'] for r in results],
    'raw_generation': raw_generations,
    'predicted': predicted_answers,
    'golden': golden_answers,
    'correct': correct_predictions
})

print(f"\nDataFrame shape: {df.shape}")
print(f"\nFirst few rows:")
df.head()

## Overall Statistics

In [None]:
total = len(df)
correct = df['correct'].sum()
accuracy = 100 * correct / total

print("="*60)
print("OVERALL STATISTICS")
print("="*60)
print(f"Total samples: {total}")
print(f"Correct predictions: {correct}")
print(f"Accuracy: {accuracy:.2f}%")
print("="*60)

## Per-Class Accuracy

Which spatial relations is the model struggling with?

In [None]:
# Calculate per-class accuracy
per_class = df.groupby('golden').agg({
    'correct': ['sum', 'count', 'mean']
}).round(4)
per_class.columns = ['correct', 'total', 'accuracy']
per_class['accuracy'] = per_class['accuracy'] * 100
per_class = per_class.sort_values('accuracy', ascending=False)

print("\nPer-Class Accuracy:")
print(per_class)

# Visualize
fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar(per_class.index, per_class['accuracy'], color='steelblue', alpha=0.8)
ax.set_xlabel('Spatial Relation (Golden Answer)', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title(f'Per-Class Accuracy\nOverall Accuracy: {accuracy:.2f}%', fontsize=14)
ax.set_ylim(0, 100)
ax.axhline(y=accuracy, color='r', linestyle='--', alpha=0.5, label=f'Overall: {accuracy:.1f}%')
plt.xticks(rotation=45, ha='right')

# Add count labels on bars
for bar, (idx, row) in zip(bars, per_class.iterrows()):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 2,
            f'n={int(row["total"])}', ha='center', va='bottom', fontsize=10)

plt.legend()
plt.tight_layout()
plt.show()

## Answer Distribution

Compare golden vs predicted answer distributions

In [None]:
# Count distributions
golden_dist = Counter(golden_answers)
predicted_dist = Counter(predicted_answers)

print("\nGolden Answer Distribution:")
for answer, count in golden_dist.most_common():
    pct = 100 * count / len(golden_answers)
    print(f"  {answer:15s}: {count:4d} ({pct:5.1f}%)")

print("\nPredicted Answer Distribution:")
for answer, count in predicted_dist.most_common():
    pct = 100 * count / len(predicted_answers)
    print(f"  {answer:15s}: {count:4d} ({pct:5.1f}%)")

# Visualize side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Golden distribution
ax1.bar(golden_dist.keys(), golden_dist.values(), color='green', alpha=0.6)
ax1.set_xlabel('Spatial Relation', fontsize=12)
ax1.set_ylabel('Count', fontsize=12)
ax1.set_title('Golden Answer Distribution', fontsize=14)
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right')

# Predicted distribution
ax2.bar(predicted_dist.keys(), predicted_dist.values(), color='orange', alpha=0.6)
ax2.set_xlabel('Spatial Relation', fontsize=12)
ax2.set_ylabel('Count', fontsize=12)
ax2.set_title('Predicted Answer Distribution', fontsize=14)
plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, ha='right')

plt.tight_layout()
plt.show()

## Confusion Matrix

What does the model predict when it's wrong?

In [None]:
# Build confusion matrix
all_labels = sorted(set(golden_answers + predicted_answers))
confusion = np.zeros((len(all_labels), len(all_labels)))
label_to_idx = {label: idx for idx, label in enumerate(all_labels)}

for gold, pred in zip(golden_answers, predicted_answers):
    confusion[label_to_idx[gold], label_to_idx[pred]] += 1

# Normalize by row (golden answer) to get percentages
confusion_norm = confusion / (confusion.sum(axis=1, keepdims=True) + 1e-10) * 100

# Plot
fig, ax = plt.subplots(figsize=(14, 12))
sns.heatmap(confusion_norm, annot=True, fmt='.1f', cmap='YlOrRd',
            xticklabels=all_labels, yticklabels=all_labels,
            cbar_kws={'label': 'Percentage (%)'}, ax=ax)
ax.set_xlabel('Predicted Answer', fontsize=12)
ax.set_ylabel('Golden Answer', fontsize=12)
ax.set_title('Confusion Matrix (% of each golden answer)', fontsize=14)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Also show absolute counts
fig, ax = plt.subplots(figsize=(14, 12))
sns.heatmap(confusion, annot=True, fmt='.0f', cmap='Blues',
            xticklabels=all_labels, yticklabels=all_labels,
            cbar_kws={'label': 'Count'}, ax=ax)
ax.set_xlabel('Predicted Answer', fontsize=12)
ax.set_ylabel('Golden Answer', fontsize=12)
ax.set_title('Confusion Matrix (Absolute Counts)', fontsize=14)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

## Most Common Confusion Patterns

In [None]:
# Find most common confusions for each class
print("\nMost Common Confusion Patterns:")
print("="*60)

for gold in sorted(set(golden_answers)):
    gold_mask = df['golden'] == gold
    predictions = df[gold_mask]['predicted'].values
    pred_counts = Counter(predictions)
    
    print(f"\nWhen golden answer is '{gold}' (n={sum(gold_mask)}):")
    for pred, count in pred_counts.most_common(5):
        pct = 100 * count / sum(gold_mask)
        is_correct = "✓" if pred == gold else "✗"
        print(f"  {is_correct} {pred:15s}: {count:4d} ({pct:5.1f}%)")

# Load dataset to get original caption options  
import sys
sys.path.append('..')
from dataset_zoo.aro_datasets import (
    get_controlled_images_a, get_controlled_images_b,
    get_coco_qa_one_obj, get_coco_qa_two_obj,
    get_vg_qa_one_obj, get_vg_qa_two_obj
)

# Dataset name mapping
dataset_loaders = {
    "Controlled_Images_A": get_controlled_images_a,
    "Controlled_Images_B": get_controlled_images_b,
    "COCO_QA_one_obj": get_coco_qa_one_obj,
    "COCO_QA_two_obj": get_coco_qa_two_obj,
    "VG_QA_one_obj": get_vg_qa_one_obj,
    "VG_QA_two_obj": get_vg_qa_two_obj,
}

# Load the dataset
print(f"Loading dataset: {DATASET_NAME}")
if DATASET_NAME not in dataset_loaders:
    print(f"Error: Unknown dataset {DATASET_NAME}")
    print(f"Available datasets: {list(dataset_loaders.keys())}")
else:
    dataset = dataset_loaders[DATASET_NAME](image_preprocess=None, download=False, root_dir=DATA_DIR)
    print(f"Dataset loaded: {len(dataset)} samples")
    
    # Show sample
    print("\nSample from dataset:")
    sample = dataset[0]
    print(f"Caption options: {sample['caption_options']}")

In [None]:
def extract_objects_from_caption(caption, dataset_type):
    """
    Extract objects from caption based on dataset format.
    
    Args:
        caption: Caption string from dataset
        dataset_type: Type of dataset (controlled, coco, vg)
    
    Returns:
        dict with 'object1' and 'object2' (if exists)
    """
    caption_lower = caption.lower()
    
    if dataset_type in ['Controlled_Images_A', 'Controlled_Images_B']:
        # Format: "A {obj1} {relation} a {obj2}"
        # e.g., "A beer bottle on a armchair"
        words = caption_lower.split()
        if len(words) >= 4:
            obj1 = words[1]  # Position [1]
            obj2 = words[-1]  # Position [-1]
            return {'object1': obj1, 'object2': obj2}
    
    elif dataset_type in ['COCO_QA_one_obj', 'VG_QA_one_obj']:
        # Format: "A photo of a {object} on the {position}"
        # e.g., "A photo of a laptop on the left"
        if caption_lower.startswith('a photo of a'):
            text = caption_lower.replace('a photo of a ', '')
            # Remove position part
            for pos in [' on the left', ' on the right']:
                if text.endswith(pos):
                    obj = text.replace(pos, '').strip()
                    return {'object1': obj, 'object2': None}
    
    elif dataset_type in ['COCO_QA_two_obj', 'VG_QA_two_obj']:
        # Format: "A photo of a {obj1} to the {position} of a {obj2}"
        # e.g., "A photo of a metal lamp to the right of a laptop"
        if caption_lower.startswith('a photo of a'):
            text = caption_lower.replace('a photo of a ', '')
            if ' to the right of a ' in text:
                parts = text.split(' to the right of a ')
                return {'object1': parts[0].strip(), 'object2': parts[1].strip()}
            elif ' to the left of a ' in text:
                parts = text.split(' to the left of a ')
                return {'object1': parts[0].strip(), 'object2': parts[1].strip()}
    
    return {'object1': None, 'object2': None}


# Extract objects for all samples
print("Extracting objects from dataset...")
df['object1'] = None
df['object2'] = None

for idx in range(len(df)):
    if idx < len(dataset):
        sample = dataset[idx]
        # Get the correct caption (first one in caption_options)
        correct_caption = sample['caption_options'][0]
        extracted = extract_objects_from_caption(correct_caption, DATASET_NAME)
        df.at[idx, 'object1'] = extracted['object1']
        df.at[idx, 'object2'] = extracted['object2']

print(f"Extracted objects for {len(df)} samples")
print("\nSample extractions:")
print(df[['prompt', 'object1', 'object2', 'golden', 'correct']].head(10))

In [None]:
# Per-Object1 Accuracy (the primary object being described)
print("="*60)
print("PER-OBJECT ACCURACY (Object 1)")
print("="*60)

obj1_stats = df.groupby('object1').agg({
    'correct': ['sum', 'count', 'mean']
}).round(4)
obj1_stats.columns = ['correct', 'total', 'accuracy']
obj1_stats['accuracy'] = obj1_stats['accuracy'] * 100
obj1_stats = obj1_stats.sort_values('accuracy', ascending=False)

print("\nObject 1 Accuracy:")
print(obj1_stats)

# Visualize - show top 20 objects by frequency
top_objects = obj1_stats.nlargest(20, 'total')

fig, ax = plt.subplots(figsize=(14, 8))
bars = ax.barh(range(len(top_objects)), top_objects['accuracy'], color='steelblue', alpha=0.8)
ax.set_yticks(range(len(top_objects)))
ax.set_yticklabels(top_objects.index)
ax.set_xlabel('Accuracy (%)', fontsize=12)
ax.set_ylabel('Object', fontsize=12)
ax.set_title(f'Top 20 Objects by Frequency - Accuracy\n(Primary Object in Prompt)', fontsize=14, pad=20)
ax.set_xlim(0, 100)
ax.axvline(x=accuracy, color='r', linestyle='--', alpha=0.5, label=f'Overall: {accuracy:.1f}%')

# Add count labels
for i, (idx, row) in enumerate(top_objects.iterrows()):
    ax.text(row['accuracy'] + 2, i, f"n={int(row['total'])}", 
            va='center', fontsize=9)

ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Per-Object2 Accuracy (the reference object)
print("="*60)
print("PER-OBJECT ACCURACY (Object 2 - Reference Object)")
print("="*60)

# Filter out rows where object2 is None
df_with_obj2 = df[df['object2'].notna()].copy()

if len(df_with_obj2) > 0:
    obj2_stats = df_with_obj2.groupby('object2').agg({
        'correct': ['sum', 'count', 'mean']
    }).round(4)
    obj2_stats.columns = ['correct', 'total', 'accuracy']
    obj2_stats['accuracy'] = obj2_stats['accuracy'] * 100
    obj2_stats = obj2_stats.sort_values('accuracy', ascending=False)
    
    print("\nObject 2 Accuracy:")
    print(obj2_stats)
    
    # Visualize - show top 20 objects by frequency
    top_objects2 = obj2_stats.nlargest(20, 'total')
    
    fig, ax = plt.subplots(figsize=(14, 8))
    bars = ax.barh(range(len(top_objects2)), top_objects2['accuracy'], color='coral', alpha=0.8)
    ax.set_yticks(range(len(top_objects2)))
    ax.set_yticklabels(top_objects2.index)
    ax.set_xlabel('Accuracy (%)', fontsize=12)
    ax.set_ylabel('Object', fontsize=12)
    ax.set_title(f'Top 20 Objects by Frequency - Accuracy\n(Reference Object in Prompt)', fontsize=14, pad=20)
    ax.set_xlim(0, 100)
    ax.axvline(x=accuracy, color='r', linestyle='--', alpha=0.5, label=f'Overall: {accuracy:.1f}%')
    
    # Add count labels
    for i, (idx, row) in enumerate(top_objects2.iterrows()):
        ax.text(row['accuracy'] + 2, i, f"n={int(row['total'])}", 
                va='center', fontsize=9)
    
    ax.legend()
    plt.tight_layout()
    plt.show()
else:
    print("No samples with object2 found.")

In [None]:
# Object Pair Analysis (for two-object scenarios)
print("="*60)
print("OBJECT PAIR ANALYSIS")
print("="*60)

if len(df_with_obj2) > 0:
    # Create object pair identifier
    df_with_obj2['object_pair'] = df_with_obj2['object1'] + ' + ' + df_with_obj2['object2']
    
    pair_stats = df_with_obj2.groupby('object_pair').agg({
        'correct': ['sum', 'count', 'mean']
    }).round(4)
    pair_stats.columns = ['correct', 'total', 'accuracy']
    pair_stats['accuracy'] = pair_stats['accuracy'] * 100
    pair_stats = pair_stats.sort_values('total', ascending=False)
    
    print("\nTop 20 Most Frequent Object Pairs:")
    print(pair_stats.head(20))
    
    # Find worst performing pairs (with at least 3 samples)
    worst_pairs = pair_stats[pair_stats['total'] >= 3].sort_values('accuracy').head(15)
    
    print("\nWorst Performing Object Pairs (n >= 3):")
    print(worst_pairs)
    
    # Visualize worst performing pairs
    if len(worst_pairs) > 0:
        fig, ax = plt.subplots(figsize=(12, 8))
        bars = ax.barh(range(len(worst_pairs)), worst_pairs['accuracy'], color='crimson', alpha=0.7)
        ax.set_yticks(range(len(worst_pairs)))
        ax.set_yticklabels(worst_pairs.index, fontsize=10)
        ax.set_xlabel('Accuracy (%)', fontsize=12)
        ax.set_ylabel('Object Pair', fontsize=12)
        ax.set_title('Worst Performing Object Pairs\n(Minimum 3 samples)', fontsize=14, pad=20)
        ax.set_xlim(0, 100)
        ax.axvline(x=accuracy, color='r', linestyle='--', alpha=0.5, label=f'Overall: {accuracy:.1f}%')
        
        # Add count labels
        for i, (idx, row) in enumerate(worst_pairs.iterrows()):
            ax.text(row['accuracy'] + 2, i, f"n={int(row['total'])}", 
                    va='center', fontsize=9)
        
        ax.legend()
        plt.tight_layout()
        plt.show()
else:
    print("No two-object samples found.")

## Explore Specific Errors

In [None]:
# View incorrect predictions
incorrect_df = df[~df['correct']].copy()
print(f"\nTotal incorrect predictions: {len(incorrect_df)}")

# Show first few errors
print("\nFirst 10 errors:")
incorrect_df[['prompt', 'raw_generation', 'predicted', 'golden']].head(10)

## Analyze Specific Confusion Pair

In [None]:
# Choose a specific confusion to investigate
GOLD_ANSWER = 'left'  # Change this to investigate different confusions
PRED_ANSWER = 'right'  # Change this to investigate different confusions

confusion_mask = (df['golden'] == GOLD_ANSWER) & (df['predicted'] == PRED_ANSWER)
confusion_cases = df[confusion_mask]

print(f"\nCases where golden='{GOLD_ANSWER}' but predicted='{PRED_ANSWER}': {len(confusion_cases)}")
print("\nExamples:")
for idx, row in confusion_cases.head(5).iterrows():
    print(f"\nPrompt: {row['prompt']}")
    print(f"Generated: {row['raw_generation']}")
    print(f"Predicted: {row['predicted']} | Golden: {row['golden']}")
    print("-" * 80)

## Raw Generation Analysis

In [None]:
# Look at the raw generations to understand model behavior
print("\nSample raw generations (correct predictions):")
for idx, row in df[df['correct']].sample(min(5, len(df[df['correct']]))).iterrows():
    print(f"\nGolden: {row['golden']}")
    print(f"Raw: {row['raw_generation']}")
    print("-" * 60)

print("\n\nSample raw generations (incorrect predictions):")
for idx, row in df[~df['correct']].sample(min(5, len(df[~df['correct']]))).iterrows():
    print(f"\nGolden: {row['golden']} | Predicted: {row['predicted']}")
    print(f"Raw: {row['raw_generation']}")
    print("-" * 60)

## Save Analysis Summary

In [None]:
# Save summary to JSON
summary = {
    'overall': {
        'total': total,
        'correct': int(correct),
        'accuracy': float(accuracy)
    },
    'per_class_accuracy': per_class.to_dict('index'),
    'golden_distribution': dict(golden_dist),
    'predicted_distribution': dict(predicted_dist)
}

output_file = RESULTS_FILE.replace('.json', '_analysis.json')
with open(output_file, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\nAnalysis saved to: {output_file}")