In [None]:
# Compare Active Learning Strategies - Classification (with CV + Multiple Trials)
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats

SAVE_DIR = os.path.join('..', 'report', 'figures')
os.makedirs(SAVE_DIR, exist_ok=True)

# Load results (these should be generated by the updated notebooks)
try:
    with open(os.path.join(SAVE_DIR, 'cls_uncertainty_results.json'), 'r') as f:
        unc = json.load(f)
    with open(os.path.join(SAVE_DIR, 'cls_sensitivity_results.json'), 'r') as f:
        sen = json.load(f)
    with open(os.path.join(SAVE_DIR, 'passive_cls_best.json'), 'r') as f:
        pas = json.load(f)
except FileNotFoundError as e:
    print(f"Results file not found: {e}")
    print("Please run the updated exploration notebooks first to generate results with CV + multiple trials")
    exit()

DATASETS = ['iris', 'wine', 'breast_cancer']
METRICS = ['accuracy', 'f1_macro']


In [None]:
# Compare all datasets
for dataset in DATASETS:
    print(f"\n=== Comparing strategies for {dataset} ===")
    
    # Get budgets and methods
    budgets = sorted([int(b) for b in next(iter(unc[dataset].values())).keys()])
    methods = list(unc[dataset].keys())
    
    for metric in METRICS:
        plt.figure(figsize=(10, 6))
        
        # Plot uncertainty methods
        for m in methods:
            means = [unc[dataset][m][str(b)][f'{metric}_mean'] for b in budgets]
            stds = [unc[dataset][m][str(b)][f'{metric}_std'] for b in budgets]
            plt.plot(budgets, means, marker='o', label=f'uncertainty_{m}', linewidth=2)
            plt.fill_between(budgets, np.array(means)-np.array(stds), np.array(means)+np.array(stds), alpha=0.2)
        
        # Plot sensitivity method
        if dataset in sen:
            means = [sen[dataset][str(b)][f'{metric}_mean'] for b in budgets]
            stds = [sen[dataset][str(b)][f'{metric}_std'] for b in budgets]
            plt.plot(budgets, means, marker='s', label='sensitivity', linewidth=2)
            plt.fill_between(budgets, np.array(means)-np.array(stds), np.array(means)+np.array(stds), alpha=0.2)
        
        # Plot passive baseline as horizontal line
        baseline = pas[dataset]['best_metric']
        plt.axhline(baseline, color='k', linestyle='--', label='passive_best', linewidth=2)
        
        plt.xlabel('Labeled budget (max_labels)')
        plt.ylabel(metric)
        plt.title(f'{dataset}: Strategies Comparison ({metric}) - CV + Multiple Trials')
        plt.grid(True, alpha=0.3)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_DIR, f'cls_{dataset}_comparison_{metric}.png'), dpi=200, bbox_inches='tight')
        plt.close()
        
        print(f"Saved {dataset} {metric} comparison")


In [None]:
# Create summary table
print("\n=== Summary Table ===")
summary_data = []

for dataset in DATASETS:
    # Get budgets and methods
    budgets = sorted([int(b) for b in next(iter(unc[dataset].values())).keys()])
    methods = list(unc[dataset].keys())
    
    for method in methods + ['sensitivity']:
        if method == 'sensitivity' and dataset not in sen:
            continue
            
        if method == 'sensitivity':
            data = sen[dataset]
        else:
            data = unc[dataset][method]
        
        # Get performance at highest budget
        max_budget = max(budgets)
        acc_mean = data[str(max_budget)]['accuracy_mean']
        acc_std = data[str(max_budget)]['accuracy_std']
        f1_mean = data[str(max_budget)]['f1_macro_mean']
        f1_std = data[str(max_budget)]['f1_macro_std']
        
        summary_data.append({
            'dataset': dataset,
            'method': method,
            'budget': max_budget,
            'accuracy_mean': acc_mean,
            'accuracy_std': acc_std,
            'f1_mean': f1_mean,
            'f1_std': f1_std
        })

# Convert to DataFrame for nice display
df = pd.DataFrame(summary_data)
print(df.round(4))

# Save summary
df.to_csv(os.path.join(SAVE_DIR, 'cls_comparison_summary.csv'), index=False)
print(f'\nSaved comparison summary to {SAVE_DIR}')
print('All comparison figures and summary saved!')
