In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from datetime import datetime
import traceback # For detailed error printing

# Project specific imports (ensure these files are in your PYTHONPATH or same directory)
# It's assumed these modules exist and are correctly implemented.
from dataset import CompositionalDataset, create_dataloaders
from model import NonlinearOscillatorNet, RNNModel, TransformerModel, HippoRNNModel, NMRNN_Spatial_ModulatedReadout, NMRNN_NoSpatial_ModulatedReadout, NMRNN_Spatial_FixedReadout, TransformerModelWithCausal
from train import train_model_comparative
from ana import plot_learning_curves, perform_decodability_analysis
from utils import set_seed, get_device, count_parameters

# --- Configuration ---
CONFIG = {
    "seed": 0,
    # "num_task_coefficients": 10, # This will be set by the sweep
    "seq_length": 200,
    "train_samples": 32 * 10, # Reduced for quicker testing, increase for real runs
    "test_samples": 32 * 5,   # Reduced for quicker testing
    "batch_size": 32,
    "epochs": 200, # Reduced for quick test, increase for real runs (e.g., 50-200)
    "lr": 1e-3,

    # Model-specific hidden sizes / main dimension
    "hidden_size_oscillator": 64,
    "hidden_size_rnn": 64,      # For standard GRU/LSTM
    "d_model_transformer": 64,
    "hidden_size_hippo": 64,    # N for HIPPO
    "hidden_size_nm_rnn": 128,   # n_rnn for nmRNN variants

    # Transformer specific
    "nhead_transformer": 1,
    "num_layers_transformer": 1,

    # HIPPORNN specific
    "hippo_method": 'legt', # 'legs' or 'legt'
    "hippo_theta": 1.0,     # Required for 'legt'
    "hippo_dt": 1.0 / 200,  # Discretization step for HIPPO (e.g., 1.0 / seq_length)
    "hippo_inv_eps": 1e-6, # Epsilon for LegS matrix inversion regularization
    "hippo_clip_val": 50.0, # Clipping for HIPPO state c_t

    # nmRNN specific (shared for variants where applicable)
    "nm_N_NM": 4,             # Number of neuromodulators
    "nm_activation": 'tanh',  # 'relu', 'tanh'
    "nm_decay": 0.05,
    "nm_bias": True,
    "nm_keepW0_spatial": False,
    "nm_keepW0_no_spatial": False,
    "nm_grad_clip": 1.0,
    "nm_spatial_ell": 0.1,
    "nm_spatial_scale": 1.0,

    # General task params
    "output_dim": 1,
    "input_dim": 1,
    "noise_level_data": 0.01,
    "run_timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), # Timestamp for the entire sweep
    "results_dir": "results_sweep" # Main directory for sweep results
}

def get_model_definitions(current_config, device):
    """
    Helper function to define or re-initialize models based on the current configuration.
    This ensures models are fresh for each part of the sweep.
    """
    models = {
        "RNN_Vanillia": NMRNN_NoSpatial_ModulatedReadout(
            input_size=current_config["input_dim"],
            hidden_size=2*current_config["hidden_size_nm_rnn"],  # multiply by 2 to match parameter count.
            output_size=current_config["output_dim"],
            N_nm=0, # Vanilla RNN, so no neuromodulators
            activation_fn_name=current_config["nm_activation"],
            decay=current_config["nm_decay"],
            bias=current_config["nm_bias"],
            keepW0=current_config["nm_keepW0_no_spatial"],
            grad_clip=current_config["nm_grad_clip"],
            device=device,
            seed=current_config["seed"]
        ),
        "ComplexOscillatorNet": NonlinearOscillatorNet(
            N_oscillators=current_config["hidden_size_oscillator"],
            device=device,
            outputdim=current_config["output_dim"],
            inputdim=current_config["input_dim"],
            seq_length=current_config["seq_length"],
            seed=current_config["seed"]
        ),
        "RNN_GRU": RNNModel(
            hidden_size=current_config["hidden_size_rnn"],
            device=device,
            outputdim=current_config["output_dim"],
            inputdim=current_config["input_dim"],
            num_layers=1,
            seed=current_config["seed"]
        ),
        "Transformer": TransformerModel(
            d_model=current_config["d_model_transformer"],
            device=device,
            outputdim=current_config["output_dim"],
            inputdim=current_config["input_dim"],
            num_heads=current_config["nhead_transformer"],
            num_layers=current_config["num_layers_transformer"],
            seq_length=current_config["seq_length"],
            seed=current_config["seed"]
        ),
        "TransformerCausal": TransformerModelWithCausal(
            d_model=current_config["d_model_transformer"],
            device=device,
            outputdim=current_config["output_dim"],
            inputdim=current_config["input_dim"],
            num_heads=current_config["nhead_transformer"],
            num_layers=current_config["num_layers_transformer"],
            seq_length=current_config["seq_length"],
            seed=current_config["seed"]
        ),
        "HIPPORNN_LegT": HippoRNNModel(
            hidden_size=current_config["hidden_size_hippo"],
            outputdim=current_config["output_dim"],
            inputdim=current_config["input_dim"],
            method=current_config["hippo_method"],
            theta=current_config["hippo_theta"],
            dt=current_config["hippo_dt"],
            inv_eps=current_config["hippo_inv_eps"],
            clip_val=current_config["hippo_clip_val"],
            device=device,
            seed=current_config["seed"]
        ),
        "NMRNN_Spatial_ModReadout": NMRNN_Spatial_ModulatedReadout(
            input_size=current_config["input_dim"],
            hidden_size=current_config["hidden_size_nm_rnn"],
            output_size=current_config["output_dim"],
            N_nm=current_config["nm_N_NM"],
            activation_fn_name=current_config["nm_activation"],
            decay=current_config["nm_decay"],
            bias=current_config["nm_bias"],
            keepW0=current_config["nm_keepW0_spatial"],
            spatial_ell=current_config["nm_spatial_ell"],
            spatial_scale=current_config["nm_spatial_scale"],
            grad_clip=current_config["nm_grad_clip"],
            device=device,
            seed=current_config["seed"]
        ),
        "NMRNN_NoSpatial_ModReadout": NMRNN_NoSpatial_ModulatedReadout(
            input_size=current_config["input_dim"],
            hidden_size=current_config["hidden_size_nm_rnn"],
            output_size=current_config["output_dim"],
            N_nm=current_config["nm_N_NM"],
            activation_fn_name=current_config["nm_activation"],
            decay=current_config["nm_decay"],
            bias=current_config["nm_bias"],
            keepW0=current_config["nm_keepW0_no_spatial"],
            grad_clip=current_config["nm_grad_clip"],
            device=device,
            seed=current_config["seed"]
        ),
        "NMRNN_Spatial_FixedReadout": NMRNN_Spatial_FixedReadout(
            input_size=current_config["input_dim"],
            hidden_size=current_config["hidden_size_nm_rnn"],
            output_size=current_config["output_dim"],
            N_nm=current_config["nm_N_NM"],
            activation_fn_name=current_config["nm_activation"],
            decay=current_config["nm_decay"],
            bias=current_config["nm_bias"],
            keepW0=current_config["nm_keepW0_spatial"],
            spatial_ell=current_config["nm_spatial_ell"],
            spatial_scale=current_config["nm_spatial_scale"],
            grad_clip=current_config["nm_grad_clip"],
            device=device,
            seed=current_config["seed"]
        ),
    }
    return models

def plot_sweep_summary_metrics(metrics_data, x_values, save_dir, metric_titles):
    """
    Plots the collected metrics against the number of task coefficients.
    metrics_data: Dict like {'metric_name': {'model_name': [values]}}
    x_values: List of num_task_coefficients used.
    save_dir: Directory to save plots.
    metric_titles: Dict mapping metric_name to plot title and y-axis label.
    """
    for metric_key, model_results in metrics_data.items():
        plt.figure(figsize=(12, 7)) # Adjusted figure size
        for model_name, y_values in model_results.items():
            if len(y_values) == len(x_values):
                plt.plot(x_values, y_values, lw=3, label=model_name, marker='o', markersize=5)
            else:
                print(f"Warning: Mismatch in data length for {model_name} on metric {metric_key}. "
                      f"Expected {len(x_values)}, got {len(y_values)}. Skipping plot for this line.")

        title, ylabel = metric_titles.get(metric_key, (metric_key.replace("_", " ").title(), metric_key.replace("_", " ").title()))
        plt.xlabel("Number of Task Coefficients")
        plt.ylabel(ylabel)
        plt.title(title)
        plt.legend(loc='best', fontsize='small') # Improved legend placement
        plt.grid(True, linestyle='--', alpha=0.7) # Softer grid
        plt.tight_layout()
        save_path = os.path.join(save_dir, f"summary_plot_{metric_key}.png")
        plt.savefig(save_path)
        plt.close() # Close plot to free memory
        print(f"Saved sweep summary plot: {save_path}")

def run_experiment():
    """
    Runs the full comparative analysis experiment with a sweep over num_task_coefficients.
    """
    set_seed(CONFIG["seed"])
    device = get_device()
    print(f"Using device: {device}")

    # Main results directory for the entire sweep
    base_sweep_results_dir = os.path.join(CONFIG["results_dir"], CONFIG["run_timestamp"])
    os.makedirs(base_sweep_results_dir, exist_ok=True)
    print(f"All sweep results will be saved in: {base_sweep_results_dir}")

    # Define the sweep range for num_task_coefficients
    num_task_coefficients_sweep = np.arange(1, 20, 1)

    # Get model names for initializing metrics storage
    # Create a temporary config for this, num_task_coefficients doesn't affect model names
    temp_config_for_names = CONFIG.copy()
    temp_config_for_names["num_task_coefficients"] = 1 # Arbitrary value
    model_names = list(get_model_definitions(temp_config_for_names, device).keys())

    # Initialize structures to store metrics for final plotting across the sweep
    overall_metrics = {
        "best_val_loss": {name: [] for name in model_names},
        "final_train_loss": {name: [] for name in model_names},
        "final_val_loss": {name: [] for name in model_names},
        "best_val_varex": {name: [] for name in model_names},
        "final_train_varex": {name: [] for name in model_names},
        "final_val_varex": {name: [] for name in model_names},
        "generalization_gap": {name: [] for name in model_names},
        "decodability_r2": {name: [] for name in model_names}, # Using R2 for decodability
    }

    # --- Sweep over num_task_coefficients ---
    for num_coeff in num_task_coefficients_sweep:
        current_num_coeff = int(num_coeff) # Ensure it's an integer
        print(f"\n{'='*25} Running for num_task_coefficients: {current_num_coeff} {'='*25}")

        # Update config for the current number of task coefficients
        current_config = CONFIG.copy()
        current_config["num_task_coefficients"] = current_num_coeff

        # Create a subdirectory for this specific num_coeff run's detailed artifacts
        coeff_run_results_dir = os.path.join(base_sweep_results_dir, f"coeffs_{current_num_coeff}")
        os.makedirs(coeff_run_results_dir, exist_ok=True)
        print(f"Results for coeffs={current_num_coeff} will be saved in: {coeff_run_results_dir}")

        # --- 1. Dataset (depends on num_coeff) ---
        print("Loading dataset...")
        train_loader, val_loader, test_loader, (input_basis, output_basis) = create_dataloaders(
            num_train_samples=current_config["train_samples"],
            num_val_samples=current_config["test_samples"],
            num_test_samples=current_config["test_samples"],
            num_basis=current_config["num_task_coefficients"], # Key change for sweep
            seq_length=current_config["seq_length"],
            batch_size=current_config["batch_size"],
            noise=current_config["noise_level_data"]
        )
        print("Dataset loaded.")

        # --- 2. Models (re-initialize for each num_coeff run) ---
        models_to_test = get_model_definitions(current_config, device)

        # Store results for the current num_coeff run (for individual reports if needed)
        current_run_train_losses_epoch = {}
        current_run_val_losses_epoch = {}
        current_run_train_varex_epoch = {}
        current_run_val_varex_epoch = {}
        current_run_decodability_results = {}
        current_run_trained_models_paths = {}


        # --- 3. Training & Evaluation Loop for current num_coeff ---
        for model_name, model_instance in models_to_test.items():
            print(f"\n--- Training {model_name} (coeffs={current_num_coeff}) ---")
            model = model_instance # Use the already instantiated model
            model.to(device)
            print(f"Number of parameters: {count_parameters(model)}")

            for x_batch, y_batch, _ in train_loader: # Loader yields (data, target, coefficients)
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                break
            total_train_var = torch.std(y_batch).item()**2

            for x_batch, y_batch, _ in val_loader: # Loader yields (data, target, coefficients)
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                break
            total_val_var = torch.std(y_batch).item()**2

            try:
                train_losses_epoch, val_losses_epoch, best_model_state, hidden_states_test, coeffs_test = train_model_comparative(
                    model,
                    model_name,
                    train_loader,
                    val_loader,
                    test_loader,
                    current_config["epochs"],
                    current_config["lr"],
                    device,
                    current_config["num_task_coefficients"],
                    coeff_run_results_dir, # Save intermediate plots/models for this coeff run
                    plot_intermediate_results=False # Typically off for sweeps to avoid clutter
                )

                current_run_train_losses_epoch[model_name] = train_losses_epoch
                current_run_val_losses_epoch[model_name] = val_losses_epoch
                current_run_train_varex_epoch[model_name] = 1-np.array(train_losses_epoch)/total_train_var
                current_run_val_varex_epoch[model_name] = 1-np.array(val_losses_epoch)/total_val_var

                # Calculate metrics for overall_metrics storage
                # Ensure lists are not empty and handle potential NaN values from training
                best_val_loss = min(val_losses_epoch) if val_losses_epoch and not all(np.isnan(l) for l in val_losses_epoch if l is not None) else float('nan')
                final_train_loss = train_losses_epoch[-1] if train_losses_epoch and not np.isnan(train_losses_epoch[-1]) else float('nan')
                final_val_loss = val_losses_epoch[-1] if val_losses_epoch and not np.isnan(val_losses_epoch[-1]) else float('nan')
                
                if not (np.isnan(final_val_loss) or np.isnan(final_train_loss)):
                    gen_gap = final_val_loss - final_train_loss
                else:
                    gen_gap = float('nan')

                overall_metrics["best_val_loss"][model_name].append(best_val_loss)
                overall_metrics["final_train_loss"][model_name].append(final_train_loss)
                overall_metrics["final_val_loss"][model_name].append(final_val_loss)
                overall_metrics["best_val_varex"][model_name].append(1-best_val_loss/total_val_var)
                overall_metrics["final_train_varex"][model_name].append(1-final_train_loss/total_train_var)
                overall_metrics["final_val_varex"][model_name].append(1-final_val_loss/total_val_var)
                overall_metrics["generalization_gap"][model_name].append(gen_gap)

                if best_model_state:
                    model_path = os.path.join(coeff_run_results_dir, f"{model_name}_coeffs_{current_num_coeff}_best.pt")
                    torch.save(best_model_state, model_path)
                    current_run_trained_models_paths[model_name] = model_path
                    print(f"Saved best model for {model_name} (coeffs={current_num_coeff}) to {model_path}")
                else:
                    print(f"No best model state saved for {model_name} (coeffs={current_num_coeff}).")

                # --- 4. Decodability Analysis for current num_coeff run ---
                decodability_score_r2 = float('nan')
                if hidden_states_test is not None and coeffs_test is not None:
                    print(f"\n--- Performing Decodability Analysis for {model_name} (coeffs={current_num_coeff}) ---")
                    decodability_score_r2 = perform_decodability_analysis(
                        model_name=model_name,
                        hidden_states=hidden_states_test,
                        coefficients=coeffs_test,
                        decoder_type='ridge',
                        decoding_metric='r2', # R-squared
                        results_dir=coeff_run_results_dir, # Save plots for this specific run
                        device=device,
                    )
                    print(f"Decodability (R2 score) for {model_name} (coeffs={current_num_coeff}): {decodability_score_r2:.4f}")
                else:
                    print(f"Skipping decodability for {model_name} (coeffs={current_num_coeff}) due to missing data.")
                
                current_run_decodability_results[model_name] = decodability_score_r2
                overall_metrics["decodability_r2"][model_name].append(decodability_score_r2)

            except Exception as e:
                print(f"!!!!!! ERROR during training or analysis for {model_name} (coeffs={current_num_coeff}): {e} !!!!!!")
                traceback.print_exc()
                # Append NaN to all metrics for this model for this num_coeff to maintain list lengths
                overall_metrics["best_val_loss"][model_name].append(float('nan'))
                overall_metrics["final_train_loss"][model_name].append(float('nan'))
                overall_metrics["final_val_loss"][model_name].append(float('nan'))
                overall_metrics["best_val_varex"][model_name].append(float('nan'))
                overall_metrics["final_train_varex"][model_name].append(float('nan'))
                overall_metrics["final_val_varex"][model_name].append(float('nan'))
                overall_metrics["generalization_gap"][model_name].append(float('nan'))
                overall_metrics["decodability_r2"][model_name].append(float('nan'))
                
                # Store NaNs for current run reports as well
                current_run_train_losses_epoch[model_name] = [float('nan')] * current_config["epochs"]
                current_run_val_losses_epoch[model_name] = [float('nan')] * current_config["epochs"]
                current_run_train_varex_epoch[model_name] = [float('nan')] * current_config["epochs"]
                current_run_val_varex_epoch[model_name] = [float('nan')] * current_config["epochs"]
                current_run_decodability_results[model_name] = float('nan')


        # --- 5. Plot Learning Curves for the current num_coeff run ---
        # This plots validation losses per epoch for all models for the current num_coeff
        if any(current_run_val_losses_epoch.values()):
            plot_learning_curves(
                current_run_val_losses_epoch, # Pass the dict of val losses
                title=f"Validation Learning Curves (Coeffs={current_num_coeff})",
                save_path=os.path.join(coeff_run_results_dir, "learning_curves_validation.png")
            )
            print(f"\nValidation learning curves for coeffs={current_num_coeff} plotted.")
        
        # Optionally, plot training losses too if plot_learning_curves is adapted or you have a similar function
        # For example:
        # if any(current_run_train_losses_epoch.values()):
        #     plot_learning_curves( # Assuming it can take train_losses too
        #         current_run_train_losses_epoch,
        #         title=f"Training Learning Curves (Coeffs={current_num_coeff})",
        #         save_path=os.path.join(coeff_run_results_dir, "learning_curves_training.png")
        #     )


        # --- 6. Report Decodability for the current num_coeff run ---
        if current_run_decodability_results:
            summary_path = os.path.join(coeff_run_results_dir, "decodability_summary.txt")
            with open(summary_path, "w") as f:
                f.write("Model,R2_Score\n")
                for model_name, score in current_run_decodability_results.items():
                    f.write(f"{model_name},{score:.4f}\n")
            print(f"Decodability summary for coeffs={current_num_coeff} saved to {summary_path}")

    # --- 7. After all num_coeff sweeps, Plot Aggregate Metrics ---
    print("\n\n--- Generating Sweep Summary Plots ---")
    metric_plot_titles = {
        "best_val_loss": ("Best Validation Loss vs. Task Coefficients", "Best Validation Loss"),
        "final_train_loss": ("Final Training Loss vs. Task Coefficients", "Final Training Loss"),
        "final_val_loss": ("Final Validation Loss vs. Task Coefficients", "Final Validation Loss"),
        "generalization_gap": ("Generalization Gap vs. Task Coefficients", "Generalization Gap (Val - Train)"),
        "decodability_r2": ("Hidden Layer Decodability (R²) vs. Task Coefficients", "Decodability (R²)"),
    }
    plot_sweep_summary_metrics(overall_metrics, num_task_coefficients_sweep, base_sweep_results_dir, metric_plot_titles)

    # --- 8. Save Overall Metrics Data ---
    overall_metrics_save_path = os.path.join(base_sweep_results_dir, "overall_metrics_data.npz")
    # Convert lists to numpy arrays for saving, handling potential mixed types if NaNs are strings
    save_overall_metrics = {}
    for metric, models_data in overall_metrics.items():
        save_overall_metrics[metric] = {}
        for model, values in models_data.items():
            try:
                save_overall_metrics[metric][model] = np.array(values, dtype=float) # Attempt to cast to float
            except ValueError: # Handle cases where a value might not be convertible (e.g. if None was appended)
                 save_overall_metrics[metric][model] = np.array([v if isinstance(v, (int, float)) else np.nan for v in values], dtype=float)


    np.savez_compressed(overall_metrics_save_path, **save_overall_metrics)
    print(f"Overall metrics data saved to: {overall_metrics_save_path}")


    print(f"\nExperiment sweep finished. All results in {base_sweep_results_dir}")


if __name__ == "__main__":
    run_experiment()


Using device: cuda
All sweep results will be saved in: results_sweep/20250511_023956




NameError: name 'TransformerModelWithCausal' is not defined