This script computes the electrophysiological features of a single model prediction for a given stimulus amplitude from $\texttt{NOBLE}$ using eFEL.

In [1]:
import numpy as np
import torch, efel

from utils.input_builder import build_input, build_input_with_embeddings, extract_scaled_e_features
from utils.model_utils import load_model



**Step 1: Set up parameters**

In [2]:
device    = "cpu"
amplitude = 0.5

features_to_embed = ["slope", "intercept"]

embedding_config = {"sine_embeddings_freq": 9,
                    "scale_sine_embeddings": "freq",
                    "amplitude_embeddings": False,
                    "hof_model_embeddings": 1,
                    "e_features_to_embed": features_to_embed}

**Step 2: Load electrophysiological features**

In [None]:
normalised_features = extract_scaled_e_features(neuron_identifier="PVALB_689331391", 
                                                path_to_features='../data/e_features/pvalb_689331391_ephys_sim_features.csv', 
                                                features_to_embed=features_to_embed)

**Step 3: Load trained $\texttt{NOBLE}$ model**

In [None]:
model_path = 'noble_models/FNO_nmodes-256_in-23_out-1_nlayers-12_projectionratio-4_hc-24_AmpEmbeddings-FreqScaledNeRF-nfreq-9_HoFEmbeddings-FreqScaledNeRF-nfreq-1_bestepoch-296.pth'
model = load_model(model_path, device)

Loading FNO model with modes=256, in_channels=23, out_channels=1, nlayers=12, projection_ratio=4,  hidden_channels=24, device=cpu.



**Step 4: Sample a single model from latent space**

In this example script, we choose $\text{HoF}_0$, a known model from $\{\text{HoF}^{train}\}$

In [5]:
sampled_hof_model    = np.array([0])
features_train       = normalised_features[normalised_features["hof_model"].isin(sampled_hof_model)]

sampled_models = {"intercept": torch.tensor(features_train["intercept"].values, dtype=torch.float32),
                  "slope": torch.tensor(features_train["slope"].values, dtype=torch.float32)}

**Step 5: Build input with sampled model embeddings**

In [6]:
num_samples = len(features_train)
input_batch = build_input(amplitude, num_samples, device)
input_batch_transformed = build_input_with_embeddings(input_batch=input_batch, 
                                                        embedding_config=embedding_config, 
                                                        features_to_embed=features_to_embed, 
                                                        normalised_features=normalised_features, 
                                                        device=device,
                                                        sampled_embeddings=sampled_models)

**Step 6: Generate ensemble predictions**

In [7]:
with torch.no_grad():
    output_batch = model(input_batch_transformed)
    output_batch = output_batch.squeeze(0).squeeze(0).cpu().detach().numpy()

**Step 7: Compute electrophysiological features**

In [8]:
spiking_features = ['AHP1_depth_from_peak', 'AHP_depth', 'AHP_time_from_peak', 'AP1_peak',
                         'AP1_width', 'Spikecount', 'decay_time_constant_after_stim', 'depol_block',
                         'inv_first_ISI', 'mean_AP_amplitude', 'steady_state_voltage',
                         'steady_state_voltage_stimend', 'time_to_first_spike', 'voltage_base']

non_spiking_features = ['decay_time_constant_after_stim', 'sag_amplitude', 'steady_state_voltage',
                        'steady_state_voltage_stimend', 'voltage_base']

## Contruct an efel trace where time is in ms and voltage is in mV
trace = {
    'T': np.linspace(0, 515, input_batch.shape[2]),
    'V': output_batch * 1000,
    'stim_start': [15],
    'stim_end': [415]
}

spikecount_result = efel.get_feature_values([trace], ['Spikecount'])[0]

if efel.get_feature_values([trace], ['Spikecount']) == 0:
    features = efel.get_feature_values([trace], non_spiking_features)[0]
else:
    features = efel.get_feature_values([trace], spiking_features)[0]

print(f"{'Feature':<35} {'Mean Value':>15}")
print("-" * 50)
for feature, value in features.items():
    print(f"{feature:<35} {np.mean(value):>15}")

Feature                                  Mean Value
--------------------------------------------------
AHP1_depth_from_peak                72.23538732110428
AHP_depth                           -8.259718551092535
AHP_time_from_peak                  2.466666666666754
AP1_peak                            -8.698474705072565
AP1_width                           43.77268163923692
Spikecount                                      6.0
decay_time_constant_after_stim      29.91447540812027
depol_block                                     1.0
inv_first_ISI                       23.36448598130973
mean_AP_amplitude                   41.0570774175681
steady_state_voltage                -70.63276954969625
steady_state_voltage_stimend        -68.65667398493909
time_to_first_spike                 106.39999999999738
voltage_base                        -71.66765210375739
