This notebook computes the electrophysiological features of an ensemble of model predictions for a given stimulus amplitude using $\texttt{NOBLE}$ and compares them with the features of the experimental data. If the experimental trace spikes, the $\texttt{NOBLE}$ traces are selected based on the spikecount, otherwise the non-spiking traces are selected. We define a mismatch if the experimental trace spikes and the $\texttt{NOBLE}$ traces do not spike, or the experimental trace does not spike and the $\texttt{NOBLE}$ traces spike.

For the `PVALB 689331391` neuron, from the experimental data provided in `data/experimental/PVALB/PVALB__689331391/Long Square.npy`:
- Amplitudes which elicit firing = 0.2nA, 0.6nA, 1.0nA, 1.8nA
- Amplitudes which do not elicit firing = -1.1nA, -0.9nA, -0.7nA,-0.5nA, -0.3nA

In [1]:
import numpy as np
import torch, efel

from utils.input_builder import build_input, build_input_with_embeddings, extract_scaled_e_features
from utils.model_utils import load_model
from utils.compare_features import compare_experiment_and_noble_features



**Step 1: Set up parameters**

In [2]:
device    = "cpu"
amplitude = 1.8 

features_to_embed = ["slope", "intercept"]

embedding_config = {"sine_embeddings_freq": 9,
                    "scale_sine_embeddings": "freq",
                    "amplitude_embeddings": False,
                    "hof_model_embeddings": 1,
                    "e_features_to_embed": features_to_embed}

**Step 2: Load electrophysiological features**

In [3]:
normalised_features = extract_scaled_e_features(neuron_identifier="PVALB_689331391", 
                                                path_to_features='../data/e_features/pvalb_689331391_ephys_sim_features.csv', 
                                                features_to_embed=features_to_embed)

**Step 3: Load trained $\texttt{NOBLE}$ model**

In [4]:
model_path = 'noble_models/FNO_nmodes-256_in-23_out-1_nlayers-12_projectionratio-4_hc-24_AmpEmbeddings-FreqScaledNeRF-nfreq-9_HoFEmbeddings-FreqScaledNeRF-nfreq-1_bestepoch-296.pth'
model = load_model(model_path, device)

Loading FNO model with modes=256, in_channels=23, out_channels=1, nlayers=12, projection_ratio=4,  hidden_channels=24, device=cpu.



**Step 4: Collect embeddings from $\{\text{HoF}^{train}\}$ models** 

In [5]:
known_hof_models     = np.arange(0, 60)
hof_test             = np.array([32, 2, 12, 40, 5, 52, 21, 29, 16, 37])
hof_train            = np.setdiff1d(known_hof_models, hof_test)
features_train       = normalised_features[normalised_features["hof_model"].isin(hof_train)]

sampled_models = {"intercept": torch.tensor(features_train["intercept"].values, dtype=torch.float32),
                  "slope": torch.tensor(features_train["slope"].values, dtype=torch.float32)}

**Step 5: Build input with sampled model embeddings**

In [6]:
num_samples = len(features_train)
input_batch = build_input(amplitude, num_samples, device)
input_batch_transformed = build_input_with_embeddings(input_batch=input_batch, 
                                                      embedding_config=embedding_config, 
                                                      features_to_embed=features_to_embed, 
                                                      normalised_features=normalised_features, 
                                                      device=device,
                                                      sampled_embeddings=sampled_models)

**Step 6: Generate ensemble predictions**

In [7]:
with torch.no_grad():
    output_batch = model(input_batch_transformed)
    output_batch = output_batch.squeeze(1).cpu().detach().numpy()

**Step 7: Compute electrophysiological features**

In [8]:
spiking_features = ['AHP1_depth_from_peak', 'AHP_depth', 'AHP_time_from_peak', 'AP1_peak',
                    'AP1_width', 'Spikecount', 'mean_frequency', 'decay_time_constant_after_stim', 'depol_block',
                    'inv_first_ISI', 'mean_AP_amplitude', 'steady_state_voltage',
                    'steady_state_voltage_stimend', 'time_to_first_spike', 'voltage_base']

non_spiking_features = ['decay_time_constant_after_stim', 'sag_amplitude', 'steady_state_voltage',
                        'steady_state_voltage_stimend', 'voltage_base']

##Â Compute features of NOBLE model predictions
## Contruct an efel trace where time is in ms and voltage is in mV
time = np.linspace(0, 515, input_batch.shape[2])
traces = []
for output_sample in output_batch:
    trace = {
        'T': time,
        'V': output_sample * 1000,
        'stim_start': [15],
        'stim_end': [415]
    }
    traces.append(trace)

## Compute features of experimental data
experimental_data       = np.load('../data/experimental/PVALB/PVALB__689331391/Long Square.npy', allow_pickle=True)
experimental_amplitudes = [round(amp['stimulus_amplitude'] * 1e9, 1) for amp in experimental_data]
trial_idx               = np.where(np.array(experimental_amplitudes) == amplitude)[0][0]
experimental_trace      = experimental_data[trial_idx]

experimental_time = experimental_trace['time'] * 1000
experimental_stim = experimental_trace['stimulus'] * 1e9
experimental_resp = experimental_trace['response'] * 1000

experimental_trace = {
    'T': experimental_time,
    'V': experimental_resp,
    'stim_start': [1000],
    'stim_end': [2000]
}

**Step 8: Compare $\texttt{NOBLE}$ and experimental features**

In [9]:
selected_feature_names, experimental_values, noble_median, noble_q1, noble_q3, selection_counts = compare_experiment_and_noble_features(
    traces=traces,
    experimental_trace=experimental_trace,
    spiking_features=spiking_features,
    non_spiking_features=non_spiking_features,
)

print(f"{'Feature':<35} {'Median Experiment':>20} {'Median NOBLE':>18} {'IQR NOBLE (Q1-Q3)':>24}")
print("-" * 105)
for name in selected_feature_names:
    exp_val   = experimental_values.get(name, np.nan)
    nob_med   = noble_median.get(name, np.nan)
    q1        = noble_q1.get(name, np.nan)
    q3        = noble_q3.get(name, np.nan)
    iqr_range = f"{q1:.3f} - {q3:.3f}" if not (np.isnan(q1) or np.isnan(q3)) else "nan - nan"
    print(f"{name:<35} {exp_val:>20.3f} {nob_med:>18.3f} {iqr_range:>24}")

print("\nNOBLE traces used:", selection_counts["used"], f"| Excluded (mismatch spiking): {selection_counts['excluded']}")

Feature                                Median Experiment       Median NOBLE        IQR NOBLE (Q1-Q3)
---------------------------------------------------------------------------------------------------------
AHP1_depth_from_peak                              94.094             73.118          68.861 - 79.940
AHP_depth                                         -4.830             -6.476          -8.729 - -3.530
AHP_time_from_peak                                 2.172              2.430            2.013 - 2.757
AP1_peak                                          21.000             -3.268           -7.004 - 0.563
AP1_width                                          0.472              0.579            0.532 - 0.637
mean_frequency                                    69.323             57.286          30.045 - 68.082
decay_time_constant_after_stim                    12.371             14.915           6.550 - 34.122
depol_block                                        1.000              1.000           