# Classification of the Feature Neurons from the SNN
This notebook performs the classification of the feature neurons from the SNN according to certain criteria. In the future, it would be interesting to implement a 3rd layer in the SNN to classify the feature neurons automatically. For now, this manual classification will classify each neuron as one of the following:
- **Silent Neuron**: Neuron that does not fire at all.
- **Noisy Neuron**: Neuron that fires randomly or without a relevant pattern.
- **Ripple Neuron**: Neuron that fires in the presence of a ripple.
- **Fast Ripple Neuron**: Neuron that fires in the presence of a fast ripple.

## Check WD (change if necessary) and file loading

In [63]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)

# Check if the current WD is the file location
if "/src/hfo/snn" not in os.getcwd():
    # Set working directory to this file location
    file_location = f"{os.getcwd()}/thesis-lava/src/hfo/snn"
    print("File Location: ", file_location)

    # Change the current working Directory
    os.chdir(file_location)

    # New Working Directory
    print("New Working Directory: ", os.getcwd())

/home/monkin/Desktop/feup/thesis/thesis-lava/src/hfo/snn


## Define Controllable Parameters to load a specific example

In [64]:
from utils.input import MarkerType, band_to_file_name, BaselineAlgorithm, ModelDistStrategy

# Declare if using ripples, fast ripples, or both
chosen_band = MarkerType.RIPPLE     # RIPPLE, FAST_RIPPLE, or BOTH

# Specify the chosen Baseline Algorithm
chosen_baseline_alg_suffix = BaselineAlgorithm.Q3
# Select the Distribution Model Strategy
selected_strategy = ModelDistStrategy.LOG_NORMAL

sim_duration = 30000
weights_scale_input = 0.2
mean_dv = 20    # The mean of the Voltage Time Constant

# Range of random values to subtract from the excitatory time constants to get the inhibitory time constants
inh_subtract_range = [0.1, 2.0]  

# ----- Feature Neuron Classification Parameters -------
# Set the Relevant Ratio Threshold to consider a Feature Neuron Relevant
relevant_ratio_threshold = 85.0
# Number of neurons that must agree to classify an event as relevant
num_consensus_neurons = 1 # 2   

## Load the Voltage and Current Dynamics during the SNN run

#### Load the Baseline Thresholds from the output file from the baseline process

In [65]:
import numpy as np
from utils.io import preview_np_array

DATASET_FILENAME = "seeg_filtered_subset_90-119_segment500_200"

# Load the Baseline Thresholds
BASELINE_FILE = f"../signal_to_spike/baseline_results/{DATASET_FILENAME}_thresholds_{chosen_baseline_alg_suffix}.npy"
baseline_thresholds = np.load(BASELINE_FILE)

# preview_np_array(baseline_thresholds, "baseline_thresholds", edge_items=3)

baseline_ripple_thresh = round(baseline_thresholds[0], 4)
baseline_fr_thresh = round(baseline_thresholds[1], 4)

# For now, the UP and DN thresholds are the same (symmetric)
ripple_thresh_up = baseline_ripple_thresh
ripple_thresh_down = -baseline_ripple_thresh
fr_thresh_up = baseline_fr_thresh
fr_thresh_down = -baseline_fr_thresh

print("Ripple Thresholds: ", ripple_thresh_up, ripple_thresh_down)
print("FR Thresholds: ", fr_thresh_up, fr_thresh_down)

Ripple Thresholds:  3.7058 -3.7058
FR Thresholds:  1.4961 -1.4961


In [66]:
from utils.input import MarkerType

def band_to_thresholds(band: MarkerType):
    """Returns the ripple and FR thresholds for the given band

    Args:
        band (MarkerType): the band for which the thresholds are required
    """
    if band == MarkerType.RIPPLE:
        return ripple_thresh_up, ripple_thresh_down
    elif band == MarkerType.FAST_RIPPLE:
        return fr_thresh_up, fr_thresh_down
    else:
        raise Exception("Invalid band type")

In [67]:
band_file_name = band_to_file_name(chosen_band)

# Define the chosen thresholds
thresh_up, thresh_down = band_to_thresholds(chosen_band)
print("Thresholds: ", thresh_up, thresh_down)


INPUT_PATH = f"./results/{DATASET_FILENAME}"
TIME_SUFFIX = f"time0-{sim_duration}-1"
THRESH_SUFFIX = f"thresh{thresh_up}-{thresh_down}"
STRAT_SUFFIX = f"strat{selected_strategy}"
WEIGHT_SUFFIX = f"w{weights_scale_input}"
DV_SUFFIX = f"dv{mean_dv}"
INH_RANGE = f"inh{inh_subtract_range[0]}-{inh_subtract_range[1]}"

# Load the voltage and current data from the numpy files
voltage_file_name = f"{INPUT_PATH}/{band_file_name}_v_dynamics_{DV_SUFFIX}_{WEIGHT_SUFFIX}_{THRESH_SUFFIX}_{INH_RANGE}_{STRAT_SUFFIX}_{TIME_SUFFIX}.npy"
current_file_name = f"{INPUT_PATH}/{band_file_name}_u_dynamics_{DV_SUFFIX}_{WEIGHT_SUFFIX}_{THRESH_SUFFIX}_{INH_RANGE}_{STRAT_SUFFIX}_{TIME_SUFFIX}.npy"

v_dynamics = np.load(voltage_file_name)
u_dynamics = np.load(current_file_name)

preview_np_array(v_dynamics, "v_dynamics", edge_items=3)
preview_np_array(u_dynamics, "u_dynamics", edge_items=3)

Thresholds:  3.7058 -3.7058
v_dynamics Shape: (30000, 256).
Preview: [[0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 ...
 [1.10310488e-02 2.16181831e-04 6.10945174e-05 ... 1.52673559e-03
  6.22963398e-04 8.41596687e-04]
 [1.06875171e-02 2.05994858e-04 5.82757594e-05 ... 1.46808152e-03
  6.04487006e-04 7.93832608e-04]
 [1.03546797e-02 1.96287919e-04 5.55870523e-05 ... 1.41168083e-03
  5.86558604e-04 7.48779337e-04]]
u_dynamics Shape: (30000, 256).
Preview: [[0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.0000000

## Load the SNN Configuration and extract the fields
The SNN configuration contains:
- Numpy array with the ground_truth for each timestep
- Initial Time Offset
- Virtual Time Step Interval
- Number of Steps

In [68]:
from utils.snn import SNNSimConfig

# Load the SNN Config data
snn_config_file_name = f"{INPUT_PATH}/{band_file_name}_snn_config_{TIME_SUFFIX}.npy"

# Load the SNNConfig data as an element of the class SNNConfig
snn_config: SNNSimConfig = np.load(snn_config_file_name, allow_pickle=True).item()

# Extract the data fields from the SNNConfig object
ground_truth: np.ndarray = snn_config.ground_truth
init_offset = snn_config.init_offset
virtual_time_step_interval = snn_config.virtual_time_step_interval
num_steps = snn_config.num_steps

preview_np_array(ground_truth, "ground_truth", edge_items=3)
np.count_nonzero(ground_truth)

print("init_offset: ", init_offset)
print("virtual_time_step_interval: ", virtual_time_step_interval)
print("num_steps: ", num_steps)

ground_truth Shape: (222,).
Preview: [('Fast-Ripple',   1000.  , 0.)
 ('Spike+Ripple+Fast-Ripple',   3206.54, 0.)
 ('Spike+Ripple',   3521.  , 0.) ... ('Spike+Ripple', 116216.  , 0.)
 ('Ripple+Fast-Ripple', 116769.  , 0.) ('Ripple', 119000.  , 0.)]
init_offset:  0
virtual_time_step_interval:  1
num_steps:  30000


## Find the timesteps where the network spiked
Let's find the timesteps where the network spiked and create a `dictionary` mapping each feature neuron to the timesteps where it spiked.

In [69]:
from utils.data_analysis import find_spike_times

# Create a map storing the spike times for each feature neuron
neuron_spike_times = {}

# Call the find_spike_times util function that detects the spikes in a voltage array
spike_times_lif1 = find_spike_times(v_dynamics, u_dynamics)

total_spikes_count = 0
for (spike_time, neuron_idx) in spike_times_lif1:
    # Calculate the spike time in ms
    real_spike_time = init_offset + spike_time * virtual_time_step_interval

    # If the neuron index is not in the map (first time the feature neuron spikes), add it
    if neuron_idx not in neuron_spike_times:
        neuron_spike_times[neuron_idx] = [real_spike_time]
    # Otherwise, append the spike time to the list of spike times for that neuron
    else:
        neuron_spike_times[neuron_idx].append(real_spike_time)
    total_spikes_count += 1

    # print(f"Spike time: {real_spike_time} (iter. {spike_time}) at neuron: {neuron_idx}")

# Print the spike times for each feature neuron
print("neuron_spike_times: ", neuron_spike_times)

neuron_spike_times:  {25: [945, 956, 976, 987, 2346, 2620, 3161, 3180, 3192, 3223, 3235, 3245, 3247, 3256, 3281, 3293, 3303, 3305, 3508, 3554, 3567, 4008, 4128, 4137, 4140, 4151, 4312, 4582, 4718, 4798, 4847, 5009, 5032, 5181, 5211, 5354, 5438, 6116, 6174, 6182, 6655, 6668, 6677, 6687, 6697, 7240, 7247, 7592, 7594, 7605, 7631, 7963, 8036, 8061, 8074, 8078, 8091, 8095, 8115, 8832, 9117, 9138, 9579, 9665, 9686, 9723, 10019, 10066, 10186, 10208, 10228, 10523, 10532, 10540, 10544, 10995, 11068, 11210, 11361, 11434, 11468, 11638, 11640, 11652, 11665, 11677, 12093, 12112, 12132, 12146, 12155, 12166, 12177, 12199, 12212, 12224, 12235, 12737, 13109, 13199, 13211, 13224, 13444, 13561, 13563, 13574, 13586, 13599, 14598, 14605, 15238, 15244, 15253, 15259, 15446, 15549, 15567, 15754, 16039, 16101, 16129, 16141, 16154, 16194, 16442, 16773, 16785, 16798, 16810, 17003, 17027, 17769, 17903, 17991, 18136, 18144, 18773, 18785, 18796, 18809, 19062, 19064, 19074, 19086, 19088, 19099, 19112, 19763, 19770, 

## See how many Neurons are Silent before merging spikes belonging to the same event

In [70]:
# See how many neurons are silent during the whole simulation
# The dict neuron_spike_times contains the spike times for each neuron, therefore, if a neuron is not in the dict, it means it never spiked
num_feature_neurons = v_dynamics.shape[1]
silent_neurons = [neuron for neuron in range(num_feature_neurons) if neuron not in neuron_spike_times]

print(f"There are {len(silent_neurons)}/{num_feature_neurons} silent neurons")
print("Silent Neurons: ", silent_neurons)

print(f"There are {len(neuron_spike_times)} spiking neurons")

There are 25/256 silent neurons
Silent Neurons:  [6, 17, 22, 35, 56, 60, 63, 64, 67, 68, 80, 86, 94, 108, 121, 144, 176, 187, 191, 205, 209, 226, 229, 236, 249]
There are 231 spiking neurons


## Define the Time Window after the event insertion to consider as part of the event

In [71]:
from utils.input import RIPPLE_CONFIDENCE_WINDOW, FR_CONFIDENCE_WINDOW, BOTH_CONFIDENCE_WINDOW

# Define the time window after an HFO insertion to consider as part of the event
confidence_window = RIPPLE_CONFIDENCE_WINDOW
if chosen_band == MarkerType.FAST_RIPPLE:
    confidence_window = FR_CONFIDENCE_WINDOW
elif chosen_band == MarkerType.BOTH:
    confidence_window = BOTH_CONFIDENCE_WINDOW

## Merge the spikes that occur at a distance less than the confidence window
In order to calculate the efficiency of the feature neurons, we need to merge the spikes that occur at a distance less than the time window. Otherwise, we would be counting the same event multiple times, which would lead to an overestimation of the efficiency.

This is **probably** safe, since each annotated event occurs at a distance greater than the considered time window, usually at least 1 order of magnitude greater.

In [72]:
# Merge the spikes that occur at a distance less than the confidence window
merged_spikes_counter = 0
for neuron_key in neuron_spike_times.keys():
    spike_times = neuron_spike_times[neuron_key]

    # Array with the merged spike times
    merged_spike_times = [spike_times[0]]

    curr_idx = 1
    comparison_idx = 0
    while curr_idx < len(spike_times):
        spike_time = spike_times[curr_idx]
        prev_spike_time = spike_times[comparison_idx]

        if spike_time - prev_spike_time < confidence_window:
            # This spike is part of the same event
            curr_idx += 1
            merged_spikes_counter += 1
            continue
        else:
            # This spike is part of a new event
            merged_spike_times.append(spike_time)
            comparison_idx = curr_idx   # Update the comparison index to the new spike time
            curr_idx += 1

    # Update the neuron_spike_times map with the merged spike times
    neuron_spike_times[neuron_key] = merged_spike_times

# Print the spike times for each feature neuron
print("Merged neuron_spike_times: ", neuron_spike_times)

print(f"Merged a total of {merged_spikes_counter} spikes out of {total_spikes_count} spikes")

Merged neuron_spike_times:  {25: [945, 2346, 2620, 3161, 3281, 3508, 4008, 4128, 4312, 4582, 4718, 4847, 5009, 5181, 5354, 6116, 6655, 7240, 7592, 7963, 8091, 8832, 9117, 9579, 9723, 10019, 10186, 10523, 10995, 11210, 11361, 11638, 12093, 12224, 12737, 13109, 13444, 13574, 14598, 15238, 15446, 15567, 15754, 16039, 16194, 16442, 16773, 17003, 17769, 17903, 18136, 18773, 19062, 19763, 20169, 20485, 20670, 20809, 21008, 21782, 21985, 22111, 22242, 22614, 22747, 22874, 23017, 23695, 24194, 24671, 25067, 25782, 26218, 26628, 27013, 28232, 28556, 28729, 28877, 29189, 29415, 29542, 29882], 81: [946, 1540, 2152, 2344, 2620, 3162, 3282, 3508, 4008, 4137, 4312, 4446, 4582, 4719, 4848, 5009, 5181, 5353, 6117, 6655, 7240, 7592, 7964, 8091, 8650, 8833, 9110, 9580, 9709, 10020, 10186, 10523, 10996, 11210, 11361, 11628, 12101, 12223, 12703, 13060, 13199, 13444, 13574, 14069, 14593, 15238, 15446, 15566, 15755, 16040, 16165, 16442, 16774, 17003, 17768, 17903, 18139, 18774, 19056, 19763, 20486, 20670, 2

## Iterate over the SNN running time and classify the feature neurons

Create a numpy array with the classification of each feature neuron.

In [73]:
from utils.snn import NeuronClass

# Create numpy array containing the class of each feature neuron
feature_neuron_class = np.full(shape=(v_dynamics.shape[1]), fill_value=NeuronClass.SILENT)
preview_np_array(feature_neuron_class, "feature_neuron_class", edge_items=3)

feature_neuron_class Shape: (256,).
Preview: [0 0 0 ... 0 0 0]


### Keep track of the relevant events each feature neuron detected

In [74]:
# Define a map to keep track of the number of relevant events each feature neuron detected
relevant_neuron_spike_times = {}

# Create an empty list for each feature neuron with spikes
for neuron_idx in neuron_spike_times.keys():
    relevant_neuron_spike_times[neuron_idx] = []

print("relevant_neuron_spike_times: ", relevant_neuron_spike_times)

relevant_neuron_spike_times:  {25: [], 81: [], 152: [], 201: [], 49: [], 100: [], 134: [], 142: [], 196: [], 211: [], 213: [], 45: [], 141: [], 82: [], 132: [], 157: [], 198: [], 233: [], 13: [], 19: [], 36: [], 41: [], 90: [], 105: [], 183: [], 34: [], 51: [], 62: [], 87: [], 99: [], 131: [], 224: [], 244: [], 251: [], 169: [], 0: [], 29: [], 158: [], 16: [], 97: [], 252: [], 15: [], 40: [], 180: [], 188: [], 37: [], 147: [], 219: [], 53: [], 170: [], 218: [], 23: [], 197: [], 137: [], 44: [], 133: [], 38: [], 255: [], 207: [], 221: [], 11: [], 42: [], 120: [], 189: [], 153: [], 246: [], 9: [], 43: [], 46: [], 73: [], 103: [], 126: [], 149: [], 155: [], 160: [], 163: [], 165: [], 171: [], 172: [], 178: [], 202: [], 203: [], 208: [], 212: [], 214: [], 223: [], 240: [], 118: [], 122: [], 1: [], 3: [], 10: [], 24: [], 26: [], 30: [], 31: [], 69: [], 101: [], 112: [], 113: [], 128: [], 139: [], 173: [], 190: [], 195: [], 199: [], 232: [], 235: [], 238: [], 239: [], 4: [], 5: [], 47: [], 5

### Define method that searches for a relevant event in the ground_truth np array

In [75]:
# Declare a sorted np.array of the ground truth events by timestamp
ground_truth_timestamps = np.array([ann_event[1] for ann_event in ground_truth])

preview_np_array(ground_truth_timestamps, "ground_truth_timestamps", edge_items=3)

ground_truth_timestamps Shape: (222,).
Preview: [  1000.     3206.54   3521.   ... 116216.   116769.   119000.  ]


In [76]:
def has_relevant_event(spike_time, confidence_window) -> float | None:
    """
    Searches for a relevant event in the ground truth array given a timestamp.
    The annotated event must be located before the spike time, since the annotation
    corresponds to the insertion of the event.
    Thus, the Event must be in the window [spike_time - confidence_window, spike_time]

    Returns:
    - The insertion time of the relevant event if found
    - None if no relevant event was found
    """
    # Check if spike_time is greater than the first event in the ground truth. If not, return False
    if ground_truth_timestamps[0] > spike_time:
        return None

    for event_time in ground_truth_timestamps:
        if spike_time - confidence_window <= event_time <= spike_time:
            return event_time
        
        if event_time > spike_time:
            # Since the events are sorted, if the event_time is greater than the spike_time, we can stop looking
            return None
    

In [77]:
# Iterate over the feature neurons that spiked
for neuron_idx in neuron_spike_times.keys():
    curr_spike_times = neuron_spike_times[neuron_idx]
    
    # For each Spike
    for curr_spike_time in curr_spike_times:
        # Since the SNN will spike after the insertion of the event, we need to consider the confidence window before the SNN spike.
        # The Annotated event is located at the insertion time (before)

        # Check if the ground truth contains at least 1 annotated event within the confidence window
        if has_relevant_event(curr_spike_time, confidence_window) != None:
            # Add the event to the relevant neuron spike times
            relevant_neuron_spike_times[neuron_idx].append(curr_spike_time)

print("relevant_neuron_spike_times: ", relevant_neuron_spike_times)

relevant_neuron_spike_times:  {25: [3281, 4847, 7240, 7592, 8091, 9117, 9723, 10523, 11210, 11638, 12224, 12737, 13574, 14598, 15238, 15567, 16194, 17769, 18136, 19763, 20169, 20809, 21782, 22242, 22747, 23695, 24194, 24671, 25782, 26218, 28556, 29542], 81: [3282, 4137, 4848, 7240, 7592, 8091, 9110, 9709, 10523, 11210, 12223, 12703, 13199, 13574, 14069, 14593, 15238, 15566, 16165, 17768, 18139, 19763, 20799, 21871, 22747, 23696, 24200, 24671, 26219, 27635, 28544, 29543], 152: [3294, 4138, 4848, 7241, 7606, 8092, 9140, 9724, 10524, 11210, 11640, 12224, 12744, 13212, 13587, 14600, 15244, 15567, 16195, 16787, 18142, 18786, 19064, 19771, 20820, 21871, 22260, 22747, 23696, 24200, 24702, 25785, 26219, 28245, 28571, 29544], 201: [4138, 7241, 7618, 8092, 9147, 10525, 11211, 11642, 12236, 14600, 15244, 15568, 16215, 18143, 18797, 19075, 19776, 20835, 21872, 23696, 24201, 24702, 25188, 25791, 26219, 28246, 29544], 49: [3305, 4138, 7242, 7595, 8096, 9142, 9725, 10524, 11211, 11640, 12224, 13212, 

## Statistics regarding the spiking neurons

In [78]:
# Create an array containing the relevant spike ratio for each feature neuron 
# in the order of their keys
relevant_spike_ratios = []
for neuron_idx in neuron_spike_times.keys():
    # Get all the spikes of the current neuron
    curr_spike_times = neuron_spike_times[neuron_idx]
    # Get the relevant spikes of the current neuron
    relevant_spike_times = relevant_neuron_spike_times[neuron_idx]

    # Calculate the ratio of relevant spikes
    relevant_spike_ratio = (len(relevant_spike_times) / len(curr_spike_times)) * 100 if len(curr_spike_times) > 0 else 0.0
    relevant_spike_ratios.append(relevant_spike_ratio)

In [79]:
PLOT_RELEVANCY_STATS = True

if PLOT_RELEVANCY_STATS:
    # Iterate over the feature neurons that spiked
    for iter_idx, neuron_idx in enumerate(neuron_spike_times.keys()):
        # Get all the spikes of the current neuron
        curr_spike_times = neuron_spike_times[neuron_idx]

        # Get the relevant spikes of the current neuron
        relevant_spike_times = relevant_neuron_spike_times[neuron_idx]

        print(f"Neuron {neuron_idx}: ")
        print(f"Spikes: {curr_spike_times}")
        print(f"Relevant Spikes: {relevant_spike_times}")
        print(f"Relevant Spikes Ratio: {relevant_spike_ratios[iter_idx]}%")
        print("================================================================\n")

Neuron 25: 
Spikes: [945, 2346, 2620, 3161, 3281, 3508, 4008, 4128, 4312, 4582, 4718, 4847, 5009, 5181, 5354, 6116, 6655, 7240, 7592, 7963, 8091, 8832, 9117, 9579, 9723, 10019, 10186, 10523, 10995, 11210, 11361, 11638, 12093, 12224, 12737, 13109, 13444, 13574, 14598, 15238, 15446, 15567, 15754, 16039, 16194, 16442, 16773, 17003, 17769, 17903, 18136, 18773, 19062, 19763, 20169, 20485, 20670, 20809, 21008, 21782, 21985, 22111, 22242, 22614, 22747, 22874, 23017, 23695, 24194, 24671, 25067, 25782, 26218, 26628, 27013, 28232, 28556, 28729, 28877, 29189, 29415, 29542, 29882]
Relevant Spikes: [3281, 4847, 7240, 7592, 8091, 9117, 9723, 10523, 11210, 11638, 12224, 12737, 13574, 14598, 15238, 15567, 16194, 17769, 18136, 19763, 20169, 20809, 21782, 22242, 22747, 23695, 24194, 24671, 25782, 26218, 28556, 29542]
Relevant Spikes Ratio: 38.55421686746988%

Neuron 81: 
Spikes: [946, 1540, 2152, 2344, 2620, 3162, 3282, 3508, 4008, 4137, 4312, 4446, 4582, 4719, 4848, 5009, 5181, 5353, 6117, 6655, 7240, 

## Plot the Relevant Spike Ratio of the Feature Neurons that Spike

In [80]:
# Create a bar plot containing the % of relevant spikes for each feature neuron 
from utils.bar_plot import create_bar_fig  # Import the function to create the figure

# Define the x and y values
neuron_label = [f"Neu. {neuron_idx}" for neuron_idx in neuron_spike_times.keys()]

# sorting the bars means sorting the range factors
neurons_descending_rel_ratio = sorted(neuron_label, key=lambda x: relevant_spike_ratios[neuron_label.index(x)], reverse=True)

# Create the LIF1 Voltage
feat_neurons_relevancy_plot = create_bar_fig(
    title="Feature Neurons Relevant Spikes Ratio", 
    x_axis_label='Feature Neuron Index', 
    y_axis_label='Relevant Spike Ratio (%)',
    x=neuron_label,
    y=relevant_spike_ratios,
    x_range=neurons_descending_rel_ratio,
    sizing_mode="stretch_width",
    bar_width=0.5,
    tooltips=[("Neuron Index", "@x"), ("Relevant Spike Ratio", "@top%")]
)

In [96]:
import bokeh.plotting as bplt

showPlot = True
if showPlot:
    # Show the plot
    bplt.show(feat_neurons_relevancy_plot)

In [94]:
EXPORT_RELEVANCY_PLOT = False
OUTPUT_FOLDER = f"./neuron_classification/{DATASET_FILENAME}"

if EXPORT_RELEVANCY_PLOT:
    # Create the output folder if it does not exist
    if not os.path.exists(OUTPUT_FOLDER):
        os.makedirs(OUTPUT_FOLDER)

    file_path = f"{OUTPUT_FOLDER}/{band_file_name}_relevancy_barplot_{DV_SUFFIX}_{WEIGHT_SUFFIX}_{THRESH_SUFFIX}_{INH_RANGE}_{STRAT_SUFFIX}_{TIME_SUFFIX}.html"

    # Customize the output file settings
    bplt.output_file(filename=file_path, title="Feature Neurons Relevancy (%) Bar Plot")

    # Save the plot
    bplt.save(feat_neurons_relevancy_plot)

## Classify the Feature Neurons according to their Spiking Activity
In this step, we will update the `feature_neuron_class` array with the classification of the `Noisy Neurons`, `Ripple Neurons` and `Fast Ripple Neurons`.

We set a **threshold of 0.9** for the `Relevant Spike Ratio` to classify a neuron as a `Ripple` or `Fast Ripple` Neuron.

In [83]:
from utils.snn import marker_type_to_neuron_class

# Iterate over the feature neurons that spiked
for iter_idx, neuron_idx in enumerate(neuron_spike_times.keys()):
    # Get all the spikes of the current neuron
    curr_spike_times = neuron_spike_times[neuron_idx]
    # Get the relevant spikes of the current neuron
    relevant_spike_times = relevant_neuron_spike_times[neuron_idx]
    # Get the relevant spikes ratio of the current neuron
    relevant_spike_ratio = relevant_spike_ratios[iter_idx]

    if relevant_spike_ratio >= relevant_ratio_threshold:
        # Classify the Neuron as a RIPPLE/FAST RIPPLE/BOTH detector based on the band we are analyzing
        feature_neuron_class[neuron_idx] = marker_type_to_neuron_class(chosen_band)
    else:
        # Classify the Neuron as Noisy
        feature_neuron_class[neuron_idx] = NeuronClass.NOISY

### Show the classification of the Feature Neurons

In [84]:
preview_np_array(feature_neuron_class, "feature_neuron_class", edge_items=3)

feature_neuron_class Shape: (256,).
Preview: [1 1 1 ... 2 2 1]


# Calculate the Classification Metrics of the SNN based on the Feature Neurons
As of right now, we have a Spiking Neural Network that updates it's state in real-time. This network is capable of detecting relevant oscillatory events through spikes.

Therefore, we have to define the conditions that identify the detection of a HFO event or not. Since each Feature Neuron is modelled to be tuned to a specific frequency range, not all of them will fire simultaneously in the presence of a HFO event. Hence, we need to define a threshold for the number of Feature Neurons that need to fire in order to classify an event as a HFO. This is an exploratory approach, but as a starting point, let's say that we need 2 neurons firing during the time window (`confidence_window`).

In [85]:
# Initialize the variables that store the values of the Confusion Matrix
true_positive = 0
false_positive = 0
true_negative = 0
false_negative = 0

### Build a numpy array containing all the spike times from the `Detector Neurons` ordered by spiking time

In [86]:
# Find the index of the Feature Neurons that are HFO(Ripple/FR/Both) Detectors
detector_indices = np.where(feature_neuron_class == marker_type_to_neuron_class(chosen_band))[0]
preview_np_array(detector_indices, "detector_indices")

detector_indices Shape: (58,).
Preview: [ 12  18  27  32  48 ... 242 245 247 253 254]


In [87]:
# Build an ordered list of spike times from all the HFO Detector Neurons
detector_spike_times = []

for neuron_idx in detector_indices:
    # print(f"Adding spikes from neuron {neuron_idx}: {neuron_spike_times[neuron_idx]}")
    detector_spike_times.extend(neuron_spike_times[neuron_idx])

# Sort the list of spike times
detector_spike_times.sort()

detector_spike_times = np.array(detector_spike_times)
preview_np_array(detector_spike_times, "detector_spike_times")

detector_spike_times Shape: (762,).
Preview: [ 3245  3246  3246  3256  3257 ... 29200 29200 29200 29544 29544]


In [88]:
import math

curr_detector_idx = 0   # Index of the current detector neuron spike
curr_ground_truth_idx = 0   # Index of the current ground truth event in the ordered array

# Stores an active annotation. If None, there is no active annotation
# An annotation is active in the interval [annotation_start, annotation_start + confidence_window]
active_annotation = None    # Timestamp of the active annotation   

# Go through the voltage dynamics array and calculate the classification metrics of the network
for run_step, feat_neurons_v in enumerate(v_dynamics):
    is_end_detector_spikes = curr_detector_idx >= len(detector_spike_times)
    is_end_ground_truth = curr_ground_truth_idx >= len(ground_truth)

    # Block that checks if a new annotation gets activated
    if not is_end_ground_truth:
        # Check if the next annotation was inserted at this run_step
        next_annotation_insertion = math.ceil(ground_truth_timestamps[curr_ground_truth_idx])
        if next_annotation_insertion == run_step:
            if active_annotation != None:
                raise Exception("There is already an active annotation. This should not happen.")

            # Start a new active annotation
            active_annotation = next_annotation_insertion

            # Increment the curr_ground_truth_idx since this annotation will be processed
            curr_ground_truth_idx += 1

    # If there are still spikes from the Detector Neurons
    if not is_end_detector_spikes:
        next_spike_iter = detector_spike_times[curr_detector_idx]
        # If the next spike occurs in the current run_step interval, let's check what kind of prediction it is
        if next_spike_iter == run_step:
            # Increment the curr_detector_idx since this spike will be processed
            curr_detector_idx += 1

            # Check if there is an annotated event in the ground_truth within the confidence window [run_step, run_step-confidence_window]
            if active_annotation != None:
                # If there is a relevant event, let's see if the SNN predicted it.
                num_votes = 1   # The first vote is the current spike (SNN needs num_consensus_neurons to classify the event as relevant)

                # Calculate the window of detection for this event
                # The annotated events are floats, so we need to ceil the end_relevant_event_time because 
                # we are comparing with the integer run_step (iteration index of the simulation) (TODO: Include the virtual_step in the calculation)
                end_relevant_event_time = math.ceil(active_annotation + confidence_window)

                # Check if there are more spikes in the confidence window
                while curr_detector_idx < len(detector_spike_times) and detector_spike_times[curr_detector_idx] <= end_relevant_event_time:
                    num_votes += 1
                    curr_detector_idx += 1

                # If the number of votes is greater than the consensus threshold, classify the event as relevant
                if num_votes >= num_consensus_neurons:
                    # The SNN detected the event -> True Positive
                    true_positive += 1
                else:
                    # The SNN didn't detect the event (Not enough votes) -> False Negative
                    # False Negatives may also occur without any spikes in the network
                    false_negative += 1

                # The current annotation was processed, so we can deactivate it
                active_annotation = None
            else:
                # If there is no annotated event in the confidence window, let's check if the SNN predicted one.
                # If the SNN predicted an event, it is a False Positive
                # Calculate the window that would cause an SNN Detection
                end_relevant_event_iter = next_spike_iter + confidence_window

                # Verify if there isn't an annotated event between the current run_step and the end_relevant_event_iter
                # If there is an annotated event in between, we need to modify the end_relevant_event_iter to the annotated event time-1
                middle_relevant_event_time = has_relevant_event(end_relevant_event_iter, confidence_window)
                if middle_relevant_event_time != None:
                    # If there is an annotated event in the confidence window, we need to update the end of the window
                    # to the timestep before the annotated event, since we are analyzing the prediction of a False Positive
                    end_relevant_event_iter = math.ceil(middle_relevant_event_time) - 1 

                # Check if there are more spikes in the confidence window
                while curr_detector_idx < len(detector_spike_times) and detector_spike_times[curr_detector_idx] <= end_relevant_event_iter:
                    num_votes += 1
                    curr_detector_idx += 1

                # If the number of votes is greater than the consensus threshold
                if num_votes >= num_consensus_neurons:
                    # The SNN detected an event that does not exist -> False Positive
                    false_positive += 1
                else:
                    # The SNN didn't detect any event -> True Negative
                    true_negative += 1
            
            # Go to the next iteration
            continue
    
    # If one of the 2 conditions is true:
    #   1. There are no more spikes from the Detector Neurons
    #   2. The current run_step does not contain a spike from the Detector Neurons
    # Then: let's check if the active annotation is still active and report a True Negative or False Negative
    # according to the answer.
    
    # Block that checks if the active annotation is still active
    if active_annotation != None:
        if run_step >= active_annotation + confidence_window + 1:   # The +1 is to consider an extra iteration because the annotation is a float ceiled
            # The active annotation is over
            active_annotation = None

            # The SNN didn't detect the event in the interval [active_annotation, active_annotation + confidence_window]
            # Thus, it is a False Negative
            false_negative += 1
        else:
            # The active annotation is still active, so it is a True Negative (No event detected)
            true_negative += 1
    else:
        # If there is no active annotation, the SNN didn't detect any event rightfully so
        # So, it is a True Negative
        true_negative += 1


    # Edge case: If there is an annotated event at a distance < confidence_window from the end of the simulation, the feature events
    # won't have time to detect it. Thus, we don't consider that event.
    # TODO: Consider this edge case?

## Show the Confusion Matrix

In [89]:
# Print the Confusion Matrix
print("True Positive: ", true_positive)
print("False Positive: ", false_positive)
print("True Negative: ", true_negative)
print("False Negative: ", false_negative)

# Print the Total of predictions
total_predictions = true_positive + false_positive + true_negative + false_negative
print("Total Predictions: ", total_predictions)

True Positive:  35
False Positive:  5
True Negative:  29943
False Negative:  17
Total Predictions:  30000


## Calculate Classification Metrics

In [90]:
# Calculate relevant metrics
if true_positive + false_positive == 0:
    print("No relevant predictions were made. Cannot calculate metrics.")
    exit()

accuracy = (true_positive + true_negative) / total_predictions * 100    # Proportion of correct predictions
precision = true_positive / (true_positive + false_positive) * 100      # Proportion of TPs that were identified correctly
recall = true_positive / (true_positive + false_negative) * 100         # Proportion of TPs that were captured by the model
f1_score = (2 * precision * recall) / (precision + recall)              # Harmonic mean of Precision and Recall
specificity = true_negative / (true_negative + false_positive) * 100    # Proportion of TNs that were identified correctly

print(f"Accuracy: {accuracy}%")
print(f"Precision: {precision}%")
print(f"Recall (Sensitivity): {recall}%")
print(f"F1 Score: {f1_score}")
print(f"Specificity: {specificity}%")

Accuracy: 99.92666666666666%
Precision: 87.5%
Recall: 67.3076923076923%
F1 Score: 76.08695652173913
Specificity: 99.98330439428342%


# Export the results of the Classification to a JSON file
Export the results of the classification to a JSON file. This file will include:
- Frequency Band used (`Ripple`, `Fast Ripple` or `Both`)
- `Confidence Window` used
- Classification of each Feature Neuron (`Silent`, `Noisy`, `Ripple`, `Fast Ripple`, `Both`)
- Number of `Consensus Feature Neurons` to classify an event as a HFO
- Classification Metrics (`True Positives`, `False Positives`, `True Negatives`, `False Negatives`, `Accuracy`, `Precision`, `Recall`, `F1 Score`, `Specificity`)

In [91]:
# Calculate the number of feature neurons of each class
num_silent_neurons = np.count_nonzero(feature_neuron_class == NeuronClass.SILENT)
num_noisy_neurons = np.count_nonzero(feature_neuron_class == NeuronClass.NOISY)
num_detector_neurons = np.count_nonzero(feature_neuron_class == marker_type_to_neuron_class(chosen_band))

print(f"Number of Silent Neurons: {num_silent_neurons}")
print(f"Number of Noisy Neurons: {num_noisy_neurons}")
print(f"Number of Detector Neurons: {num_detector_neurons}")
print(f"Total: {num_silent_neurons + num_noisy_neurons + num_detector_neurons}")

Number of Silent Neurons: 25
Number of Noisy Neurons: 173
Number of Detector Neurons: 58
Total: 256


In [95]:
import json

# Export the results to a JSON file

# Create a dictionary with the results
json_results = {
    "freq_band": band_file_name,
    "confidence_window": confidence_window,
    "feature_neuron_class": {
        "array": feature_neuron_class.tolist(),
        "counts": {
            "silent_neurons": num_silent_neurons,
            "noisy_neurons": num_noisy_neurons,
            "detector_neurons": num_detector_neurons
        },
    },
    "num_consensus_neurons": num_consensus_neurons,
    "metrics": {
        "true_positive": true_positive,
        "false_positive": false_positive,
        "true_negative": true_negative,
        "false_negative": false_negative,
        "total_predictions": total_predictions,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "specificity": specificity
    }
}

CLASSIF_SUFFIX = f"c_{relevant_ratio_threshold}_{num_consensus_neurons}"

EXPORT_JSON_FILE = False
if EXPORT_JSON_FILE:
    json_file_name = f"{OUTPUT_FOLDER}/{band_file_name}_results_{DV_SUFFIX}_{WEIGHT_SUFFIX}_{THRESH_SUFFIX}_{INH_RANGE}_{STRAT_SUFFIX}_{TIME_SUFFIX}_{CLASSIF_SUFFIX}.json"
    with open(json_file_name, 'w') as f:
        json.dump(json_results, f)

## See the Characteristics of the Relevant Feature Neurons

In [93]:
# TODO: Export Characteristics of the neurons from the `hfo_detection` file to the JSON file so that we can analyze the results