In [None]:
# Passive Regression Tuning with Cross-Validation and Multiple Trials
import os
import json
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from typing import Dict
import torch
from tqdm import tqdm
import time

from alnn.models import OneHiddenMLP
from alnn.training import train_passive
from alnn.evaluation import evaluate_regression
import torch.nn as nn
from alnn.training import TrainConfig

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from torch.utils.data import TensorDataset, DataLoader

SAVE_DIR = os.path.join('..', 'report', 'figures')
os.makedirs(SAVE_DIR, exist_ok=True)

DATASETS = ['diabetes', 'linnerud', 'california']
LR = [1e-4, 3e-4, 1e-3, 3e-3]  # More conservative LR range for regression
WD = [0.0, 1e-5, 1e-4]
HIDDEN = [32, 64, 128]
BS = [32, 64]
N_TRIALS = 5  # Number of random seeds for each config
N_FOLDS = 5  # Number of CV folds


In [None]:
def evaluate_config_cv(dataset: str, lr: float, wd: float, hidden: int, bs: int) -> Dict[str, float]:
    """Evaluate a configuration using cross-validation across multiple trials."""
    all_metrics = []
    
    for trial in range(N_TRIALS):
        # Set random seed for reproducibility
        torch.manual_seed(42 + trial)
        np.random.seed(42 + trial)
        
        trial_metrics = []
        
        # Load data for CV splits
        if dataset == "diabetes":
            ds = datasets.load_diabetes()
            y = ds.target.astype(np.float32)
        elif dataset == "linnerud":
            ds = datasets.load_linnerud()
            y = ds.target[:, 0].astype(np.float32)  # use one target (Weight)
        elif dataset == "california":
            ds = datasets.fetch_california_housing()
            y = ds.target.astype(np.float32)
        
        X = ds.data.astype(np.float32)
        
        # Use KFold for regression
        kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=42 + trial)
        
        for fold, (train_idx, val_idx) in enumerate(kf.split(X)):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            
            # Standardize features
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_val_scaled = scaler.transform(X_val)
            
            # Convert to tensors
            X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
            y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(-1)
            X_val_tensor = torch.tensor(X_val_scaled, dtype=torch.float32)
            y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(-1)
            
            # Create datasets
            train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
            val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
            
            train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False)
            
            # Train model
            model = OneHiddenMLP(input_dim=X_train_scaled.shape[1], hidden_units=hidden, output_dim=1)
            loss_fn = nn.MSELoss()
            config = TrainConfig(learning_rate=lr, weight_decay=wd, batch_size=bs, max_epochs=200, patience=20, device='cpu')
            
            train_passive(model, train_loader, val_loader, loss_fn, config)
            
            # Evaluate
            metrics = evaluate_regression(model, val_loader, device='cpu')
            trial_metrics.append(metrics)
        
        # Average across folds for this trial
        trial_avg = {}
        for key in trial_metrics[0].keys():
            trial_avg[key] = np.mean([m[key] for m in trial_metrics])
        all_metrics.append(trial_avg)
    
    # Average across trials and compute std
    final_metrics = {}
    for key in all_metrics[0].keys():
        values = [m[key] for m in all_metrics]
        final_metrics[f'{key}_mean'] = float(np.mean(values))
        final_metrics[f'{key}_std'] = float(np.std(values, ddof=1))
    
    return final_metrics


In [None]:
# Load checkpoint if exists
checkpoint_file = os.path.join(SAVE_DIR, 'passive_reg_checkpoint.json')
if os.path.exists(checkpoint_file):
    with open(checkpoint_file, 'r') as f:
        checkpoint = json.load(f)
    print(f"Resuming from checkpoint: {checkpoint['completed_configs']} configs completed")
else:
    checkpoint = {'completed_configs': 0, 'results': {}}
    print("Starting fresh run")

BEST = checkpoint.get('results', {})

# Calculate total configs
total_configs = len(DATASETS) * len(LR) * len(WD) * len(HIDDEN) * len(BS)
start_time = time.time()

# Determine resume point
completed_configs = checkpoint['completed_configs']
configs_per_dataset = len(LR) * len(WD) * len(HIDDEN) * len(BS)
resume_dataset_idx = completed_configs // configs_per_dataset
resume_config_idx = completed_configs % configs_per_dataset

print(f"Resuming from dataset {resume_dataset_idx}, config {resume_config_idx}")


In [None]:
# Process datasets starting from the resume point
for dataset_idx, dataset in enumerate(DATASETS):
    if dataset not in BEST:
        BEST[dataset] = {"best_cfg": None, "best_metric": np.inf, "history": []}
    
    print(f"\n=== Tuning {dataset} ===")
    best_metric = BEST[dataset]["best_metric"]
    best_cfg = BEST[dataset]["best_cfg"]
    hist = BEST[dataset]["history"]
    
    dataset_configs = len(LR) * len(WD) * len(HIDDEN) * len(BS)
    
    # Determine starting point for this dataset
    if dataset_idx < resume_dataset_idx:
        # This dataset is already completed, skip it
        print(f"Skipping {dataset} (already completed)")
        continue
    elif dataset_idx == resume_dataset_idx:
        # This is the dataset we need to resume from
        start_config_idx = resume_config_idx
        print(f"Resuming {dataset} from config {start_config_idx + 1}/{dataset_configs}")
    else:
        # This dataset hasn't been started yet
        start_config_idx = 0
        print(f"Starting {dataset} from config 1/{dataset_configs}")
    
    # Create progress bar for this dataset
    pbar = tqdm(total=dataset_configs, desc=f"{dataset} configs", 
                initial=len(hist), position=0, leave=True)
    
    config_count = 0
    for lr, wd, hidden, bs in itertools.product(LR, WD, HIDDEN, BS):
        # Skip configurations that have already been processed
        if config_count < start_config_idx:
            config_count += 1
            continue
            
        config_idx = len(hist) + 1
        print(f'Config {config_idx}/{dataset_configs}: lr={lr}, wd={wd}, hidden={hidden}, bs={bs}')
        
        res = evaluate_config_cv(dataset, lr, wd, hidden, bs)
        res.update({"lr": lr, "wd": wd, "hidden": hidden, "bs": bs})
        hist.append(res)
        
        if res['rmse_mean'] < best_metric:
            best_metric = res['rmse_mean']
            best_cfg = {"lr": lr, "wd": wd, "hidden": hidden, "bs": bs}
        
        # Update progress bar
        pbar.update(1)
        pbar.set_postfix({'best_rmse': f"{best_metric:.4f}"})
        
        # Save checkpoint after each config
        checkpoint['completed_configs'] += 1
        BEST[dataset] = {"best_cfg": best_cfg, "best_metric": best_metric, "history": hist}
        
        with open(checkpoint_file, 'w') as f:
            json.dump(checkpoint, f, indent=2)
        
        config_count += 1
    
    pbar.close()
    BEST[dataset] = {"best_cfg": best_cfg, "best_metric": best_metric, "history": hist}
    print(f"Best config for {dataset}: {best_cfg} (RMSE: {best_metric:.4f})")

# Save final results
with open(os.path.join(SAVE_DIR, 'passive_reg_best.json'), 'w') as f:
    json.dump(BEST, f, indent=2)

# Clean up checkpoint file
if os.path.exists(checkpoint_file):
    os.remove(checkpoint_file)

total_time = time.time() - start_time
print(f"\nTotal time: {total_time/3600:.2f} hours")
print(f"Average time per config: {total_time/total_configs:.2f} seconds")


In [None]:
# Plot best RMSE per dataset with error bars
plt.figure(figsize=(8, 5))
datasets_plot = []
means_plot = []
stds_plot = []

for dataset_idx, dataset in enumerate(DATASETS):
    best_idx = None
    best_rmse = np.inf
    for i, h in enumerate(BEST[dataset]['history']):
        if h['rmse_mean'] < best_rmse:
            best_rmse = h['rmse_mean']
            best_idx = i
    
    datasets_plot.append(dataset)
    means_plot.append(BEST[dataset]['history'][best_idx]['rmse_mean'])
    stds_plot.append(BEST[dataset]['history'][best_idx]['rmse_std'])

plt.errorbar(datasets_plot, means_plot, yerr=stds_plot, fmt='o', capsize=5, capthick=2)
plt.ylabel('RMSE (best ± std)')
plt.title('Passive Regression Best RMSE (CV + Multiple Trials)')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, 'passive_reg_best_rmse.png'), dpi=200)
plt.show()

print(f'\nSaved passive regression tuning results to {SAVE_DIR}')
print(f'Used {N_TRIALS} trials × {N_FOLDS} folds = {N_TRIALS * N_FOLDS} evaluations per config')
