In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from datetime import datetime

# Project specific imports
from dataset import CompositionalDataset, create_dataloaders
from model import NonlinearOscillatorNet, RNNModel, TransformerModel, HippoRNNModel, NMRNN_Spatial_ModulatedReadout, NMRNN_NoSpatial_ModulatedReadout, NMRNN_Spatial_FixedReadout
from training import train_model_comparative
from analysis import plot_learning_curves, perform_decodability_analysis
from utils import set_seed, get_device, count_parameters

# --- Configuration ---
CONFIG = {
    "seed": 0,
    "num_task_coefficients": 25, 
    "seq_length": 200,          
    "train_samples": 32 * 10,  # Reduced for quicker testing, increase for real runs
    "test_samples": 32 * 5,   # Reduced for quicker testing
    "batch_size": 32,
    "epochs": 200, # Reduced for quick test, increase for real runs (e.g., 50-200)
    "lr": 1e-3,
    
    # Model-specific hidden sizes / main dimension
    "hidden_size_oscillator": 64, 
    "hidden_size_rnn": 128,       # For standard GRU/LSTM
    "d_model_transformer": 64,   
    "hidden_size_hippo": 128,     # N for HIPPO
    "hidden_size_nm_rnn": 128,    # n_rnn for nmRNN variants

    # Transformer specific
    "nhead_transformer": 1,
    "num_layers_transformer": 1,
    
    # HIPPORNN specific
    "hippo_method": 'legt', # 'legs' or 'legt'
    "hippo_theta": 1.0,     # Required for 'legt'
    "hippo_dt": 1.0 / 200,  # Discretization step for HIPPO (e.g., 1.0 / seq_length)
    "hippo_inv_eps": 1e-6, # Epsilon for LegS matrix inversion regularization
    "hippo_clip_val": 50.0, # Clipping for HIPPO state c_t

    # nmRNN specific (shared for variants where applicable)
    "nm_N_NM": 4,               # Number of neuromodulators
    "nm_activation": 'tanh',    # 'relu', 'tanh' (original code had 'relu-tanh', simplified here)
    "nm_decay": 0.05, # dt_sec / tau_rnn, e.g., (20ms/step) / (100ms tau) -> exp(-0.2)
                                     # Original: math.exp(-20/100) - assuming 20ms step, 100ms tau
    "nm_bias": True,
    "nm_keepW0_spatial": False, # For the version with spatial connections
    "nm_keepW0_no_spatial": False,
    "nm_grad_clip": 1.0,
    "nm_spatial_ell": 0.1,      # For SpatialWeight
    "nm_spatial_scale": 1.0,    # For SpatialWeight

    # General task params
    "output_dim": 1, 
    "input_dim": 1,  
    "noise_level_data": 0.01,
    "run_timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
    "results_dir": "results"
}

def run_experiment():
    """
    Runs the full comparative analysis experiment.
    """
    set_seed(CONFIG["seed"])
    device = get_device()
    print(f"Using device: {device}")

    os.makedirs(CONFIG["results_dir"], exist_ok=True)
    run_results_dir = os.path.join(CONFIG["results_dir"], CONFIG["run_timestamp"])
    os.makedirs(run_results_dir, exist_ok=True)
    print(f"Results will be saved in: {run_results_dir}")

    # --- 1. Dataset ---
    print("Loading dataset...")
    train_loader, val_loader, test_loader, (input_basis, output_basis) = create_dataloaders(
        num_train_samples=CONFIG["train_samples"],
        num_val_samples=CONFIG["test_samples"], 
        num_test_samples=CONFIG["test_samples"],
        num_basis=CONFIG["num_task_coefficients"],
        seq_length=CONFIG["seq_length"],
        batch_size=CONFIG["batch_size"],
        noise=CONFIG["noise_level_data"]
    )
    print("Dataset loaded.")

    # --- 2. Models ---
    models_to_test = {
        "ComplexOscillatorNet": NonlinearOscillatorNet(
            N_oscillators=CONFIG["hidden_size_oscillator"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            seq_length=CONFIG["seq_length"], 
            seed=CONFIG["seed"]
        ),
        "RNN_GRU": RNNModel(
            hidden_size=CONFIG["hidden_size_rnn"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            num_layers=1, 
            seed=CONFIG["seed"]
        ),
        "Transformer": TransformerModel(
            d_model=CONFIG["d_model_transformer"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            num_heads=CONFIG["nhead_transformer"],
            num_layers=CONFIG["num_layers_transformer"],
            seq_length=CONFIG["seq_length"], 
            seed=CONFIG["seed"]
        ),
        "HIPPORNN_LegT": HippoRNNModel( # Using LegT by default as per config
            hidden_size=CONFIG["hidden_size_hippo"],
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            method=CONFIG["hippo_method"], 
            theta=CONFIG["hippo_theta"],
            dt=CONFIG["hippo_dt"],
            inv_eps=CONFIG["hippo_inv_eps"],
            clip_val=CONFIG["hippo_clip_val"],
            device=device, # Pass device
            seed=CONFIG["seed"]
        ),
        "NMRNN_Spatial_ModReadout": NMRNN_Spatial_ModulatedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"],
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_spatial"],
            spatial_ell=CONFIG["nm_spatial_ell"],
            spatial_scale=CONFIG["nm_spatial_scale"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
        "NMRNN_NoSpatial_ModReadout": NMRNN_NoSpatial_ModulatedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"],
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_no_spatial"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
        "NMRNN_Spatial_FixedReadout": NMRNN_Spatial_FixedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"], # N_nm still needed for the core recurrence, just not readout
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_spatial"],
            spatial_ell=CONFIG["nm_spatial_ell"],
            spatial_scale=CONFIG["nm_spatial_scale"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
    }

    all_val_losses = {}
    decodability_results = {}
    trained_models_paths = {}

    # --- 3. Training & Evaluation Loop ---
    for model_name, model in models_to_test.items():
        print(f"\n--- Training {model_name} ---")
        model.to(device)
        print(f"Number of parameters: {count_parameters(model)}")

        try:
            val_losses, best_model_state, hidden_states_test, coeffs_test = train_model_comparative(
                model,
                model_name,
                train_loader,
                val_loader, 
                test_loader, 
                CONFIG["epochs"],
                CONFIG["lr"],
                device,
                CONFIG["num_task_coefficients"], 
                run_results_dir,
                plot_intermediate_results=(len(models_to_test) == 1) 
            )
            all_val_losses[model_name] = val_losses
            
            if best_model_state:
                model_path = os.path.join(run_results_dir, f"{model_name}_best.pt")
                torch.save(best_model_state, model_path)
                trained_models_paths[model_name] = model_path
                print(f"Saved best model for {model_name} to {model_path}")
            else:
                print(f"No best model state saved for {model_name} (possibly due to training issues).")

            # --- 4. Decodability Analysis ---
            if hidden_states_test is not None and coeffs_test is not None:
                print(f"\n--- Performing Decodability Analysis for {model_name} ---")
                decodability_score = perform_decodability_analysis(
                    model_name=model_name, 
                    hidden_states=hidden_states_test, 
                    coefficients=coeffs_test,       
                    decoder_type='ridge', # Using RidgeCV as a robust default
                    decoding_metric='r2', # R-squared is often more interpretable than MSE here
                    results_dir=run_results_dir,
                    device=device,
                )
                decodability_results[model_name] = decodability_score
                print(f"Decodability (R2 score) for {model_name}: {decodability_score:.4f}")
            else:
                print(f"Skipping decodability for {model_name} due to missing hidden states or coefficients.")
        
        except Exception as e:
            print(f"!!!!!! ERROR during training or analysis for {model_name}: {e} !!!!!!")
            import traceback
            traceback.print_exc()
            all_val_losses[model_name] = [float('nan')] * CONFIG["epochs"] # Log error for this model
            decodability_results[model_name] = float('nan')


    # --- 5. Plot Learning Curves ---
    if any(all_val_losses.values()): # Check if there's anything to plot
        plot_learning_curves(all_val_losses, title="Validation Learning Curves", save_path=os.path.join(run_results_dir, "learning_curves.png"))
        print(f"\nLearning curves plotted to {os.path.join(run_results_dir, 'learning_curves.png')}")

    # --- 6. Report Decodability ---
    print("\n--- Decodability Results (R2 Score) ---")
    if decodability_results:
        for model_name, score in decodability_results.items():
            print(f"{model_name}: {score:.4f}")
        with open(os.path.join(run_results_dir, "decodability_summary.txt"), "w") as f:
            f.write("Model,R2_Score\n")
            for model_name, score in decodability_results.items():
                f.write(f"{model_name},{score:.4f}\n")
    else:
        print("No decodability results to report.")
        
    print(f"\nExperiment finished. All results in {run_results_dir}")

if __name__ == "__main__":
    run_experiment()

Using device: cuda
Results will be saved in: results/20250508_082552
Loading dataset...
Dataset loaded.

--- Training ComplexOscillatorNet ---
Number of parameters: 4480
Epoch 1/200, Train Loss: 29.2817, Val Loss: 27.0559
  New best validation loss: 27.0559
Epoch 2/200, Train Loss: 28.8193, Val Loss: 26.8465
  New best validation loss: 26.8465
Epoch 3/200, Train Loss: 28.4933, Val Loss: 26.7277
  New best validation loss: 26.7277
Epoch 4/200, Train Loss: 28.2453, Val Loss: 26.6262
  New best validation loss: 26.6262
Epoch 5/200, Train Loss: 28.0021, Val Loss: 26.4998
  New best validation loss: 26.4998
Epoch 6/200, Train Loss: 27.7891, Val Loss: 26.4488
  New best validation loss: 26.4488
Epoch 7/200, Train Loss: 27.5828, Val Loss: 26.3815
  New best validation loss: 26.3815
Epoch 8/200, Train Loss: 27.4085, Val Loss: 26.3616
  New best validation loss: 26.3616
Epoch 9/200, Train Loss: 27.2289, Val Loss: 26.2961
  New best validation loss: 26.2961
Epoch 10/200, Train Loss: 27.0854, Val



  RidgeCV Decoder for ComplexOscillatorNet - Test MSE: 1.9435, Test R2: 0.0197 (best alpha: 100.0000)
Decodability (R2 score) for ComplexOscillatorNet: 0.0197

--- Training RNN_GRU ---
Number of parameters: 50561
Epoch 1/200, Train Loss: 28.9905, Val Loss: 26.9323
  New best validation loss: 26.9323
Epoch 2/200, Train Loss: 28.8746, Val Loss: 26.8757
  New best validation loss: 26.8757
Epoch 3/200, Train Loss: 28.7561, Val Loss: 26.8288
  New best validation loss: 26.8288
Epoch 4/200, Train Loss: 28.6743, Val Loss: 26.9001
Epoch 5/200, Train Loss: 28.6995, Val Loss: 26.8490
Epoch 6/200, Train Loss: 28.6867, Val Loss: 26.8019
  New best validation loss: 26.8019
Epoch 7/200, Train Loss: 28.6547, Val Loss: 26.8064
Epoch 8/200, Train Loss: 28.6559, Val Loss: 26.8093
Epoch 9/200, Train Loss: 28.6430, Val Loss: 26.8298
Epoch 10/200, Train Loss: 28.6341, Val Loss: 26.8200
Epoch 11/200, Train Loss: 28.5926, Val Loss: 26.8174
Epoch 12/200, Train Loss: 28.6163, Val Loss: 26.8166
Epoch 13/200, Tr



  RidgeCV Decoder for RNN_GRU - Test MSE: 2.0504, Test R2: -0.0138 (best alpha: 100.0000)
Decodability (R2 score) for RNN_GRU: -0.0138

--- Training Transformer ---
Number of parameters: 281345
Epoch 1/200, Train Loss: 30.4577, Val Loss: 27.2994
  New best validation loss: 27.2994
Epoch 2/200, Train Loss: 29.3786, Val Loss: 27.4270
Epoch 3/200, Train Loss: 29.2044, Val Loss: 27.1546
  New best validation loss: 27.1546
Epoch 4/200, Train Loss: 29.0669, Val Loss: 27.1297
  New best validation loss: 27.1297
Epoch 5/200, Train Loss: 29.1172, Val Loss: 27.1599
Epoch 6/200, Train Loss: 29.0119, Val Loss: 27.0418
  New best validation loss: 27.0418
Epoch 7/200, Train Loss: 29.0533, Val Loss: 27.0663
Epoch 8/200, Train Loss: 29.0064, Val Loss: 27.0269
  New best validation loss: 27.0269
Epoch 9/200, Train Loss: 29.0012, Val Loss: 27.0739
Epoch 10/200, Train Loss: 29.0311, Val Loss: 27.0153
  New best validation loss: 27.0153
Epoch 11/200, Train Loss: 29.0557, Val Loss: 27.1309
Epoch 12/200, Tr



Decodability (R2 score) for Transformer: -0.0607

--- Training HIPPORNN_LegT ---
Number of parameters: 99716
Epoch 1/200, Train Loss: 29.1846, Val Loss: 27.1487
  New best validation loss: 27.1487
Epoch 2/200, Train Loss: 29.1803, Val Loss: 27.1451
  New best validation loss: 27.1451
Epoch 3/200, Train Loss: 29.1789, Val Loss: 27.1427
  New best validation loss: 27.1427
Epoch 4/200, Train Loss: 29.1761, Val Loss: 27.1397
  New best validation loss: 27.1397
Epoch 5/200, Train Loss: 29.1567, Val Loss: 27.1133
  New best validation loss: 27.1133
Epoch 6/200, Train Loss: 28.9731, Val Loss: 26.8958
  New best validation loss: 26.8958
Epoch 7/200, Train Loss: 28.3446, Val Loss: 26.6900
  New best validation loss: 26.6900
Epoch 8/200, Train Loss: 27.9358, Val Loss: 26.3796
  New best validation loss: 26.3796
Epoch 9/200, Train Loss: 27.6824, Val Loss: 26.8921
Epoch 10/200, Train Loss: 27.6186, Val Loss: 26.3572
  New best validation loss: 26.3572
Epoch 11/200, Train Loss: 27.4479, Val Loss: 2



  RidgeCV Decoder for HIPPORNN_LegT - Test MSE: 1.9466, Test R2: 0.0255 (best alpha: 100.0000)
Decodability (R2 score) for HIPPORNN_LegT: 0.0255

--- Training NMRNN_Spatial_ModReadout ---
Number of parameters: 66832
Epoch 1/200, Train Loss: 29.1701, Val Loss: 27.1344
  New best validation loss: 27.1344
Epoch 2/200, Train Loss: 29.1389, Val Loss: 27.1077
  New best validation loss: 27.1077
Epoch 3/200, Train Loss: 29.0919, Val Loss: 27.0870
  New best validation loss: 27.0870
Epoch 4/200, Train Loss: 29.0718, Val Loss: 27.0821
  New best validation loss: 27.0821
Epoch 5/200, Train Loss: 29.0333, Val Loss: 27.0661
  New best validation loss: 27.0661
Epoch 6/200, Train Loss: 29.0101, Val Loss: 27.0494
  New best validation loss: 27.0494
Epoch 7/200, Train Loss: 28.9528, Val Loss: 27.0778
Epoch 8/200, Train Loss: 28.9064, Val Loss: 27.0522
Epoch 9/200, Train Loss: 28.8871, Val Loss: 27.0795
Epoch 10/200, Train Loss: 28.8165, Val Loss: 27.0432
  New best validation loss: 27.0432
Epoch 11/20



  RidgeCV Decoder for NMRNN_Spatial_ModReadout - Test MSE: 2.1325, Test R2: -0.0514 (best alpha: 100.0000)
Decodability (R2 score) for NMRNN_Spatial_ModReadout: -0.0514

--- Training NMRNN_NoSpatial_ModReadout ---
Number of parameters: 66832
Epoch 1/200, Train Loss: 29.1609, Val Loss: 27.1411
  New best validation loss: 27.1411
Epoch 2/200, Train Loss: 29.1348, Val Loss: 27.1301
  New best validation loss: 27.1301
Epoch 3/200, Train Loss: 29.1311, Val Loss: 27.0978
  New best validation loss: 27.0978
Epoch 4/200, Train Loss: 29.0981, Val Loss: 27.0950
  New best validation loss: 27.0950
Epoch 5/200, Train Loss: 29.0582, Val Loss: 27.0803
  New best validation loss: 27.0803
Epoch 6/200, Train Loss: 29.0365, Val Loss: 27.0803
  New best validation loss: 27.0803
Epoch 7/200, Train Loss: 28.9733, Val Loss: 27.0308
  New best validation loss: 27.0308
Epoch 8/200, Train Loss: 28.9389, Val Loss: 27.0580
Epoch 9/200, Train Loss: 28.8833, Val Loss: 27.0378
Epoch 10/200, Train Loss: 28.8248, Val



  RidgeCV Decoder for NMRNN_NoSpatial_ModReadout - Test MSE: 2.2094, Test R2: -0.0950 (best alpha: 100.0000)
Decodability (R2 score) for NMRNN_NoSpatial_ModReadout: -0.0950

--- Training NMRNN_Spatial_FixedReadout ---
Number of parameters: 66449
Epoch 1/200, Train Loss: 29.2476, Val Loss: 27.1694
  New best validation loss: 27.1694
Epoch 2/200, Train Loss: 29.1388, Val Loss: 27.1253
  New best validation loss: 27.1253
Epoch 3/200, Train Loss: 29.1254, Val Loss: 27.1215
  New best validation loss: 27.1215
Epoch 4/200, Train Loss: 29.1134, Val Loss: 27.1010
  New best validation loss: 27.1010
Epoch 5/200, Train Loss: 29.0973, Val Loss: 27.1018
Epoch 6/200, Train Loss: 29.0643, Val Loss: 27.1138
Epoch 7/200, Train Loss: 29.0373, Val Loss: 27.1050
Epoch 8/200, Train Loss: 29.0018, Val Loss: 27.1277
Epoch 9/200, Train Loss: 28.9768, Val Loss: 27.1047
Epoch 10/200, Train Loss: 28.9088, Val Loss: 27.1537
Epoch 11/200, Train Loss: 28.7991, Val Loss: 27.1349
Epoch 12/200, Train Loss: 28.6865, V



  RidgeCV Decoder for NMRNN_Spatial_FixedReadout - Test MSE: 2.1470, Test R2: -0.0601 (best alpha: 100.0000)
Decodability (R2 score) for NMRNN_Spatial_FixedReadout: -0.0601
Learning curves saved to results/20250508_082552/learning_curves.png

Learning curves plotted to results/20250508_082552/learning_curves.png

--- Decodability Results (R2 Score) ---
ComplexOscillatorNet: 0.0197
RNN_GRU: -0.0138
Transformer: -0.0607
HIPPORNN_LegT: 0.0255
NMRNN_Spatial_ModReadout: -0.0514
NMRNN_NoSpatial_ModReadout: -0.0950
NMRNN_Spatial_FixedReadout: -0.0601

Experiment finished. All results in results/20250508_082552


In [9]:
np.exp(-1.0 / 20.0)

0.951229424500714