In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from datetime import datetime

# Project specific imports
from dataset import CompositionalDataset, create_dataloaders
from model import NonlinearOscillatorNet, RNNModel, TransformerModel, HippoRNNModel, NMRNN_Spatial_ModulatedReadout, NMRNN_NoSpatial_ModulatedReadout, NMRNN_Spatial_FixedReadout
from training import train_model_comparative
from analysis import plot_learning_curves, perform_decodability_analysis
from utils import set_seed, get_device, count_parameters

# --- Configuration ---
CONFIG = {
    "seed": 0,
    "num_task_coefficients": 10, 
    "seq_length": 200,          
    "train_samples": 32 * 10,  # Reduced for quicker testing, increase for real runs
    "test_samples": 32 * 5,   # Reduced for quicker testing
    "batch_size": 32,
    "epochs": 200, # Reduced for quick test, increase for real runs (e.g., 50-200)
    "lr": 1e-3,
    
    # Model-specific hidden sizes / main dimension
    "hidden_size_oscillator": 64, 
    "hidden_size_rnn": 128,       # For standard GRU/LSTM
    "d_model_transformer": 64,   
    "hidden_size_hippo": 128,     # N for HIPPO
    "hidden_size_nm_rnn": 128,    # n_rnn for nmRNN variants

    # Transformer specific
    "nhead_transformer": 1,
    "num_layers_transformer": 1,
    
    # HIPPORNN specific
    "hippo_method": 'legt', # 'legs' or 'legt'
    "hippo_theta": 1.0,     # Required for 'legt'
    "hippo_dt": 1.0 / 200,  # Discretization step for HIPPO (e.g., 1.0 / seq_length)
    "hippo_inv_eps": 1e-6, # Epsilon for LegS matrix inversion regularization
    "hippo_clip_val": 50.0, # Clipping for HIPPO state c_t

    # nmRNN specific (shared for variants where applicable)
    "nm_N_NM": 4,               # Number of neuromodulators
    "nm_activation": 'tanh',    # 'relu', 'tanh' (original code had 'relu-tanh', simplified here)
    "nm_decay": 0.05, # dt_sec / tau_rnn, e.g., (20ms/step) / (100ms tau) -> exp(-0.2)
                                     # Original: math.exp(-20/100) - assuming 20ms step, 100ms tau
    "nm_bias": True,
    "nm_keepW0_spatial": False, # For the version with spatial connections
    "nm_keepW0_no_spatial": False,
    "nm_grad_clip": 1.0,
    "nm_spatial_ell": 0.1,      # For SpatialWeight
    "nm_spatial_scale": 1.0,    # For SpatialWeight

    # General task params
    "output_dim": 1, 
    "input_dim": 1,  
    "noise_level_data": 0.01,
    "run_timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
    "results_dir": "results"
}

def run_experiment():
    """
    Runs the full comparative analysis experiment.
    """
    set_seed(CONFIG["seed"])
    device = get_device()
    print(f"Using device: {device}")

    os.makedirs(CONFIG["results_dir"], exist_ok=True)
    run_results_dir = os.path.join(CONFIG["results_dir"], CONFIG["run_timestamp"])
    os.makedirs(run_results_dir, exist_ok=True)
    print(f"Results will be saved in: {run_results_dir}")

    # --- 1. Dataset ---
    print("Loading dataset...")
    train_loader, val_loader, test_loader, (input_basis, output_basis) = create_dataloaders(
        num_train_samples=CONFIG["train_samples"],
        num_val_samples=CONFIG["test_samples"], 
        num_test_samples=CONFIG["test_samples"],
        num_basis=CONFIG["num_task_coefficients"],
        seq_length=CONFIG["seq_length"],
        batch_size=CONFIG["batch_size"],
        noise=CONFIG["noise_level_data"]
    )
    print("Dataset loaded.")

    # --- 2. Models ---
    models_to_test = {
        "ComplexOscillatorNet": NonlinearOscillatorNet(
            N_oscillators=CONFIG["hidden_size_oscillator"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            seq_length=CONFIG["seq_length"], 
            seed=CONFIG["seed"]
        ),
        "RNN_GRU": RNNModel(
            hidden_size=CONFIG["hidden_size_rnn"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            num_layers=1, 
            seed=CONFIG["seed"]
        ),
        "Transformer": TransformerModel(
            d_model=CONFIG["d_model_transformer"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            num_heads=CONFIG["nhead_transformer"],
            num_layers=CONFIG["num_layers_transformer"],
            seq_length=CONFIG["seq_length"], 
            seed=CONFIG["seed"]
        ),
        "HIPPORNN_LegT": HippoRNNModel( # Using LegT by default as per config
            hidden_size=CONFIG["hidden_size_hippo"],
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            method=CONFIG["hippo_method"], 
            theta=CONFIG["hippo_theta"],
            dt=CONFIG["hippo_dt"],
            inv_eps=CONFIG["hippo_inv_eps"],
            clip_val=CONFIG["hippo_clip_val"],
            device=device, # Pass device
            seed=CONFIG["seed"]
        ),
        "NMRNN_Spatial_ModReadout": NMRNN_Spatial_ModulatedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"],
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_spatial"],
            spatial_ell=CONFIG["nm_spatial_ell"],
            spatial_scale=CONFIG["nm_spatial_scale"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
        "NMRNN_NoSpatial_ModReadout": NMRNN_NoSpatial_ModulatedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"],
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_no_spatial"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
        "NMRNN_Spatial_FixedReadout": NMRNN_Spatial_FixedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"], # N_nm still needed for the core recurrence, just not readout
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_spatial"],
            spatial_ell=CONFIG["nm_spatial_ell"],
            spatial_scale=CONFIG["nm_spatial_scale"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
    }

    all_val_losses = {}
    decodability_results = {}
    trained_models_paths = {}

    # --- 3. Training & Evaluation Loop ---
    for model_name, model in models_to_test.items():
        print(f"\n--- Training {model_name} ---")
        model.to(device)
        print(f"Number of parameters: {count_parameters(model)}")

        try:
            val_losses, best_model_state, hidden_states_test, coeffs_test = train_model_comparative(
                model,
                model_name,
                train_loader,
                val_loader, 
                test_loader, 
                CONFIG["epochs"],
                CONFIG["lr"],
                device,
                CONFIG["num_task_coefficients"], 
                run_results_dir,
                plot_intermediate_results=(len(models_to_test) == 1) 
            )
            all_val_losses[model_name] = val_losses
            
            if best_model_state:
                model_path = os.path.join(run_results_dir, f"{model_name}_best.pt")
                torch.save(best_model_state, model_path)
                trained_models_paths[model_name] = model_path
                print(f"Saved best model for {model_name} to {model_path}")
            else:
                print(f"No best model state saved for {model_name} (possibly due to training issues).")

            # --- 4. Decodability Analysis ---
            if hidden_states_test is not None and coeffs_test is not None:
                print(f"\n--- Performing Decodability Analysis for {model_name} ---")
                decodability_score = perform_decodability_analysis(
                    model_name=model_name, 
                    hidden_states=hidden_states_test, 
                    coefficients=coeffs_test,       
                    decoder_type='ridge', # Using RidgeCV as a robust default
                    decoding_metric='r2', # R-squared is often more interpretable than MSE here
                    results_dir=run_results_dir,
                    device=device,
                )
                decodability_results[model_name] = decodability_score
                print(f"Decodability (R2 score) for {model_name}: {decodability_score:.4f}")
            else:
                print(f"Skipping decodability for {model_name} due to missing hidden states or coefficients.")
        
        except Exception as e:
            print(f"!!!!!! ERROR during training or analysis for {model_name}: {e} !!!!!!")
            import traceback
            traceback.print_exc()
            all_val_losses[model_name] = [float('nan')] * CONFIG["epochs"] # Log error for this model
            decodability_results[model_name] = float('nan')


    # --- 5. Plot Learning Curves ---
    if any(all_val_losses.values()): # Check if there's anything to plot
        plot_learning_curves(all_val_losses, title="Validation Learning Curves", save_path=os.path.join(run_results_dir, "learning_curves.png"))
        print(f"\nLearning curves plotted to {os.path.join(run_results_dir, 'learning_curves.png')}")

    # --- 6. Report Decodability ---
    print("\n--- Decodability Results (R2 Score) ---")
    if decodability_results:
        for model_name, score in decodability_results.items():
            print(f"{model_name}: {score:.4f}")
        with open(os.path.join(run_results_dir, "decodability_summary.txt"), "w") as f:
            f.write("Model,R2_Score\n")
            for model_name, score in decodability_results.items():
                f.write(f"{model_name},{score:.4f}\n")
    else:
        print("No decodability results to report.")
        
    print(f"\nExperiment finished. All results in {run_results_dir}")

if __name__ == "__main__":
    run_experiment()

Using device: cuda
Results will be saved in: results/20250508_090642
Loading dataset...
Dataset loaded.

--- Training ComplexOscillatorNet ---
Number of parameters: 4480
Epoch 1/200, Train Loss: 11.6030, Val Loss: 10.3057
  New best validation loss: 10.3057
Epoch 2/200, Train Loss: 11.1090, Val Loss: 9.9933
  New best validation loss: 9.9933
Epoch 3/200, Train Loss: 10.7814, Val Loss: 9.7488
  New best validation loss: 9.7488
Epoch 4/200, Train Loss: 10.5311, Val Loss: 9.5342
  New best validation loss: 9.5342
Epoch 5/200, Train Loss: 10.3239, Val Loss: 9.3870
  New best validation loss: 9.3870
Epoch 6/200, Train Loss: 10.1541, Val Loss: 9.2522
  New best validation loss: 9.2522
Epoch 7/200, Train Loss: 10.0302, Val Loss: 9.1912
  New best validation loss: 9.1912
Epoch 8/200, Train Loss: 9.9141, Val Loss: 9.1017
  New best validation loss: 9.1017
Epoch 9/200, Train Loss: 9.8186, Val Loss: 9.0746
  New best validation loss: 9.0746
Epoch 10/200, Train Loss: 9.7362, Val Loss: 9.0192
  New



  RidgeCV Decoder for ComplexOscillatorNet - Test MSE: 0.9731, Test R2: 0.4856 (best alpha: 48.3293)
Decodability (R2 score) for ComplexOscillatorNet: 0.4856

--- Training RNN_GRU ---
Number of parameters: 50561
Epoch 1/200, Train Loss: 11.5692, Val Loss: 10.5647
  New best validation loss: 10.5647
Epoch 2/200, Train Loss: 11.5412, Val Loss: 10.5350
  New best validation loss: 10.5350
Epoch 3/200, Train Loss: 11.5245, Val Loss: 10.5287
  New best validation loss: 10.5287
Epoch 4/200, Train Loss: 11.5168, Val Loss: 10.5208
  New best validation loss: 10.5208
Epoch 5/200, Train Loss: 11.5068, Val Loss: 10.5127
  New best validation loss: 10.5127
Epoch 6/200, Train Loss: 11.5263, Val Loss: 10.5105
  New best validation loss: 10.5105
Epoch 7/200, Train Loss: 11.5025, Val Loss: 10.5197
Epoch 8/200, Train Loss: 11.4876, Val Loss: 10.5057
  New best validation loss: 10.5057
Epoch 9/200, Train Loss: 11.4961, Val Loss: 10.5015
  New best validation loss: 10.5015
Epoch 10/200, Train Loss: 11.494



  RidgeCV Decoder for RNN_GRU - Test MSE: 1.2774, Test R2: 0.3344 (best alpha: 23.3572)
Decodability (R2 score) for RNN_GRU: 0.3344

--- Training Transformer ---
Number of parameters: 281345
Epoch 1/200, Train Loss: 12.9897, Val Loss: 10.9119
  New best validation loss: 10.9119
Epoch 2/200, Train Loss: 11.9179, Val Loss: 10.6967
  New best validation loss: 10.6967
Epoch 3/200, Train Loss: 11.6769, Val Loss: 10.5601
  New best validation loss: 10.5601
Epoch 4/200, Train Loss: 11.6080, Val Loss: 10.5436
  New best validation loss: 10.5436
Epoch 5/200, Train Loss: 11.6193, Val Loss: 10.5224
  New best validation loss: 10.5224
Epoch 6/200, Train Loss: 11.6058, Val Loss: 10.5285
Epoch 7/200, Train Loss: 11.6242, Val Loss: 10.5176
  New best validation loss: 10.5176
Epoch 8/200, Train Loss: 11.6093, Val Loss: 10.5414
Epoch 9/200, Train Loss: 11.6356, Val Loss: 10.6021
Epoch 10/200, Train Loss: 11.6419, Val Loss: 10.5319
Epoch 11/200, Train Loss: 11.5674, Val Loss: 10.5056
  New best validati



  RidgeCV Decoder for Transformer - Test MSE: 0.8049, Test R2: 0.5688 (best alpha: 0.6158)
Decodability (R2 score) for Transformer: 0.5688

--- Training HIPPORNN_LegT ---
Number of parameters: 99716
Epoch 1/200, Train Loss: 11.8282, Val Loss: 10.8220
  New best validation loss: 10.8220
Epoch 2/200, Train Loss: 11.8270, Val Loss: 10.8210
  New best validation loss: 10.8210
Epoch 3/200, Train Loss: 11.8265, Val Loss: 10.8188
  New best validation loss: 10.8188
Epoch 4/200, Train Loss: 11.8222, Val Loss: 10.8091
  New best validation loss: 10.8091
Epoch 5/200, Train Loss: 11.7960, Val Loss: 10.7627
  New best validation loss: 10.7627
Epoch 6/200, Train Loss: 11.6202, Val Loss: 10.5815
  New best validation loss: 10.5815
Epoch 7/200, Train Loss: 11.3578, Val Loss: 10.1934
  New best validation loss: 10.1934
Epoch 8/200, Train Loss: 11.1580, Val Loss: 10.2206
Epoch 9/200, Train Loss: 11.1265, Val Loss: 10.0510
  New best validation loss: 10.0510
Epoch 10/200, Train Loss: 11.0246, Val Loss: 



Decodability (R2 score) for HIPPORNN_LegT: 0.5609

--- Training NMRNN_Spatial_ModReadout ---
Number of parameters: 66832
Epoch 1/200, Train Loss: 11.7282, Val Loss: 10.6048
  New best validation loss: 10.6048
Epoch 2/200, Train Loss: 11.6301, Val Loss: 10.5059
  New best validation loss: 10.5059
Epoch 3/200, Train Loss: 11.6210, Val Loss: 10.5351
Epoch 4/200, Train Loss: 11.6003, Val Loss: 10.5323
Epoch 5/200, Train Loss: 11.5896, Val Loss: 10.4591
  New best validation loss: 10.4591
Epoch 6/200, Train Loss: 11.5634, Val Loss: 10.4581
  New best validation loss: 10.4581
Epoch 7/200, Train Loss: 11.5574, Val Loss: 10.4392
  New best validation loss: 10.4392
Epoch 8/200, Train Loss: 11.4995, Val Loss: 10.3434
  New best validation loss: 10.3434
Epoch 9/200, Train Loss: 11.4644, Val Loss: 10.4236
Epoch 10/200, Train Loss: 11.3868, Val Loss: 10.1851
  New best validation loss: 10.1851
Epoch 11/200, Train Loss: 11.2707, Val Loss: 10.0963
  New best validation loss: 10.0963
Epoch 12/200, Tra



Decodability (R2 score) for NMRNN_Spatial_ModReadout: 0.3681

--- Training NMRNN_NoSpatial_ModReadout ---
Number of parameters: 66832
Epoch 1/200, Train Loss: 11.7239, Val Loss: 10.6394
  New best validation loss: 10.6394
Epoch 2/200, Train Loss: 11.6091, Val Loss: 10.5036
  New best validation loss: 10.5036
Epoch 3/200, Train Loss: 11.6039, Val Loss: 10.4902
  New best validation loss: 10.4902
Epoch 4/200, Train Loss: 11.5820, Val Loss: 10.5107
Epoch 5/200, Train Loss: 11.5717, Val Loss: 10.4247
  New best validation loss: 10.4247
Epoch 6/200, Train Loss: 11.5141, Val Loss: 10.4303
Epoch 7/200, Train Loss: 11.4809, Val Loss: 10.4041
  New best validation loss: 10.4041
Epoch 8/200, Train Loss: 11.4040, Val Loss: 10.2663
  New best validation loss: 10.2663
Epoch 9/200, Train Loss: 11.3374, Val Loss: 10.1932
  New best validation loss: 10.1932
Epoch 10/200, Train Loss: 11.2723, Val Loss: 10.1545
  New best validation loss: 10.1545
Epoch 11/200, Train Loss: 11.1561, Val Loss: 10.0262
  Ne



  RidgeCV Decoder for NMRNN_NoSpatial_ModReadout - Test MSE: 1.1489, Test R2: 0.4031 (best alpha: 48.3293)
Decodability (R2 score) for NMRNN_NoSpatial_ModReadout: 0.4031

--- Training NMRNN_Spatial_FixedReadout ---
Number of parameters: 66449
Epoch 1/200, Train Loss: 11.8307, Val Loss: 10.6832
  New best validation loss: 10.6832
Epoch 2/200, Train Loss: 11.6829, Val Loss: 10.6117
  New best validation loss: 10.6117
Epoch 3/200, Train Loss: 11.6694, Val Loss: 10.6058
  New best validation loss: 10.6058
Epoch 4/200, Train Loss: 11.6348, Val Loss: 10.5904
  New best validation loss: 10.5904
Epoch 5/200, Train Loss: 11.5767, Val Loss: 10.5303
  New best validation loss: 10.5303
Epoch 6/200, Train Loss: 11.4937, Val Loss: 10.4616
  New best validation loss: 10.4616
Epoch 7/200, Train Loss: 11.3958, Val Loss: 10.4591
  New best validation loss: 10.4591
Epoch 8/200, Train Loss: 11.3424, Val Loss: 10.4205
  New best validation loss: 10.4205
Epoch 9/200, Train Loss: 11.2679, Val Loss: 10.3899
 



  RidgeCV Decoder for NMRNN_Spatial_FixedReadout - Test MSE: 1.5930, Test R2: 0.1639 (best alpha: 100.0000)
Decodability (R2 score) for NMRNN_Spatial_FixedReadout: 0.1639
Learning curves saved to results/20250508_090642/learning_curves.png

Learning curves plotted to results/20250508_090642/learning_curves.png

--- Decodability Results (R2 Score) ---
ComplexOscillatorNet: 0.4856
RNN_GRU: 0.3344
Transformer: 0.5688
HIPPORNN_LegT: 0.5609
NMRNN_Spatial_ModReadout: 0.3681
NMRNN_NoSpatial_ModReadout: 0.4031
NMRNN_Spatial_FixedReadout: 0.1639

Experiment finished. All results in results/20250508_090642


In [9]:
np.exp(-1.0 / 20.0)

0.951229424500714