# üìä Analysis & Figures for Report

This notebook generates publication-ready figures for the AGRE-KD feature distillation report.

**Figures to generate:**
1. Œ≥ (gamma) sweep bar chart - Main experimental result
2. Waterbirds dataset 2x2 visualization - Help readers understand spurious correlations
3. AGRE-KD vs AVER comparison line plot (optional)
4. Summary results table - Color-coded table with WGA and Avg Acc ¬± std
5. Horizontal bar comparison - Side-by-side WGA and Avg Acc visualization

**Output:** Figures saved to `blog/images/` for inclusion in the report.

---
## 1Ô∏è‚É£ Setup Environment

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install dependencies
!pip install -q wilds tqdm scikit-learn

# Verify GPU (optional for figure generation)
import torch
print(f"\n{'='*50}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"{'='*50}")

In [None]:
import os
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

# ============================================================
# CONFIGURE YOUR PATHS HERE
# ============================================================
GITHUB_REPO = 'dat-tran05/robust-ensemble-kd'
DRIVE_ROOT = '/content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd'

# Derived paths
CODE_DIR = '/content/repo'
DATA_DIR = f'{DRIVE_ROOT}/data/waterbirds_v1.0'
LOG_DIR = f'{DRIVE_ROOT}/logs'
OUTPUT_DIR = f'{CODE_DIR}/blog/images'  # Save figures here

# Create output directory if needed
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Drive root: {DRIVE_ROOT}")
print(f"Data: {DATA_DIR}")
print(f"Logs: {LOG_DIR}")
print(f"Output: {OUTPUT_DIR}")

In [None]:
# Clone or update repository
if os.path.exists(CODE_DIR):
    print("Repository exists, pulling latest...")
    %cd {CODE_DIR}
    !git pull
else:
    print("Cloning repository...")
    !git clone https://github.com/{GITHUB_REPO}.git {CODE_DIR}
    %cd {CODE_DIR}

# Navigate to code directory
%cd {CODE_DIR}/code

# Add to Python path
sys.path.insert(0, f'{CODE_DIR}/code')

print("\n‚úÖ Setup complete!")

---
## 2Ô∏è‚É£ Load Experimental Results

In [None]:
# Original results (no seed set - treat as additional data point 'og')
# Original results (no seed set - treat as additional data point 'og')
ORIGINAL_RESULTS = {
    'aver_baseline': {
        'wga': 83.33,
        'avg_acc': 92.79,
        'alpha': 1.0,
        'gamma': 0.0,
        'use_agre': False
    },
    'baseline_agrekd': {
        'wga': 85.05,
        'avg_acc': 91.42,
        'alpha': 1.0,
        'gamma': 0.0,
        'use_agre': True
    },
    'exp1_alpha07': {
        'wga': 82.55,
        'avg_acc': 92.34,
        'alpha': 0.7,
        'gamma': 0.0,
        'use_agre': True
    },
    'exp1_alpha09': {
        'wga': 83.80,
        'avg_acc': 93.01,
        'alpha': 0.9,
        'gamma': 0.0,
        'use_agre': True
    },
    'exp2_gamma01': {
        'wga': 85.10,
        'avg_acc': 90.94,
        'alpha': 1.0,
        'gamma': 0.1,
        'use_agre': True
    },
    'exp2_gamma025': {
        'wga': 86.29,
        'avg_acc': 92.96,
        'alpha': 1.0,
        'gamma': 0.25,
        'use_agre': True
    },
    'exp3_combined': {
        'wga': 83.18,
        'avg_acc': 92.72,
        'alpha': 0.7,
        'gamma': 0.1,
        'use_agre': True
    },
}

# Load seeded results
log_path = os.path.join(LOG_DIR, 'seed_experiment_results.json')
if os.path.exists(log_path):
    with open(log_path, 'r') as f:
        seed_results = json.load(f)
    print(f"‚úÖ Loaded {len(seed_results)} seeded experiment results")
else:
    seed_results = {}
    print(f"‚ö†Ô∏è No seeded results found at {log_path}")

# Build combined dataframe
rows = []

# Add original results
for base_exp, data in ORIGINAL_RESULTS.items():
    rows.append({
        'Experiment': base_exp,
        'Seed': 'og',
        'Œ±': data['alpha'],
        'Œ≥': data['gamma'],
        'Method': 'AGRE-KD' if data['use_agre'] else 'AVER',
        'WGA (%)': data['wga'],
        'Avg Acc (%)': data['avg_acc'],
    })

# Add seeded results
for exp_name, data in seed_results.items():
    base_exp = data.get('base_exp', exp_name.rsplit('_seed', 1)[0])
    is_aver = 'aver' in base_exp.lower()
    is_multilayer = base_exp.startswith('ml_')
    is_disagree = 'disagree' in base_exp.lower()
    alpha = data.get('alpha', 1.0)
    
    if is_multilayer:
        method = 'Multi-Layer'
    elif is_disagree:
        method = 'Disagree-Weight'
    elif is_aver:
        method = 'AVER'
    else:
        method = 'AGRE-KD'
    
    rows.append({
        'Experiment': base_exp,
        'Seed': str(data.get('seed', 'N/A')),
        'Œ±': alpha,
        'Œ≥': data['gamma'],
        'Method': method,
        'WGA (%)': round(data['wga'] * 100, 2),
        'Avg Acc (%)': round(data.get('avg_acc', 0) * 100, 2),
    })

df = pd.DataFrame(rows)
df['Seed'] = df['Seed'].astype(str)

print(f"\n‚úÖ Combined {len(df)} total experiment runs")
print(f"   Unique experiments: {df['Experiment'].nunique()}")
print(f"   Seeds: {sorted(df['Seed'].unique())}")

# Load seeded results
log_path = os.path.join(LOG_DIR, 'seed_experiment_results.json')
if os.path.exists(log_path):
    with open(log_path, 'r') as f:
        seed_results = json.load(f)
    print(f"‚úÖ Loaded {len(seed_results)} seeded experiment results")
else:
    seed_results = {}
    print(f"‚ö†Ô∏è No seeded results found at {log_path}")

# Build combined dataframe
rows = []

# Add original results
for base_exp, data in ORIGINAL_RESULTS.items():
    rows.append({
        'Experiment': base_exp,
        'Seed': 'og',
        'Œ±': data['alpha'],
        'Œ≥': data['gamma'],
        'Method': 'AGRE-KD' if data['use_agre'] else 'AVER',
        'WGA (%)': data['wga'],
        'Avg Acc (%)': data['avg_acc'],
    })

# Add seeded results
for exp_name, data in seed_results.items():
    base_exp = data.get('base_exp', exp_name.rsplit('_seed', 1)[0])
    is_aver = 'aver' in base_exp.lower()
    is_multilayer = base_exp.startswith('ml_')
    is_disagree = 'disagree' in base_exp.lower()
    alpha = data.get('alpha', 1.0)
    
    if is_multilayer:
        method = 'Multi-Layer'
    elif is_disagree:
        method = 'Disagree-Weight'
    elif is_aver:
        method = 'AVER'
    else:
        method = 'AGRE-KD'
    
    rows.append({
        'Experiment': base_exp,
        'Seed': str(data.get('seed', 'N/A')),
        'Œ±': alpha,
        'Œ≥': data['gamma'],
        'Method': method,
        'WGA (%)': round(data['wga'] * 100, 2),
        'Avg Acc (%)': round(data.get('avg_acc', 0) * 100, 2),
    })

df = pd.DataFrame(rows)
df['Seed'] = df['Seed'].astype(str)

print(f"\n‚úÖ Combined {len(df)} total experiment runs")
print(f"   Unique experiments: {df['Experiment'].nunique()}")
print(f"   Seeds: {sorted(df['Seed'].unique())}")

In [None]:
# Compute aggregated statistics per experiment
agg_rows = []
for base_exp in df['Experiment'].unique():
    subset = df[df['Experiment'] == base_exp]
    
    wga_mean = subset['WGA (%)'].mean()
    wga_std = subset['WGA (%)'].std() if len(subset) > 1 else 0
    avg_acc_mean = subset['Avg Acc (%)'].mean()
    n = len(subset)
    
    method = subset['Method'].iloc[0]
    gamma = subset['Œ≥'].iloc[0]
    alpha = subset['Œ±'].iloc[0]
    
    agg_rows.append({
        'Experiment': base_exp,
        'Method': method,
        'Œ±': alpha,
        'Œ≥': gamma,
        'WGA_mean': wga_mean,
        'WGA_std': wga_std,
        'Avg_Acc': avg_acc_mean,
        'n': n,
    })

agg_df = pd.DataFrame(agg_rows)
print("\nAggregated Results:")
print(agg_df.sort_values('WGA_mean', ascending=False).to_string(index=False))

---
## 3Ô∏è‚É£ Figure 1: Œ≥ Sweep Bar Chart (Main Result)

Clean bar chart showing WGA vs Œ≥ (feature distillation weight) for AGRE-KD method.

In [None]:
# Filter to AGRE-KD experiments with Œ±=1.0 (the main gamma ablation)
# Map experiment names to gamma values
gamma_exp_map = {
    'baseline_agrekd': 0.00,
    'gamma_005': 0.05,
    'exp2_gamma01': 0.10,  # if exists
    'exp2_gamma025': 0.25,
    'gamma_050': 0.50,
    'gamma_075': 0.75,
    'gamma_100': 1.00,
}

# Get AGRE-KD data for gamma sweep
gamma_data = agg_df[
    (agg_df['Method'] == 'AGRE-KD') & 
    (agg_df['Œ±'] == 1.0)
].copy()

# Sort by gamma
gamma_data = gamma_data.sort_values('Œ≥')

print("Gamma sweep data:")
print(gamma_data[['Experiment', 'Œ≥', 'WGA_mean', 'WGA_std', 'n']].to_string(index=False))

In [None]:
# Create the bar chart
fig, ax = plt.subplots(figsize=(10, 6))

gammas = gamma_data['Œ≥'].values
means = gamma_data['WGA_mean'].values
stds = gamma_data['WGA_std'].values
ns = gamma_data['n'].values

# Color scheme: highlight optimal (Œ≥=0.5) in green, others in blue
# Use hatching for n=1 (single run) experiments
colors = []
hatches = []
for g, n in zip(gammas, ns):
    if g == 0.5:  # Optimal
        colors.append('#2ecc71')  # Green
    elif g == 0.0:  # Baseline
        colors.append('#e74c3c')  # Red
    else:
        colors.append('#3498db')  # Blue
    
    # Hatching for single-run experiments
    hatches.append('//' if n == 1 else '')

# Create bars
x = np.arange(len(gammas))
bars = ax.bar(x, means, color=colors, edgecolor='black', linewidth=1.2)

# Add hatching for n=1 experiments
for bar, hatch in zip(bars, hatches):
    bar.set_hatch(hatch)

# Add error bars only where n > 1
for i, (m, s, n) in enumerate(zip(means, stds, ns)):
    if n > 1 and s > 0:
        ax.errorbar(i, m, yerr=s, fmt='none', color='black', capsize=5, capthick=2, linewidth=2)

# Add n labels above bars
for i, (m, s, n) in enumerate(zip(means, stds, ns)):
    y_pos = m + (s if n > 1 and s > 0 else 0) + 0.3
    ax.annotate(f'n={n}', (i, y_pos), ha='center', va='bottom', fontsize=10, fontweight='bold')

# Add baseline reference line
baseline_wga = gamma_data[gamma_data['Œ≥'] == 0.0]['WGA_mean'].values[0]
ax.axhline(y=baseline_wga, color='#e74c3c', linestyle='--', linewidth=2, alpha=0.7,
           label=f'Baseline (Œ≥=0): {baseline_wga:.1f}%')

# Formatting
ax.set_xticks(x)
ax.set_xticklabels([f'{g:.2f}' for g in gammas], fontsize=12)
ax.set_xlabel('Œ≥ (Feature Distillation Weight)', fontsize=14)
ax.set_ylabel('Worst-Group Accuracy (%)', fontsize=14)
ax.set_title('Effect of Feature Distillation on Worst-Group Accuracy', fontsize=16, fontweight='bold')

# Set y-axis limits to show variation clearly
ax.set_ylim([82, 88])

# Add legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='#e74c3c', edgecolor='black', label='Baseline (Œ≥=0)'),
    Patch(facecolor='#2ecc71', edgecolor='black', label='Optimal (Œ≥=0.5)'),
    Patch(facecolor='#3498db', edgecolor='black', label='Other Œ≥ values'),
    Patch(facecolor='white', edgecolor='black', hatch='//', label='Single run (n=1)'),
]
ax.legend(handles=legend_elements, loc='lower right', fontsize=10)

# Add grid for readability
ax.grid(True, axis='y', alpha=0.3, linestyle='--')
ax.set_axisbelow(True)

plt.tight_layout()

# Save figure
save_path = os.path.join(OUTPUT_DIR, 'gamma_sweep.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"\n‚úÖ Saved: {save_path}")

plt.show()

---
## 4Ô∏è‚É£ Figure 2: Waterbirds Dataset Visualization

2x2 grid showing the 4 groups to help readers understand spurious correlations.

In [None]:
# Load Waterbirds metadata to find example images
metadata_path = os.path.join(DATA_DIR, 'metadata.csv')

if os.path.exists(metadata_path):
    metadata = pd.read_csv(metadata_path)
    print(f"‚úÖ Loaded metadata: {len(metadata)} images")
    print(f"\nColumns: {metadata.columns.tolist()}")
    print(f"\nGroup distribution (training split):")
    train_meta = metadata[metadata['split'] == 0]  # 0 = train
    print(train_meta.groupby(['y', 'place']).size())
else:
    print(f"‚ö†Ô∏è Metadata not found at {metadata_path}")
    print("Will use placeholder visualization")

In [None]:
# Group definitions:
# y=0: landbird, y=1: waterbird
# place=0: land, place=1: water
# Group 0: landbird + land (majority)
# Group 1: landbird + water (minority)
# Group 2: waterbird + land (minority, hardest)
# Group 3: waterbird + water (majority)

group_info = {
    (0, 0): {'name': 'Landbird + Land', 'type': 'majority', 'train_n': 3498},
    (0, 1): {'name': 'Landbird + Water', 'type': 'minority', 'train_n': 184},
    (1, 0): {'name': 'Waterbird + Land', 'type': 'minority (hardest)', 'train_n': 56},
    (1, 1): {'name': 'Waterbird + Water', 'type': 'majority', 'train_n': 1057},
}

# Find one example image per group
example_images = {}
if 'metadata' in dir():
    for (y, place), info in group_info.items():
        subset = metadata[(metadata['y'] == y) & (metadata['place'] == place) & (metadata['split'] == 0)]
        if len(subset) > 0:
            # Pick a random example
            img_filename = subset.sample(1)['img_filename'].values[0]
            img_path = os.path.join(DATA_DIR, img_filename)
            if os.path.exists(img_path):
                example_images[(y, place)] = img_path
                print(f"Group ({y}, {place}): {img_filename}")

print(f"\n‚úÖ Found {len(example_images)} example images")

In [None]:
# Create 2x2 visualization
fig, axes = plt.subplots(2, 2, figsize=(10, 10))

# Layout: rows = bird type (landbird, waterbird), cols = background (land, water)
positions = {
    (0, 0): (0, 0),  # landbird + land -> top-left
    (0, 1): (0, 1),  # landbird + water -> top-right
    (1, 0): (1, 0),  # waterbird + land -> bottom-left
    (1, 1): (1, 1),  # waterbird + water -> bottom-right
}

for (y, place), (row, col) in positions.items():
    ax = axes[row, col]
    info = group_info[(y, place)]
    
    # Load and display image if available
    if (y, place) in example_images:
        img = Image.open(example_images[(y, place)])
        ax.imshow(img)
    else:
        # Placeholder
        ax.text(0.5, 0.5, 'Image\nNot Found', ha='center', va='center', fontsize=14,
                transform=ax.transAxes)
        ax.set_facecolor('#f0f0f0')
    
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Title with group info
    pct = info['train_n'] / 4795 * 100
    title = f"{info['name']}\n(n={info['train_n']}, {pct:.1f}%)"
    ax.set_title(title, fontsize=12, fontweight='bold')
    
    # Highlight minority groups (especially the hardest one)
    if info['type'] == 'minority (hardest)':
        for spine in ax.spines.values():
            spine.set_edgecolor('#e74c3c')
            spine.set_linewidth(4)
    elif info['type'] == 'minority':
        for spine in ax.spines.values():
            spine.set_edgecolor('#f39c12')
            spine.set_linewidth(3)
    else:
        for spine in ax.spines.values():
            spine.set_edgecolor('#2ecc71')
            spine.set_linewidth(2)

# Add row/column labels
fig.text(0.02, 0.75, 'Landbird', va='center', ha='center', rotation=90, fontsize=14, fontweight='bold')
fig.text(0.02, 0.25, 'Waterbird', va='center', ha='center', rotation=90, fontsize=14, fontweight='bold')
fig.text(0.3, 0.98, 'Land Background', va='center', ha='center', fontsize=14, fontweight='bold')
fig.text(0.7, 0.98, 'Water Background', va='center', ha='center', fontsize=14, fontweight='bold')

# Add legend for border colors
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='white', edgecolor='#2ecc71', linewidth=3, label='Majority group'),
    Patch(facecolor='white', edgecolor='#f39c12', linewidth=3, label='Minority group'),
    Patch(facecolor='white', edgecolor='#e74c3c', linewidth=4, label='Hardest minority (WGA target)'),
]
fig.legend(handles=legend_elements, loc='lower center', ncol=3, fontsize=11, bbox_to_anchor=(0.5, -0.02))

plt.suptitle('Waterbirds Dataset: Spurious Correlation Between Bird Type and Background', 
             fontsize=16, fontweight='bold', y=1.02)

plt.tight_layout()
plt.subplots_adjust(top=0.92, bottom=0.08)

# Save figure
save_path = os.path.join(OUTPUT_DIR, 'waterbirds_groups.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"\n‚úÖ Saved: {save_path}")

plt.show()

---
## 5Ô∏è‚É£ (Optional) AGRE-KD vs AVER Comparison

Line plot comparing gradient-based weighting (AGRE-KD) vs simple averaging (AVER) across Œ≥ values.

In [None]:
# Get AGRE-KD and AVER data
agrekd_data = agg_df[(agg_df['Method'] == 'AGRE-KD') & (agg_df['Œ±'] == 1.0)].sort_values('Œ≥')
aver_data = agg_df[(agg_df['Method'] == 'AVER') & (agg_df['Œ±'] == 1.0)].sort_values('Œ≥')

print("AGRE-KD data points:")
print(agrekd_data[['Œ≥', 'WGA_mean', 'WGA_std', 'n']].to_string(index=False))
print("\nAVER data points:")
print(aver_data[['Œ≥', 'WGA_mean', 'WGA_std', 'n']].to_string(index=False))

In [None]:
# Create comparison plot
fig, ax = plt.subplots(figsize=(10, 6))

# AGRE-KD line
ax.errorbar(agrekd_data['Œ≥'], agrekd_data['WGA_mean'], 
            yerr=agrekd_data['WGA_std'].where(agrekd_data['n'] > 1, 0),
            fmt='o-', linewidth=2.5, markersize=10, capsize=5, capthick=2,
            color='#3498db', label='AGRE-KD (gradient weighting)')

# AVER line
if len(aver_data) > 0:
    ax.errorbar(aver_data['Œ≥'], aver_data['WGA_mean'],
                yerr=aver_data['WGA_std'].where(aver_data['n'] > 1, 0),
                fmt='s--', linewidth=2.5, markersize=10, capsize=5, capthick=2,
                color='#e67e22', label='AVER (simple averaging)')

# Formatting
ax.set_xlabel('Œ≥ (Feature Distillation Weight)', fontsize=14)
ax.set_ylabel('Worst-Group Accuracy (%)', fontsize=14)
ax.set_title('AGRE-KD vs AVER: Gradient Weighting Improves WGA Across All Œ≥', fontsize=16, fontweight='bold')
ax.legend(fontsize=12, loc='lower right')
ax.grid(True, alpha=0.3, linestyle='--')

# Set reasonable y-limits
ax.set_ylim([82, 88])

plt.tight_layout()

# Save figure
save_path = os.path.join(OUTPUT_DIR, 'agrekd_vs_aver.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"\n‚úÖ Saved: {save_path}")

plt.show()

---
## 6Ô∏è‚É£ Figure 3: Summary Results Table (Heatmap Style)

Publication-quality table showing all methods ranked by WGA, with both WGA and Avg Acc columns.
Color-coded cells highlight best/worst results at a glance.

In [None]:
# ============================================================
# Publication-Quality Results Table (Heatmap Style)
# ============================================================
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# Prepare data for the summary table
# We need: Rank, Method, Œ±, Œ≥, WGA (%), Avg Acc (%)

# Recompute aggregated stats including Avg Acc std
table_rows = []
for base_exp in df['Experiment'].unique():
    subset = df[df['Experiment'] == base_exp]
    
    wga_mean = subset['WGA (%)'].mean()
    wga_std = subset['WGA (%)'].std() if len(subset) > 1 else 0
    avg_acc_mean = subset['Avg Acc (%)'].mean()
    avg_acc_std = subset['Avg Acc (%)'].std() if len(subset) > 1 else 0
    n = len(subset)
    
    method = subset['Method'].iloc[0]
    gamma = subset['Œ≥'].iloc[0]
    alpha = subset['Œ±'].iloc[0]
    
    table_rows.append({
        'Experiment': base_exp,
        'Method': method,
        'Œ±': alpha,
        'Œ≥': gamma,
        'WGA_mean': wga_mean,
        'WGA_std': wga_std,
        'Avg_Acc_mean': avg_acc_mean,
        'Avg_Acc_std': avg_acc_std,
        'n': n,
    })

table_df = pd.DataFrame(table_rows)
table_df = table_df.sort_values('WGA_mean', ascending=False).reset_index(drop=True)
table_df['Rank'] = range(1, len(table_df) + 1)

# Format strings for display
def format_metric(mean, std, n):
    if n > 1 and std > 0:
        return f"{mean:.2f} ¬± {std:.2f}"
    else:
        return f"{mean:.2f}"

table_df['WGA (%)'] = table_df.apply(lambda r: format_metric(r['WGA_mean'], r['WGA_std'], r['n']), axis=1)
table_df['Avg Acc (%)'] = table_df.apply(lambda r: format_metric(r['Avg_Acc_mean'], r['Avg_Acc_std'], r['n']), axis=1)

print("Summary Table Data:")
print(table_df[['Rank', 'Method', 'Œ±', 'Œ≥', 'WGA (%)', 'Avg Acc (%)', 'n']].to_string(index=False))

In [None]:
# ============================================================
# Create Publication-Quality Table Figure
# ============================================================

# Select columns for display
display_cols = ['Rank', 'Method', 'Œ±', 'Œ≥', 'WGA (%)', 'Avg Acc (%)', 'n']
display_df = table_df[display_cols].copy()

# Create figure
fig, ax = plt.subplots(figsize=(12, 0.5 * len(display_df) + 1.5))
ax.axis('off')

# Create the table
table = ax.table(
    cellText=display_df.values,
    colLabels=display_df.columns,
    cellLoc='center',
    loc='center',
    colColours=['#4a90d9'] * len(display_df.columns),  # Header color
)

# Style the table
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1.2, 1.8)  # Scale width and height

# Color-code cells based on WGA values
# Get WGA mean values for coloring
wga_means = table_df['WGA_mean'].values
wga_min, wga_max = wga_means.min(), wga_means.max()

# Create colormap (red=low, white=mid, green=high)
cmap = plt.cm.RdYlGn

# Style header row
for j, col in enumerate(display_df.columns):
    cell = table[(0, j)]
    cell.set_text_props(weight='bold', color='white')
    cell.set_facecolor('#2c3e50')

# Style data rows with conditional formatting
for i in range(len(display_df)):
    # Normalize WGA for color
    wga_norm = (wga_means[i] - wga_min) / (wga_max - wga_min) if wga_max > wga_min else 0.5
    row_color = cmap(0.3 + 0.5 * wga_norm)  # Map to middle range of colormap
    
    for j, col in enumerate(display_df.columns):
        cell = table[(i + 1, j)]
        
        # Highlight WGA column with gradient
        if col == 'WGA (%)':
            cell.set_facecolor(cmap(0.2 + 0.6 * wga_norm))
            if wga_norm > 0.7:  # Best results get bold
                cell.set_text_props(weight='bold')
        # Highlight best row (rank 1)
        elif i == 0:
            cell.set_facecolor('#d5f5e3')  # Light green
            cell.set_text_props(weight='bold')
        # Alternating row colors for readability
        elif i % 2 == 0:
            cell.set_facecolor('#f8f9fa')
        else:
            cell.set_facecolor('#ffffff')

# Add title
plt.title('Summary of Results: Methods Ranked by Worst-Group Accuracy', 
          fontsize=14, fontweight='bold', pad=20)

plt.tight_layout()

# Save figure
save_path = os.path.join(OUTPUT_DIR, 'results_table.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
print(f"\n‚úÖ Saved: {save_path}")

plt.show()

In [None]:
# ============================================================
# Alternative: Seaborn Heatmap Style (More Visual)
# ============================================================
import seaborn as sns

# Create a cleaner version with just the key metrics as a heatmap
# Rows = methods, Columns = metrics (WGA, Avg Acc)

# Prepare heatmap data
heatmap_data = table_df[['Method', 'Œ≥', 'WGA_mean', 'Avg_Acc_mean']].copy()
heatmap_data['Label'] = heatmap_data.apply(
    lambda r: f"{r['Method']}\n(Œ≥={r['Œ≥']:.2f})" if r['Œ≥'] > 0 else f"{r['Method']}\n(baseline)", 
    axis=1
)

# Create pivot for heatmap
metrics_df = heatmap_data[['Label', 'WGA_mean', 'Avg_Acc_mean']].set_index('Label')
metrics_df.columns = ['WGA (%)', 'Avg Acc (%)']

# Create figure with two subplots - one for each metric
fig, axes = plt.subplots(1, 2, figsize=(14, 0.6 * len(metrics_df) + 2))

# WGA Heatmap (horizontal bar style)
ax1 = axes[0]
colors_wga = sns.color_palette("RdYlGn", n_colors=len(metrics_df))
sorted_idx = metrics_df['WGA (%)'].argsort()
y_pos = np.arange(len(metrics_df))

bars1 = ax1.barh(y_pos, metrics_df['WGA (%)'].values[sorted_idx], 
                  color=[colors_wga[i] for i in range(len(sorted_idx))])
ax1.set_yticks(y_pos)
ax1.set_yticklabels(metrics_df.index[sorted_idx], fontsize=10)
ax1.set_xlabel('Worst-Group Accuracy (%)', fontsize=12)
ax1.set_title('WGA by Method', fontsize=14, fontweight='bold')
ax1.set_xlim([80, 88])

# Add value labels
for i, (idx, val) in enumerate(zip(sorted_idx, metrics_df['WGA (%)'].values[sorted_idx])):
    ax1.text(val + 0.1, i, f'{val:.1f}%', va='center', fontsize=10, fontweight='bold')

# Avg Acc Heatmap
ax2 = axes[1]
colors_avg = sns.color_palette("Blues", n_colors=len(metrics_df))
sorted_idx_avg = metrics_df['Avg Acc (%)'].argsort()

bars2 = ax2.barh(y_pos, metrics_df['Avg Acc (%)'].values[sorted_idx_avg],
                  color=[colors_avg[i] for i in range(len(sorted_idx_avg))])
ax2.set_yticks(y_pos)
ax2.set_yticklabels(metrics_df.index[sorted_idx_avg], fontsize=10)
ax2.set_xlabel('Average Accuracy (%)', fontsize=12)
ax2.set_title('Avg Acc by Method', fontsize=14, fontweight='bold')
ax2.set_xlim([88, 95])

# Add value labels
for i, (idx, val) in enumerate(zip(sorted_idx_avg, metrics_df['Avg Acc (%)'].values[sorted_idx_avg])):
    ax2.text(val + 0.1, i, f'{val:.1f}%', va='center', fontsize=10, fontweight='bold')

plt.suptitle('Method Comparison: WGA vs Average Accuracy', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()

# Save
save_path = os.path.join(OUTPUT_DIR, 'results_bars.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"\n‚úÖ Saved: {save_path}")

plt.show()

In [None]:
# ============================================================
# CONSOLIDATED RESULTS TABLE (One Row Per Method)
# Clean, publication-ready table for the report
# ============================================================

# Define method categories and their best configurations
# We'll pick the best Œ≥ for each method type

method_categories = {
    'AGRE-KD + Features': {'method': 'AGRE-KD', 'alpha': 1.0, 'gamma_min': 0.01, 'gamma_max': 1.0},
    'AGRE-KD Baseline': {'method': 'AGRE-KD', 'alpha': 1.0, 'gamma_min': 0.0, 'gamma_max': 0.0},
    'Disagree-Weight': {'method': 'Disagree-Weight', 'alpha': 1.0, 'gamma_min': 0.0, 'gamma_max': 1.0},
    'Multi-Layer': {'method': 'Multi-Layer', 'alpha': 1.0, 'gamma_min': 0.0, 'gamma_max': 1.0},
    'AVER + Features': {'method': 'AVER', 'alpha': 1.0, 'gamma_min': 0.01, 'gamma_max': 1.0},
    'AVER Baseline': {'method': 'AVER', 'alpha': 1.0, 'gamma_min': 0.0, 'gamma_max': 0.0},
    'Combined (Œ±<1)': {'method': 'AGRE-KD', 'alpha_max': 0.99, 'gamma_min': 0.0, 'gamma_max': 1.0},
}

consolidated_rows = []

for category_name, filters in method_categories.items():
    # Filter table_df based on category
    mask = (table_df['Method'] == filters['method'])
    
    if 'alpha' in filters:
        mask &= (table_df['Œ±'] == filters['alpha'])
    if 'alpha_max' in filters:
        mask &= (table_df['Œ±'] < filters['alpha_max'])
    if 'gamma_min' in filters:
        mask &= (table_df['Œ≥'] >= filters['gamma_min'])
    if 'gamma_max' in filters:
        mask &= (table_df['Œ≥'] <= filters['gamma_max'])
    
    subset = table_df[mask]
    
    if len(subset) == 0:
        continue
    
    # Get the best result (highest WGA) for this category
    best_row = subset.loc[subset['WGA_mean'].idxmax()]
    
    consolidated_rows.append({
        'Method': category_name,
        'Best Œ≥': best_row['Œ≥'],
        'WGA_mean': best_row['WGA_mean'],
        'WGA_std': best_row['WGA_std'],
        'Avg_Acc_mean': best_row['Avg_Acc_mean'],
        'Avg_Acc_std': best_row['Avg_Acc_std'],
        'n': best_row['n'],
        'is_baseline': 'Baseline' in category_name,
    })

consolidated_df = pd.DataFrame(consolidated_rows)
consolidated_df = consolidated_df.sort_values('WGA_mean', ascending=False).reset_index(drop=True)

# Format display strings
def format_metric(mean, std, n):
    if n > 1 and std > 0:
        return f"{mean:.2f} ¬± {std:.2f}"
    else:
        return f"{mean:.2f}"

consolidated_df['WGA (%)'] = consolidated_df.apply(
    lambda r: format_metric(r['WGA_mean'], r['WGA_std'], r['n']), axis=1)
consolidated_df['Avg Acc (%)'] = consolidated_df.apply(
    lambda r: format_metric(r['Avg_Acc_mean'], r['Avg_Acc_std'], r['n']), axis=1)

print("Consolidated Results (one row per method):")
print(consolidated_df[['Method', 'Best Œ≥', 'WGA (%)', 'Avg Acc (%)', 'n']].to_string(index=False))

# ============================================================
# Create the consolidated table figure (compact, no legend)
# ============================================================
display_cols = ['Method', 'Best Œ≥', 'WGA (%)', 'Avg Acc (%)', 'n']
display_df = consolidated_df[display_cols].copy()

# Create figure - compact sizing, no extra space for legend
fig, ax = plt.subplots(figsize=(10, 0.5 * len(display_df) + 0.8))
ax.axis('off')

# Create the table
table = ax.table(
    cellText=display_df.values,
    colLabels=display_df.columns,
    cellLoc='center',
    loc='center',
)

# Style the table
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1.3, 1.8)

# Get WGA values for color coding
wga_means = consolidated_df['WGA_mean'].values
wga_min, wga_max = wga_means.min(), wga_means.max()
cmap = plt.cm.RdYlGn

# Style header row
for j, col in enumerate(display_df.columns):
    cell = table[(0, j)]
    cell.set_text_props(weight='bold', color='white', fontsize=11)
    cell.set_facecolor('#2c3e50')

# Style data rows
for i in range(len(display_df)):
    is_baseline = consolidated_df.iloc[i]['is_baseline']
    is_best = (i == 0)  # First row after sorting is best
    wga_norm = (wga_means[i] - wga_min) / (wga_max - wga_min) if wga_max > wga_min else 0.5
    
    for j, col in enumerate(display_df.columns):
        cell = table[(i + 1, j)]
        
        if col == 'WGA (%)':
            # WGA column gets gradient color
            cell.set_facecolor(cmap(0.25 + 0.5 * wga_norm))
            if is_best:
                cell.set_text_props(weight='bold', fontsize=11)
            else:
                cell.set_text_props(fontsize=11)
        elif is_best:
            # Best row (rank 1) - light green
            cell.set_facecolor('#d5f5e3')
            cell.set_text_props(weight='bold', fontsize=11)
        elif is_baseline:
            # Baseline rows - light blue
            cell.set_facecolor('#e8f4f8')
            cell.set_text_props(fontsize=11)
        elif i % 2 == 0:
            cell.set_facecolor('#f8f9fa')
            cell.set_text_props(fontsize=11)
        else:
            cell.set_facecolor('#ffffff')
            cell.set_text_props(fontsize=11)

# No title - let the report caption handle it
# No legend - colors are self-explanatory

plt.tight_layout()
plt.subplots_adjust(top=0.95, bottom=0.05)

# Save figure
save_path = os.path.join(OUTPUT_DIR, 'results_table_consolidated.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none', pad_inches=0.1)
print(f"\n‚úÖ Saved: {save_path}")

plt.show()

---
## 7Ô∏è‚É£ Refined Results Analysis (Full Statistics)

Complete aggregated results with WGA ¬± std AND Avg Acc ¬± std for all experiments.

In [None]:
# ============================================================
# REFINED AGGREGATED RESULTS - Full Statistics
# WGA ¬± std AND Avg Acc ¬± std for all experiments
# ============================================================

# Ensure df is properly loaded (re-run if needed)
print("=" * 100)
print("ALL EXPERIMENT RUNS (raw data)")
print("=" * 100)

# Assign categories for organization
def get_category(row):
    if row['Method'] == 'Multi-Layer':
        return '6. Multi-Layer'
    if row['Method'] == 'Disagree-Weight':
        return '5. Disagree Weighting'
    if row['Œ±'] < 1.0:
        return '4. Combined (Œ±<1)'
    if row['Method'] == 'AVER':
        if row['Œ≥'] == 0:
            return '1a. Baseline (AVER)'
        return '3. Feature Dist (AVER)'
    else:  # AGRE-KD
        if row['Œ≥'] == 0:
            return '1b. Baseline (AGRE-KD)'
        return '2. Feature Dist (AGRE-KD)'

df['Category'] = df.apply(get_category, axis=1)

# Sort by category, then by experiment, then by seed
df_sorted = df.sort_values(['Category', 'Experiment', 'Seed'], ascending=[True, True, True])

display_cols = ['Category', 'Experiment', 'Seed', 'Œ±', 'Œ≥', 'Method', 'WGA (%)', 'Avg Acc (%)']
print(df_sorted[display_cols].to_string(index=False))

# ============================================================
# AGGREGATED RESULTS WITH FULL STATISTICS
# ============================================================
print("\n" + "=" * 100)
print("AGGREGATED RESULTS BY EXPERIMENT (with Avg Acc ¬± std)")
print("=" * 100)

agg_full_rows = []
for base_exp in df['Experiment'].unique():
    subset = df[df['Experiment'] == base_exp]
    
    # WGA statistics
    wga_mean = subset['WGA (%)'].mean()
    wga_std = subset['WGA (%)'].std() if len(subset) > 1 else 0
    
    # Avg Acc statistics
    avg_acc_mean = subset['Avg Acc (%)'].mean()
    avg_acc_std = subset['Avg Acc (%)'].std() if len(subset) > 1 else 0
    
    n = len(subset)
    
    # Track which seeds we have
    seeds_list = sorted([str(s) for s in subset['Seed'].unique()])
    seeds_str = ','.join(seeds_list)
    
    method = subset['Method'].iloc[0]
    gamma = subset['Œ≥'].iloc[0]
    alpha = subset['Œ±'].iloc[0]
    category = subset['Category'].iloc[0]
    
    # Format strings
    if wga_std > 0:
        wga_str = f"{wga_mean:.2f} ¬± {wga_std:.2f}"
    else:
        wga_str = f"{wga_mean:.2f}"
    
    if avg_acc_std > 0:
        avg_acc_str = f"{avg_acc_mean:.2f} ¬± {avg_acc_std:.2f}"
    else:
        avg_acc_str = f"{avg_acc_mean:.2f}"
    
    agg_full_rows.append({
        'Category': category,
        'Experiment': base_exp,
        'Method': method,
        'Œ±': alpha,
        'Œ≥': gamma,
        'WGA (%)': wga_str,
        'Avg Acc (%)': avg_acc_str,
        'n': n,
        'Seeds': seeds_str,
        '_wga_mean': wga_mean,
        '_wga_std': wga_std,
        '_avg_acc_mean': avg_acc_mean,
        '_avg_acc_std': avg_acc_std,
    })

agg_full_df = pd.DataFrame(agg_full_rows)

# Show by category
print("\n--- Grouped by Category ---")
agg_by_cat = agg_full_df.sort_values(['Category', '_wga_mean'], ascending=[True, False])
print(agg_by_cat[['Category', 'Experiment', 'Method', 'Œ±', 'Œ≥', 'WGA (%)', 'Avg Acc (%)', 'n', 'Seeds']].to_string(index=False))

# Ranked by WGA
print("\n" + "-" * 100)
print("--- Ranked by WGA ---")
agg_by_wga = agg_full_df.sort_values('_wga_mean', ascending=False).reset_index(drop=True)
agg_by_wga['Rank'] = range(1, len(agg_by_wga) + 1)
print(agg_by_wga[['Rank', 'Experiment', 'Method', 'Œ±', 'Œ≥', 'WGA (%)', 'Avg Acc (%)', 'n', 'Seeds']].to_string(index=False))

In [None]:
# ============================================================
# KEY FINDINGS SUMMARY
# ============================================================
print("\n" + "=" * 100)
print("KEY FINDINGS")
print("=" * 100)

# Best overall
best = agg_full_df.loc[agg_full_df['_wga_mean'].idxmax()]
print(f"\nüèÜ Best overall: {best['Experiment']} ({best['Method']}, Œ±={best['Œ±']}, Œ≥={best['Œ≥']})")
print(f"   WGA: {best['WGA (%)']} (n={best['n']}, seeds: {best['Seeds']})")
print(f"   Avg Acc: {best['Avg Acc (%)']}")

# Best AGRE-KD with features
agre_features = agg_full_df[(agg_full_df['Method'] == 'AGRE-KD') & (agg_full_df['_wga_mean'] > 0) & (agg_full_df['Œ≥'] > 0)]
if len(agre_features) > 0:
    best_agre = agre_features.loc[agre_features['_wga_mean'].idxmax()]
    print(f"\nüìä Best AGRE-KD + Features: {best_agre['Experiment']} (Œ≥={best_agre['Œ≥']})")
    print(f"   WGA: {best_agre['WGA (%)']} | Avg Acc: {best_agre['Avg Acc (%)']}")

# Baseline comparison
baseline = agg_full_df[agg_full_df['Experiment'] == 'baseline_agrekd']
if len(baseline) > 0:
    baseline_row = baseline.iloc[0]
    print(f"\nüìà AGRE-KD Baseline (Œ≥=0):")
    print(f"   WGA: {baseline_row['WGA (%)']} | Avg Acc: {baseline_row['Avg Acc (%)']}")
    
    # Improvement over baseline
    if len(agre_features) > 0:
        improvement = agre_features['_wga_mean'].max() - baseline_row['_wga_mean']
        print(f"   Best improvement: +{improvement:.2f}%")

# AVER vs AGRE-KD comparison
aver_baseline = agg_full_df[agg_full_df['Experiment'] == 'aver_baseline']
if len(aver_baseline) > 0 and len(baseline) > 0:
    aver_row = aver_baseline.iloc[0]
    agre_row = baseline.iloc[0]
    diff = agre_row['_wga_mean'] - aver_row['_wga_mean']
    print(f"\nüîÑ AGRE-KD vs AVER Baseline:")
    print(f"   AGRE-KD: {agre_row['WGA (%)']} | AVER: {aver_row['WGA (%)']} | Diff: {diff:+.2f}%")

print("\n" + "=" * 100)

In [None]:
# ============================================================
# SEED VARIANCE ANALYSIS BY EXPERIMENT
# Shows individual seed results + statistics
# ============================================================
print("\n" + "=" * 100)
print("SEED VARIANCE ANALYSIS BY EXPERIMENT")
print("=" * 100)

for exp in sorted(df['Experiment'].unique()):
    subset = df[df['Experiment'] == exp]
    if len(subset) > 0:
        wgas = subset['WGA (%)'].values
        avg_accs = subset['Avg Acc (%)'].values
        seeds_list = subset['Seed'].values
        gamma = subset['Œ≥'].iloc[0]
        method = subset['Method'].iloc[0]
        
        print(f"\n{exp} ({method}, Œ≥={gamma}):")
        for seed, wga, acc in zip(seeds_list, wgas, avg_accs):
            marker = " (og)" if seed == 'og' else ""
            print(f"  Seed {seed}{marker}: WGA={wga:.2f}%, Avg Acc={acc:.2f}%")
        
        if len(wgas) > 1:
            print(f"  ‚Üí WGA:     Mean={np.mean(wgas):.2f}% | Std={np.std(wgas):.2f}% | Range={np.max(wgas)-np.min(wgas):.2f}%")
            print(f"  ‚Üí Avg Acc: Mean={np.mean(avg_accs):.2f}% | Std={np.std(avg_accs):.2f}% | Range={np.max(avg_accs)-np.min(avg_accs):.2f}%")

print("\n" + "=" * 100)

In [None]:
# ============================================================
# UPDATED CONSOLIDATED TABLE FOR REPORT
# One row per method category with full WGA and Avg Acc ¬± std
# ============================================================

# Use agg_full_df which has proper statistics
method_categories = {
    'AGRE-KD + Features': {'method': 'AGRE-KD', 'alpha': 1.0, 'gamma_min': 0.01, 'gamma_max': 1.0},
    'AGRE-KD Baseline': {'method': 'AGRE-KD', 'alpha': 1.0, 'gamma_min': 0.0, 'gamma_max': 0.0},
    'Disagree-Weight': {'method': 'Disagree-Weight', 'alpha': 1.0, 'gamma_min': 0.0, 'gamma_max': 1.0},
    'Multi-Layer': {'method': 'Multi-Layer', 'alpha': 1.0, 'gamma_min': 0.0, 'gamma_max': 1.0},
    'AVER + Features': {'method': 'AVER', 'alpha': 1.0, 'gamma_min': 0.01, 'gamma_max': 1.0},
    'AVER Baseline': {'method': 'AVER', 'alpha': 1.0, 'gamma_min': 0.0, 'gamma_max': 0.0},
    'Combined (Œ±<1)': {'method': 'AGRE-KD', 'alpha_max': 0.99, 'gamma_min': 0.0, 'gamma_max': 1.0},
}

final_consolidated_rows = []

for category_name, filters in method_categories.items():
    # Filter agg_full_df based on category
    mask = (agg_full_df['Method'] == filters['method'])
    
    if 'alpha' in filters:
        mask &= (agg_full_df['Œ±'] == filters['alpha'])
    if 'alpha_max' in filters:
        mask &= (agg_full_df['Œ±'] < filters['alpha_max'])
    if 'gamma_min' in filters:
        mask &= (agg_full_df['Œ≥'] >= filters['gamma_min'])
    if 'gamma_max' in filters:
        mask &= (agg_full_df['Œ≥'] <= filters['gamma_max'])
    
    subset = agg_full_df[mask]
    
    if len(subset) == 0:
        continue
    
    # Get the best result (highest WGA) for this category
    best_row = subset.loc[subset['_wga_mean'].idxmax()]
    
    final_consolidated_rows.append({
        'Method': category_name,
        'Best Œ≥': best_row['Œ≥'],
        'WGA (%)': best_row['WGA (%)'],  # Already formatted with ¬± std
        'Avg Acc (%)': best_row['Avg Acc (%)'],  # Already formatted with ¬± std
        'n': best_row['n'],
        '_wga_mean': best_row['_wga_mean'],
        '_avg_acc_mean': best_row['_avg_acc_mean'],
        'is_baseline': 'Baseline' in category_name,
    })

final_consolidated_df = pd.DataFrame(final_consolidated_rows)
final_consolidated_df = final_consolidated_df.sort_values('_wga_mean', ascending=False).reset_index(drop=True)

print("\n" + "=" * 100)
print("FINAL CONSOLIDATED TABLE FOR REPORT")
print("=" * 100)
print(final_consolidated_df[['Method', 'Best Œ≥', 'WGA (%)', 'Avg Acc (%)', 'n']].to_string(index=False))

# ============================================================
# Generate the figure
# ============================================================
display_cols = ['Method', 'Best Œ≥', 'WGA (%)', 'Avg Acc (%)', 'n']
display_df = final_consolidated_df[display_cols].copy()

fig, ax = plt.subplots(figsize=(10, 0.5 * len(display_df) + 0.8))
ax.axis('off')

table = ax.table(
    cellText=display_df.values,
    colLabels=display_df.columns,
    cellLoc='center',
    loc='center',
)

table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1.3, 1.8)

wga_means = final_consolidated_df['_wga_mean'].values
wga_min, wga_max = wga_means.min(), wga_means.max()
cmap = plt.cm.RdYlGn

# Style header row
for j, col in enumerate(display_df.columns):
    cell = table[(0, j)]
    cell.set_text_props(weight='bold', color='white', fontsize=11)
    cell.set_facecolor('#2c3e50')

# Style data rows
for i in range(len(display_df)):
    is_baseline = final_consolidated_df.iloc[i]['is_baseline']
    is_best = (i == 0)
    wga_norm = (wga_means[i] - wga_min) / (wga_max - wga_min) if wga_max > wga_min else 0.5
    
    for j, col in enumerate(display_df.columns):
        cell = table[(i + 1, j)]
        
        if col == 'WGA (%)':
            cell.set_facecolor(cmap(0.25 + 0.5 * wga_norm))
            if is_best:
                cell.set_text_props(weight='bold', fontsize=11)
            else:
                cell.set_text_props(fontsize=11)
        elif is_best:
            cell.set_facecolor('#d5f5e3')
            cell.set_text_props(weight='bold', fontsize=11)
        elif is_baseline:
            cell.set_facecolor('#e8f4f8')
            cell.set_text_props(fontsize=11)
        elif i % 2 == 0:
            cell.set_facecolor('#f8f9fa')
            cell.set_text_props(fontsize=11)
        else:
            cell.set_facecolor('#ffffff')
            cell.set_text_props(fontsize=11)

plt.tight_layout()
plt.subplots_adjust(top=0.95, bottom=0.05)

save_path = os.path.join(OUTPUT_DIR, 'results_table_consolidated.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none', pad_inches=0.1)
print(f"\n‚úÖ Saved: {save_path}")

plt.show()