In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from datetime import datetime

# Project specific imports
from dataset import CompositionalDataset, create_dataloaders
from model import NonlinearOscillatorNet, RNNModel, TransformerModel, HippoRNNModel, NMRNN_Spatial_ModulatedReadout, NMRNN_NoSpatial_ModulatedReadout, NMRNN_Spatial_FixedReadout
from training import train_model_comparative
from analysis import plot_learning_curves, perform_decodability_analysis
from utils import set_seed, get_device, count_parameters

# --- Configuration ---
CONFIG = {
    "seed": 0,
    "num_task_coefficients": 5, 
    "seq_length": 200,          
    "train_samples": 32 * 10,  # Reduced for quicker testing, increase for real runs
    "test_samples": 32 * 5,   # Reduced for quicker testing
    "batch_size": 32,
    "epochs": 200, # Reduced for quick test, increase for real runs (e.g., 50-200)
    "lr": 1e-3,
    
    # Model-specific hidden sizes / main dimension
    "hidden_size_oscillator": 64, 
    "hidden_size_rnn": 128,       # For standard GRU/LSTM
    "d_model_transformer": 64,   
    "hidden_size_hippo": 128,     # N for HIPPO
    "hidden_size_nm_rnn": 128,    # n_rnn for nmRNN variants

    # Transformer specific
    "nhead_transformer": 1,
    "num_layers_transformer": 1,
    
    # HIPPORNN specific
    "hippo_method": 'legt', # 'legs' or 'legt'
    "hippo_theta": 1.0,     # Required for 'legt'
    "hippo_dt": 1.0 / 200,  # Discretization step for HIPPO (e.g., 1.0 / seq_length)
    "hippo_inv_eps": 1e-6, # Epsilon for LegS matrix inversion regularization
    "hippo_clip_val": 50.0, # Clipping for HIPPO state c_t

    # nmRNN specific (shared for variants where applicable)
    "nm_N_NM": 4,               # Number of neuromodulators
    "nm_activation": 'tanh',    # 'relu', 'tanh' (original code had 'relu-tanh', simplified here)
    "nm_decay": 0.05, # dt_sec / tau_rnn, e.g., (20ms/step) / (100ms tau) -> exp(-0.2)
                                     # Original: math.exp(-20/100) - assuming 20ms step, 100ms tau
    "nm_bias": True,
    "nm_keepW0_spatial": False, # For the version with spatial connections
    "nm_keepW0_no_spatial": False,
    "nm_grad_clip": 1.0,
    "nm_spatial_ell": 0.1,      # For SpatialWeight
    "nm_spatial_scale": 1.0,    # For SpatialWeight

    # General task params
    "output_dim": 1, 
    "input_dim": 1,  
    "noise_level_data": 0.01,
    "run_timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
    "results_dir": "results"
}

def run_experiment():
    """
    Runs the full comparative analysis experiment.
    """
    set_seed(CONFIG["seed"])
    device = get_device()
    print(f"Using device: {device}")

    os.makedirs(CONFIG["results_dir"], exist_ok=True)
    run_results_dir = os.path.join(CONFIG["results_dir"], CONFIG["run_timestamp"])
    os.makedirs(run_results_dir, exist_ok=True)
    print(f"Results will be saved in: {run_results_dir}")

    # --- 1. Dataset ---
    print("Loading dataset...")
    train_loader, val_loader, test_loader, (input_basis, output_basis) = create_dataloaders(
        num_train_samples=CONFIG["train_samples"],
        num_val_samples=CONFIG["test_samples"], 
        num_test_samples=CONFIG["test_samples"],
        num_basis=CONFIG["num_task_coefficients"],
        seq_length=CONFIG["seq_length"],
        batch_size=CONFIG["batch_size"],
        noise=CONFIG["noise_level_data"]
    )
    print("Dataset loaded.")

    # --- 2. Models ---
    models_to_test = {
        "ComplexOscillatorNet": NonlinearOscillatorNet(
            N_oscillators=CONFIG["hidden_size_oscillator"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            seq_length=CONFIG["seq_length"], 
            seed=CONFIG["seed"]
        ),
        "RNN_GRU": RNNModel(
            hidden_size=CONFIG["hidden_size_rnn"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            num_layers=1, 
            seed=CONFIG["seed"]
        ),
        "Transformer": TransformerModel(
            d_model=CONFIG["d_model_transformer"],
            device=device,
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            num_heads=CONFIG["nhead_transformer"],
            num_layers=CONFIG["num_layers_transformer"],
            seq_length=CONFIG["seq_length"], 
            seed=CONFIG["seed"]
        ),
        "HIPPORNN_LegT": HippoRNNModel( # Using LegT by default as per config
            hidden_size=CONFIG["hidden_size_hippo"],
            outputdim=CONFIG["output_dim"],
            inputdim=CONFIG["input_dim"],
            method=CONFIG["hippo_method"], 
            theta=CONFIG["hippo_theta"],
            dt=CONFIG["hippo_dt"],
            inv_eps=CONFIG["hippo_inv_eps"],
            clip_val=CONFIG["hippo_clip_val"],
            device=device, # Pass device
            seed=CONFIG["seed"]
        ),
        "NMRNN_Spatial_ModReadout": NMRNN_Spatial_ModulatedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"],
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_spatial"],
            spatial_ell=CONFIG["nm_spatial_ell"],
            spatial_scale=CONFIG["nm_spatial_scale"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
        "NMRNN_NoSpatial_ModReadout": NMRNN_NoSpatial_ModulatedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"],
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_no_spatial"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
        "NMRNN_Spatial_FixedReadout": NMRNN_Spatial_FixedReadout(
            input_size=CONFIG["input_dim"],
            hidden_size=CONFIG["hidden_size_nm_rnn"],
            output_size=CONFIG["output_dim"],
            N_nm=CONFIG["nm_N_NM"], # N_nm still needed for the core recurrence, just not readout
            activation_fn_name=CONFIG["nm_activation"],
            decay=CONFIG["nm_decay"],
            bias=CONFIG["nm_bias"],
            keepW0=CONFIG["nm_keepW0_spatial"],
            spatial_ell=CONFIG["nm_spatial_ell"],
            spatial_scale=CONFIG["nm_spatial_scale"],
            grad_clip=CONFIG["nm_grad_clip"],
            device=device,
            seed=CONFIG["seed"]
        ),
    }

    all_val_losses = {}
    decodability_results = {}
    trained_models_paths = {}

    # --- 3. Training & Evaluation Loop ---
    for model_name, model in models_to_test.items():
        print(f"\n--- Training {model_name} ---")
        model.to(device)
        print(f"Number of parameters: {count_parameters(model)}")

        try:
            val_losses, best_model_state, hidden_states_test, coeffs_test = train_model_comparative(
                model,
                model_name,
                train_loader,
                val_loader, 
                test_loader, 
                CONFIG["epochs"],
                CONFIG["lr"],
                device,
                CONFIG["num_task_coefficients"], 
                run_results_dir,
                plot_intermediate_results=(len(models_to_test) == 1) 
            )
            all_val_losses[model_name] = val_losses
            
            if best_model_state:
                model_path = os.path.join(run_results_dir, f"{model_name}_best.pt")
                torch.save(best_model_state, model_path)
                trained_models_paths[model_name] = model_path
                print(f"Saved best model for {model_name} to {model_path}")
            else:
                print(f"No best model state saved for {model_name} (possibly due to training issues).")

            # --- 4. Decodability Analysis ---
            if hidden_states_test is not None and coeffs_test is not None:
                print(f"\n--- Performing Decodability Analysis for {model_name} ---")
                decodability_score = perform_decodability_analysis(
                    model_name=model_name, 
                    hidden_states=hidden_states_test, 
                    coefficients=coeffs_test,       
                    decoder_type='ridge', # Using RidgeCV as a robust default
                    decoding_metric='r2', # R-squared is often more interpretable than MSE here
                    results_dir=run_results_dir,
                    device=device,
                )
                decodability_results[model_name] = decodability_score
                print(f"Decodability (R2 score) for {model_name}: {decodability_score:.4f}")
            else:
                print(f"Skipping decodability for {model_name} due to missing hidden states or coefficients.")
        
        except Exception as e:
            print(f"!!!!!! ERROR during training or analysis for {model_name}: {e} !!!!!!")
            import traceback
            traceback.print_exc()
            all_val_losses[model_name] = [float('nan')] * CONFIG["epochs"] # Log error for this model
            decodability_results[model_name] = float('nan')


    # --- 5. Plot Learning Curves ---
    if any(all_val_losses.values()): # Check if there's anything to plot
        plot_learning_curves(all_val_losses, title="Validation Learning Curves", save_path=os.path.join(run_results_dir, "learning_curves.png"))
        print(f"\nLearning curves plotted to {os.path.join(run_results_dir, 'learning_curves.png')}")

    # --- 6. Report Decodability ---
    print("\n--- Decodability Results (R2 Score) ---")
    if decodability_results:
        for model_name, score in decodability_results.items():
            print(f"{model_name}: {score:.4f}")
        with open(os.path.join(run_results_dir, "decodability_summary.txt"), "w") as f:
            f.write("Model,R2_Score\n")
            for model_name, score in decodability_results.items():
                f.write(f"{model_name},{score:.4f}\n")
    else:
        print("No decodability results to report.")
        
    print(f"\nExperiment finished. All results in {run_results_dir}")

if __name__ == "__main__":
    run_experiment()

Using device: cuda
Results will be saved in: results/20250508_082527
Loading dataset...
Dataset loaded.

--- Training ComplexOscillatorNet ---
Number of parameters: 4480
Epoch 1/200, Train Loss: 4.9201, Val Loss: 4.7620
  New best validation loss: 4.7620
Epoch 2/200, Train Loss: 4.5789, Val Loss: 4.4091
  New best validation loss: 4.4091
Epoch 3/200, Train Loss: 4.3282, Val Loss: 4.1753
  New best validation loss: 4.1753
Epoch 4/200, Train Loss: 4.0661, Val Loss: 3.9659
  New best validation loss: 3.9659
Epoch 5/200, Train Loss: 3.8658, Val Loss: 3.8017
  New best validation loss: 3.8017
Epoch 6/200, Train Loss: 3.7076, Val Loss: 3.6618
  New best validation loss: 3.6618
Epoch 7/200, Train Loss: 3.5907, Val Loss: 3.5625
  New best validation loss: 3.5625
Epoch 8/200, Train Loss: 3.5033, Val Loss: 3.4988
  New best validation loss: 3.4988
Epoch 9/200, Train Loss: 3.4445, Val Loss: 3.4559
  New best validation loss: 3.4559
Epoch 10/200, Train Loss: 3.3918, Val Loss: 3.4240
  New best val



  RidgeCV Decoder for ComplexOscillatorNet - Test MSE: 0.0605, Test R2: 0.9709 (best alpha: 1.2743)
Decodability (R2 score) for ComplexOscillatorNet: 0.9709

--- Training RNN_GRU ---
Number of parameters: 50561
Epoch 1/200, Train Loss: 4.8265, Val Loss: 5.2117
  New best validation loss: 5.2117
Epoch 2/200, Train Loss: 4.7924, Val Loss: 5.0764
  New best validation loss: 5.0764
Epoch 3/200, Train Loss: 4.7707, Val Loss: 5.1396
Epoch 4/200, Train Loss: 4.7394, Val Loss: 5.0988
Epoch 5/200, Train Loss: 4.7277, Val Loss: 5.1501
Epoch 6/200, Train Loss: 4.7182, Val Loss: 5.0424
  New best validation loss: 5.0424
Epoch 7/200, Train Loss: 4.7303, Val Loss: 5.1388
Epoch 8/200, Train Loss: 4.7113, Val Loss: 5.0679
Epoch 9/200, Train Loss: 4.7093, Val Loss: 5.0527
Epoch 10/200, Train Loss: 4.6897, Val Loss: 5.0944
Epoch 11/200, Train Loss: 4.6845, Val Loss: 5.0749
Epoch 12/200, Train Loss: 4.6761, Val Loss: 5.0849
Epoch 13/200, Train Loss: 4.6684, Val Loss: 5.0234
  New best validation loss: 5.



  RidgeCV Decoder for RNN_GRU - Test MSE: 0.0025, Test R2: 0.9988 (best alpha: 0.0018)
Decodability (R2 score) for RNN_GRU: 0.9988

--- Training Transformer ---
Number of parameters: 281345
Epoch 1/200, Train Loss: 6.3340, Val Loss: 5.1674
  New best validation loss: 5.1674
Epoch 2/200, Train Loss: 5.0713, Val Loss: 5.1637
  New best validation loss: 5.1637
Epoch 3/200, Train Loss: 4.9144, Val Loss: 5.0780
  New best validation loss: 5.0780
Epoch 4/200, Train Loss: 4.8872, Val Loss: 5.1059
Epoch 5/200, Train Loss: 4.8702, Val Loss: 5.2255
Epoch 6/200, Train Loss: 4.8611, Val Loss: 5.1043
Epoch 7/200, Train Loss: 4.8276, Val Loss: 5.0451
  New best validation loss: 5.0451
Epoch 8/200, Train Loss: 4.8400, Val Loss: 5.1266
Epoch 9/200, Train Loss: 4.7974, Val Loss: 5.0273
  New best validation loss: 5.0273
Epoch 10/200, Train Loss: 4.8014, Val Loss: 5.0947
Epoch 11/200, Train Loss: 4.7987, Val Loss: 5.2067
Epoch 12/200, Train Loss: 4.7881, Val Loss: 5.1057
Epoch 13/200, Train Loss: 4.7792



Decodability (R2 score) for Transformer: 0.9955

--- Training HIPPORNN_LegT ---
Number of parameters: 99716
Epoch 1/200, Train Loss: 5.0083, Val Loss: 5.1213
  New best validation loss: 5.1213
Epoch 2/200, Train Loss: 5.0074, Val Loss: 5.1210
  New best validation loss: 5.1210
Epoch 3/200, Train Loss: 5.0041, Val Loss: 5.1205
  New best validation loss: 5.1205
Epoch 4/200, Train Loss: 4.9989, Val Loss: 5.1176
  New best validation loss: 5.1176
Epoch 5/200, Train Loss: 4.9818, Val Loss: 5.1115
  New best validation loss: 5.1115
Epoch 6/200, Train Loss: 4.9218, Val Loss: 5.0782
  New best validation loss: 5.0782
Epoch 7/200, Train Loss: 4.7406, Val Loss: 4.9201
  New best validation loss: 4.9201
Epoch 8/200, Train Loss: 4.5687, Val Loss: 4.8765
  New best validation loss: 4.8765
Epoch 9/200, Train Loss: 4.3929, Val Loss: 4.7006
  New best validation loss: 4.7006
Epoch 10/200, Train Loss: 4.3576, Val Loss: 4.7061
Epoch 11/200, Train Loss: 4.1944, Val Loss: 4.4919
  New best validation los



  RidgeCV Decoder for HIPPORNN_LegT - Test MSE: 0.0044, Test R2: 0.9979 (best alpha: 0.0162)
Decodability (R2 score) for HIPPORNN_LegT: 0.9979

--- Training NMRNN_Spatial_ModReadout ---
Number of parameters: 66832
Epoch 1/200, Train Loss: 4.8398, Val Loss: 5.2257
  New best validation loss: 5.2257
Epoch 2/200, Train Loss: 4.7008, Val Loss: 4.9290
  New best validation loss: 4.9290
Epoch 3/200, Train Loss: 4.6273, Val Loss: 4.7932
  New best validation loss: 4.7932
Epoch 4/200, Train Loss: 4.5279, Val Loss: 4.7123
  New best validation loss: 4.7123
Epoch 5/200, Train Loss: 4.4619, Val Loss: 4.6278
  New best validation loss: 4.6278
Epoch 6/200, Train Loss: 4.3261, Val Loss: 4.4514
  New best validation loss: 4.4514
Epoch 7/200, Train Loss: 4.1500, Val Loss: 4.2777
  New best validation loss: 4.2777
Epoch 8/200, Train Loss: 3.9685, Val Loss: 4.2263
  New best validation loss: 4.2263
Epoch 9/200, Train Loss: 3.8286, Val Loss: 4.1152
  New best validation loss: 4.1152
Epoch 10/200, Train L



  RidgeCV Decoder for NMRNN_Spatial_ModReadout - Test MSE: 0.0548, Test R2: 0.9732 (best alpha: 2.6367)
Decodability (R2 score) for NMRNN_Spatial_ModReadout: 0.9732

--- Training NMRNN_NoSpatial_ModReadout ---
Number of parameters: 66832
Epoch 1/200, Train Loss: 4.8267, Val Loss: 5.2125
  New best validation loss: 5.2125
Epoch 2/200, Train Loss: 4.7041, Val Loss: 4.9662
  New best validation loss: 4.9662
Epoch 3/200, Train Loss: 4.6626, Val Loss: 4.9032
  New best validation loss: 4.9032
Epoch 4/200, Train Loss: 4.5697, Val Loss: 4.8280
  New best validation loss: 4.8280
Epoch 5/200, Train Loss: 4.4826, Val Loss: 4.6918
  New best validation loss: 4.6918
Epoch 6/200, Train Loss: 4.3409, Val Loss: 4.5045
  New best validation loss: 4.5045
Epoch 7/200, Train Loss: 4.1676, Val Loss: 4.3979
  New best validation loss: 4.3979
Epoch 8/200, Train Loss: 4.0104, Val Loss: 4.1805
  New best validation loss: 4.1805
Epoch 9/200, Train Loss: 3.8999, Val Loss: 4.0638
  New best validation loss: 4.06



  RidgeCV Decoder for NMRNN_NoSpatial_ModReadout - Test MSE: 0.0296, Test R2: 0.9856 (best alpha: 0.2976)
Decodability (R2 score) for NMRNN_NoSpatial_ModReadout: 0.9856

--- Training NMRNN_Spatial_FixedReadout ---
Number of parameters: 66449
Epoch 1/200, Train Loss: 4.9547, Val Loss: 5.1017
  New best validation loss: 5.1017
Epoch 2/200, Train Loss: 4.7681, Val Loss: 4.9848
  New best validation loss: 4.9848
Epoch 3/200, Train Loss: 4.6498, Val Loss: 4.9744
  New best validation loss: 4.9744
Epoch 4/200, Train Loss: 4.5583, Val Loss: 4.7323
  New best validation loss: 4.7323
Epoch 5/200, Train Loss: 4.4649, Val Loss: 4.6348
  New best validation loss: 4.6348
Epoch 6/200, Train Loss: 4.3317, Val Loss: 4.5233
  New best validation loss: 4.5233
Epoch 7/200, Train Loss: 4.1748, Val Loss: 4.4553
  New best validation loss: 4.4553
Epoch 8/200, Train Loss: 4.0285, Val Loss: 4.3778
  New best validation loss: 4.3778
Epoch 9/200, Train Loss: 3.8812, Val Loss: 4.1990
  New best validation loss: 



  RidgeCV Decoder for NMRNN_Spatial_FixedReadout - Test MSE: 0.1035, Test R2: 0.9497 (best alpha: 2.6367)
Decodability (R2 score) for NMRNN_Spatial_FixedReadout: 0.9497
Learning curves saved to results/20250508_082527/learning_curves.png

Learning curves plotted to results/20250508_082527/learning_curves.png

--- Decodability Results (R2 Score) ---
ComplexOscillatorNet: 0.9709
RNN_GRU: 0.9988
Transformer: 0.9955
HIPPORNN_LegT: 0.9979
NMRNN_Spatial_ModReadout: 0.9732
NMRNN_NoSpatial_ModReadout: 0.9856
NMRNN_Spatial_FixedReadout: 0.9497

Experiment finished. All results in results/20250508_082527


In [2]:
np.exp(-1.0 / 20.0)

0.951229424500714