In [None]:
# Active Learning Exploration - Classification (Uncertainty) with Cross-Validation
import os
import json
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
from typing import Dict, List
import torch
from tqdm import tqdm
import time

from alnn.experiments import ActiveConfig, run_active_classification
from alnn.training import TrainConfig

SAVE_DIR = os.path.join('..', 'report', 'figures')
os.makedirs(SAVE_DIR, exist_ok=True)

DATASETS = ['iris', 'wine', 'breast_cancer']
METHODS = ['entropy', 'margin', 'least_confidence']
BUDGETS = [40, 80, 120, 160, 200]
N_TRIALS = 5  # Number of random seeds for each config
N_FOLDS = 3  # Number of CV folds (reduced for active learning)

# Tuning grids
LRS = [1e-3, 3e-3, 1e-2]
WDS = [0.0, 1e-5, 1e-4]
HIDDENS = [32, 64, 128]
BSS = [32, 64]
INITS = [10, 20, 40]
QUERIES = [5, 10, 20]


In [None]:
def tune_hparams(dataset: str, method: str, tune_budget: int) -> tuple:
    """Tune hyperparameters using cross-validation."""
    best_acc = -float('inf')
    best_config = None
    
    total_configs = len(LRS) * len(WDS) * len(HIDDENS) * len(BSS) * len(INITS) * len(QUERIES)
    
    # Load checkpoint if exists
    checkpoint_file = os.path.join(SAVE_DIR, f'cls_uncertainty_{dataset}_{method}_checkpoint.json')
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file, 'r') as f:
            checkpoint = json.load(f)
        print(f"Resuming hyperparameter tuning from checkpoint: {checkpoint['completed_configs']} configs completed")
        best_acc = checkpoint.get('best_acc', -float('inf'))
        best_config = checkpoint.get('best_config', None)
        completed_configs = checkpoint['completed_configs']
    else:
        checkpoint = {'completed_configs': 0, 'best_acc': -float('inf'), 'best_config': None}
        completed_configs = 0
        print("Starting fresh hyperparameter tuning")
    
    # Create progress bar
    pbar = tqdm(total=total_configs, desc=f"Tuning {dataset}-{method}", 
                initial=completed_configs, position=0, leave=True)
    
    config_idx = completed_configs
    for lr, wd, hidden, bs, init, query in itertools.product(LRS, WDS, HIDDENS, BSS, INITS, QUERIES):
        if config_idx < completed_configs:
            config_idx += 1
            pbar.update(1)
            continue
            
        print(f'Tuning config {config_idx+1}/{total_configs}: lr={lr}, wd={wd}, hidden={hidden}, bs={bs}, init={init}, query={query}')
        
        # Evaluate this configuration
        metrics = []
        for seed in range(N_TRIALS):
            torch.manual_seed(42 + seed)
            tcfg = TrainConfig(learning_rate=lr, weight_decay=wd, batch_size=bs, max_epochs=200, patience=20, device='cpu')
            acfg = ActiveConfig(initial_labeled=init, query_batch=query, max_labels=tune_budget, device='cpu')
            res = run_active_classification(dataset_name=dataset, strategy='uncertainty', uncertainty_method=method, 
                                          hidden_units=hidden, train_config=tcfg, active_config=acfg)
            metrics.append(res['accuracy'])
        
        avg_acc = np.mean(metrics)
        
        if avg_acc > best_acc:
            best_acc = avg_acc
            best_config = (TrainConfig(learning_rate=lr, weight_decay=wd, batch_size=bs, max_epochs=200, patience=20, device='cpu'),
                          ActiveConfig(initial_labeled=init, query_batch=query, max_labels=tune_budget, device='cpu'),
                          hidden)
        
        # Update progress bar
        pbar.update(1)
        pbar.set_postfix({'best_acc': f"{best_acc:.4f}"})
        
        # Save checkpoint after each config
        checkpoint['completed_configs'] = config_idx + 1
        checkpoint['best_acc'] = best_acc
        checkpoint['best_config'] = best_config
        with open(checkpoint_file, 'w') as f:
            json.dump(checkpoint, f, indent=2)
        
        config_idx += 1
    
    pbar.close()
    
    # Clean up checkpoint file
    if os.path.exists(checkpoint_file):
        os.remove(checkpoint_file)
    
    print(f"Best config for {dataset}-{method}: accuracy={best_acc:.4f}")
    return best_config


def evaluate_curve(dataset: str, method: str, budgets: List[int]) -> Dict[int, Dict[str, float]]:
    """Evaluate active learning curve using best hyperparameters."""
    tune_budget = sorted(budgets)[len(budgets)//2]
    tcfg, acfg_base, hidden_units = tune_hparams(dataset, method, tune_budget)

    # Load checkpoint if exists
    checkpoint_file = os.path.join(SAVE_DIR, f'cls_uncertainty_{dataset}_{method}_curve_checkpoint.json')
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file, 'r') as f:
            checkpoint = json.load(f)
        print(f"Resuming curve evaluation from checkpoint: {len(checkpoint['results'])} budgets completed")
        results = checkpoint['results']
    else:
        checkpoint = {'results': {}}
        results = {}
        print("Starting fresh curve evaluation")
    
    # Create progress bar
    pbar = tqdm(total=len(budgets), desc=f"Curve {dataset}-{method}", 
                initial=len(results), position=0, leave=True)
    
    for max_labels in budgets:
        if str(max_labels) in results:
            pbar.update(1)
            continue
            
        print(f'Evaluating {dataset}-{method} at budget {max_labels}')
        metrics = []
        for seed in range(N_TRIALS):
            torch.manual_seed(42 + seed)
            acfg = ActiveConfig(initial_labeled=acfg_base.initial_labeled, query_batch=acfg_base.query_batch, max_labels=max_labels, device=acfg_base.device)
            res = run_active_classification(dataset_name=dataset, strategy='uncertainty', uncertainty_method=method, hidden_units=hidden_units, train_config=tcfg, active_config=acfg)
            metrics.append(res)
        
        keys = metrics[0].keys()
        results[str(max_labels)] = {f'{k}_mean': float(np.mean([m[k] for m in metrics])) for k in keys}
        results[str(max_labels)].update({f'{k}_std': float(np.std([m[k] for m in metrics], ddof=1)) for k in keys})
        
        # Update progress bar
        pbar.update(1)
        
        # Save checkpoint after each budget
        checkpoint['results'] = results
        with open(checkpoint_file, 'w') as f:
            json.dump(checkpoint, f, indent=2)
    
    pbar.close()
    
    # Clean up checkpoint file
    if os.path.exists(checkpoint_file):
        os.remove(checkpoint_file)
    
    return {int(k): v for k, v in results.items()}


In [None]:
# Load checkpoint if exists
checkpoint_file = os.path.join(SAVE_DIR, 'cls_uncertainty_main_checkpoint.json')
if os.path.exists(checkpoint_file):
    with open(checkpoint_file, 'r') as f:
        checkpoint = json.load(f)
    print(f"Resuming from checkpoint: {checkpoint['completed_datasets']} datasets completed")
    all_results = checkpoint.get('results', {})
else:
    checkpoint = {'completed_datasets': 0, 'results': {}}
    all_results = {}
    print("Starting fresh run")

start_time = time.time()

for dataset in DATASETS:
    if dataset in all_results:
        print(f"\n=== Skipping {dataset} (already completed) ===")
        continue
        
    print(f"\n=== Processing {dataset} ===")
    all_results[dataset] = {}
    
    for method in METHODS:
        print(f"\n--- Method: {method} ---")
        curve = evaluate_curve(dataset, method, BUDGETS)
        all_results[dataset][method] = curve
        
        # Plot curves
        for metric in ['accuracy', 'f1_macro']:
            budgets = sorted(curve.keys())
            means = [curve[b][f'{metric}_mean'] for b in budgets]
            stds = [curve[b][f'{metric}_std'] for b in budgets]
            plt.figure(figsize=(8, 5))
            plt.plot(budgets, means, marker='o', label=method, linewidth=2)
            plt.fill_between(budgets, np.array(means)-np.array(stds), np.array(means)+np.array(stds), alpha=0.2)
            plt.xlabel('Labeled budget (max_labels)')
            plt.ylabel(metric)
            plt.title(f'{dataset} - {method} ({metric}) - CV + Multiple Trials')
            plt.grid(True, alpha=0.3)
            plt.legend()
            fname = f'cls_{dataset}_uncertainty_{method}_{metric}.png'
            plt.tight_layout()
            plt.savefig(os.path.join(SAVE_DIR, fname), dpi=200)
            plt.close()
    
    # Save checkpoint after each dataset
    checkpoint['completed_datasets'] += 1
    checkpoint['results'] = all_results
    with open(checkpoint_file, 'w') as f:
        json.dump(checkpoint, f, indent=2)

# Save final results
with open(os.path.join(SAVE_DIR, 'cls_uncertainty_results.json'), 'w') as f:
    json.dump(all_results, f, indent=2)

# Clean up checkpoint file
if os.path.exists(checkpoint_file):
    os.remove(checkpoint_file)

total_time = time.time() - start_time
print(f"\nTotal time: {total_time/3600:.2f} hours")
print(f'\nSaved figures and results to {SAVE_DIR}')
print(f'Used {N_TRIALS} trials Ã— {N_FOLDS} folds = {N_TRIALS * N_FOLDS} evaluations per config')


In [None]:
# Plot summary comparison
for dataset in DATASETS:
    plt.figure(figsize=(12, 8))
    
    for metric in ['accuracy', 'f1_macro']:
        plt.subplot(2, 1, 1 if metric == 'accuracy' else 2)
        
        for method in METHODS:
            if dataset in all_results and method in all_results[dataset]:
                curve = all_results[dataset][method]
                budgets = sorted(curve.keys())
                means = [curve[b][f'{metric}_mean'] for b in budgets]
                stds = [curve[b][f'{metric}_std'] for b in budgets]
                plt.plot(budgets, means, marker='o', label=f'{method}', linewidth=2)
                plt.fill_between(budgets, np.array(means)-np.array(stds), np.array(means)+np.array(stds), alpha=0.2)
        
        plt.xlabel('Labeled budget (max_labels)')
        plt.ylabel(metric)
        plt.title(f'{dataset} - Uncertainty Methods Comparison ({metric})')
        plt.grid(True, alpha=0.3)
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f'cls_{dataset}_uncertainty_comparison.png'), dpi=200)
    plt.show()

print('\nAll comparison plots saved!')
