# Classification of the Feature Neurons from the SNN
This notebook performs the classification of the feature neurons from the SNN according to certain criteria. In the future, it would be interesting to implement a 3rd layer in the SNN to classify the feature neurons automatically. For now, this manual classification will classify each neuron as one of the following:
- **Silent Neuron**: Neuron that does not fire at all.
- **Noisy Neuron**: Neuron that fires randomly or without a relevant pattern.
- **Ripple Neuron**: Neuron that fires in the presence of a ripple.
- **Fast Ripple Neuron**: Neuron that fires in the presence of a fast ripple.

## Check WD (change if necessary) and file loading

In [2]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)

# Check if the current WD is the file location
if "/src/hfo/snn" not in os.getcwd():
    # Set working directory to this file location
    file_location = f"{os.getcwd()}/thesis-lava/src/hfo/snn"
    print("File Location: ", file_location)

    # Change the current working Directory
    os.chdir(file_location)

    # New Working Directory
    print("New Working Directory: ", os.getcwd())

/home/monkin/Desktop/feup/thesis
File Location:  /home/monkin/Desktop/feup/thesis/thesis-lava/src/hfo/snn
New Working Directory:  /home/monkin/Desktop/feup/thesis/thesis-lava/src/hfo/snn


## Load the Voltage and Current Dynamics during the SNN run

In [3]:
import numpy as np
from utils.io import preview_np_array
from utils.input import MarkerType, band_to_file_name

# Declare if using ripples, fast ripples, or both
chosen_band = MarkerType.FAST_RIPPLE     # RIPPLE, FAST_RIPPLE, or BOTH
band_file_name = band_to_file_name(chosen_band)

INPUT_PATH = f"./results/custom_subset_90-119_segment500_200"
TIME_SUFFIX = "time0-120000-1"

# Load the voltage and current data from the numpy files
voltage_file_name = f"{INPUT_PATH}/{band_file_name}_v_dynamics_0.07dv_5ch_{TIME_SUFFIX}.npy"
current_file_name = f"{INPUT_PATH}/{band_file_name}_u_dynamics_0.07dv_5ch_{TIME_SUFFIX}.npy"

v_dynamics = np.load(voltage_file_name)
u_dynamics = np.load(current_file_name)

preview_np_array(v_dynamics, "v_dynamics", edge_items=3)
preview_np_array(u_dynamics, "u_dynamics", edge_items=3)

v_dynamics Shape: (120000, 256).
Preview: [[0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 ...
 [6.26404417e-32 2.11770672e-31 1.33287944e-31 ... 1.21019874e-32
  7.87830517e-32 1.12042268e-31]
 [5.82556108e-32 1.96946725e-31 1.23957788e-31 ... 1.12548483e-32
  7.32682381e-32 1.04199309e-31]
 [5.41777180e-32 1.83160455e-31 1.15280743e-31 ... 1.04670089e-32
  6.81394614e-32 9.69053578e-32]]
u_dynamics Shape: (120000, 256).
Preview: [[0.00000000e+000 0.00000000e+000 0.00000000e+000 ... 0.00000000e+000
  0.00000000e+000 0.00000000e+000]
 [0.00000000e+000 0.00000000e+000 0.00000000e+000 ... 0.00000000e+000
  0.00000000e+000 0.00000000e+000]
 [0.00000000e+000 0.00000000e+000 0.00000000e+000 ... 0.00000000e+000
  0.00000000e+000 0.00000000e+000]
 

## Load the SNN Configuration and extract the fields
The SNN configuration contains:
- Numpy array with the ground_truth for each timestep
- Initial Time Offset
- Virtual Time Step Interval
- Number of Steps

In [4]:
from utils.snn import SNNSimConfig

# Load the SNN Config data
snn_config_file_name = f"{INPUT_PATH}/{band_file_name}_snn_config_{TIME_SUFFIX}.npy"

# Load the SNNConfig data as an element of the class SNNConfig
snn_config: SNNSimConfig = np.load(snn_config_file_name, allow_pickle=True).item()

# Extract the data fields from the SNNConfig object
ground_truth: np.ndarray = snn_config.ground_truth
init_offset = snn_config.init_offset
virtual_time_step_interval = snn_config.virtual_time_step_interval
num_steps = snn_config.num_steps

preview_np_array(ground_truth, "ground_truth", edge_items=3)
np.count_nonzero(ground_truth)

print("init_offset: ", init_offset)
print("virtual_time_step_interval: ", virtual_time_step_interval)
print("num_steps: ", num_steps)

ground_truth Shape: (199,).
Preview: [('Fast-Ripple',   1000.  , 0.)
 ('Spike+Ripple+Fast-Ripple',   3206.54, 0.)
 ('Fast-Ripple',   3770.02, 0.) ... ('Fast-Ripple', 116096.  , 0.)
 ('Ripple+Fast-Ripple', 116769.  , 0.) ('Fast-Ripple', 119000.  , 0.)]
init_offset:  0
virtual_time_step_interval:  1
num_steps:  120000


## Find the timesteps where the network spiked
Let's find the timesteps where the network spiked and create a `dictionary` mapping each feature neuron to the timesteps where it spiked.

In [5]:
from utils.data_analysis import find_spike_times

# Create a map storing the spike times for each feature neuron
neuron_spike_times = {}

# Call the find_spike_times util function that detects the spikes in a voltage array
spike_times_lif1 = find_spike_times(v_dynamics, u_dynamics)

total_spikes_count = 0
for (spike_time, neuron_idx) in spike_times_lif1:
    # Calculate the spike time in ms
    real_spike_time = init_offset + spike_time * virtual_time_step_interval

    # If the neuron index is not in the map (first time the feature neuron spikes), add it
    if neuron_idx not in neuron_spike_times:
        neuron_spike_times[neuron_idx] = [real_spike_time]
    # Otherwise, append the spike time to the list of spike times for that neuron
    else:
        neuron_spike_times[neuron_idx].append(real_spike_time)
    total_spikes_count += 1

    # print(f"Spike time: {real_spike_time} (iter. {spike_time}) at neuron: {neuron_idx}")

# Print the spike times for each feature neuron
print("neuron_spike_times: ", neuron_spike_times)


neuron_spike_times:  {2: [1007, 3226, 3770, 4138, 4144, 4147, 4152, 4775, 4788, 6290, 6678, 6685, 7221, 7240, 8106, 8111, 9104, 9112, 9683, 10056, 10518, 11131, 11643, 12130, 12145, 13199, 13203, 14055, 14591, 14595, 15205, 15216, 15556, 15561, 16134, 16785, 16798, 16802, 17030, 17749, 17752, 17757, 18122, 18131, 18784, 20155, 20766, 21179, 21792, 22667, 23021, 24214, 24657, 25120, 25125, 27053, 27064, 28538, 28549, 29269, 29277, 30260, 30748, 31727, 31730, 31742, 33096, 33719, 34617, 38197, 38200, 38558, 39230, 41232, 41235, 42286, 43313, 44103, 44109, 45148, 45161, 45169, 45551, 46270, 47108, 48226, 48231, 50727, 50737, 50741, 51044, 51556, 51572, 52112, 52751, 52754, 53301, 53306, 53309, 54244, 54250, 54255, 55012, 55761, 56093, 56107, 57313, 57321, 57555, 57573, 58575, 58585, 59266, 59270, 59706, 60039, 60717, 60721, 61167, 61513, 61516, 61526, 61531, 62647, 63291, 64636, 64651, 66608, 66611, 67246, 68048, 68787, 69130, 69138, 69142, 69794, 69805, 70763, 70771, 71292, 71295, 71300,

## Define the Time Window after the event insertion to consider as part of the event

In [6]:
# Define the time window after an HFO insertion to consider as part of the event
ripple_confidence_window = 120  # Let's give a 120ms window after the Ripple to consider as part of the event
fr_confidence_window = 60  # Let's give a 60ms window after the Fast Ripple to consider as part of the event    (TODO: Could be changed)
both_confidence_window = 120  # Let's give a 120ms window after the HFO event (Ripple or Fast Ripple) insertion to consider as part of the event

confidence_window = ripple_confidence_window
if chosen_band == MarkerType.FAST_RIPPLE:
    confidence_window = fr_confidence_window
elif chosen_band == MarkerType.BOTH:
    confidence_window = both_confidence_window

## Merge the spikes that occur at a distance less than the confidence window
In order to calculate the efficiency of the feature neurons, we need to merge the spikes that occur at a distance less than the time window. Otherwise, we would be counting the same event multiple times, which would lead to an overestimation of the efficiency.

This is **probably** safe, since each annotated event occurs at a distance greater than the considered time window, usually at least 1 order of magnitude greater.

In [7]:
# Merge the spikes that occur at a distance less than the confidence window
merged_spikes_counter = 0
for neuron_key in neuron_spike_times.keys():
    spike_times = neuron_spike_times[neuron_key]

    # Array with the merged spike times
    merged_spike_times = [spike_times[0]]

    curr_idx = 1
    comparison_idx = 0
    while curr_idx < len(spike_times):
        spike_time = spike_times[curr_idx]
        prev_spike_time = spike_times[comparison_idx]

        if spike_time - prev_spike_time < confidence_window:
            # This spike is part of the same event
            curr_idx += 1
            merged_spikes_counter += 1
            continue
        else:
            # This spike is part of a new event
            merged_spike_times.append(spike_time)
            comparison_idx = curr_idx   # Update the comparison index to the new spike time
            curr_idx += 1

    # Update the neuron_spike_times map with the merged spike times
    neuron_spike_times[neuron_key] = merged_spike_times

# Print the spike times for each feature neuron
print("Merged neuron_spike_times: ", neuron_spike_times)

print(f"Merged a total of {merged_spikes_counter} spikes out of {total_spikes_count} spikes")

Merged neuron_spike_times:  {2: [1007, 3226, 3770, 4138, 4775, 6290, 6678, 7221, 8106, 9104, 9683, 10056, 10518, 11131, 11643, 12130, 13199, 14055, 14591, 15205, 15556, 16134, 16785, 17030, 17749, 18122, 18784, 20155, 20766, 21179, 21792, 22667, 23021, 24214, 24657, 25120, 27053, 28538, 29269, 30260, 30748, 31727, 33096, 33719, 34617, 38197, 38558, 39230, 41232, 42286, 43313, 44103, 45148, 45551, 46270, 47108, 48226, 50727, 51044, 51556, 52112, 52751, 53301, 54244, 55012, 55761, 56093, 57313, 57555, 58575, 59266, 59706, 60039, 60717, 61167, 61513, 62647, 63291, 64636, 66608, 67246, 68048, 68787, 69130, 69794, 70763, 71292, 72718, 73168, 73715, 74056, 75154, 75721, 76238, 76623, 77151, 77719, 78257, 79540, 80239, 80739, 81749, 82062, 82791, 83060, 84184, 85105, 85772, 86270, 86638, 87128, 89050, 90129, 90648, 91249, 91578, 92273, 92619, 93287, 94030, 94623, 95169, 95634, 96705, 97608, 99270, 99632, 100293, 100575, 101680, 102267, 102711, 103261, 103574, 105792, 106037, 106767, 107172, 1

## Iterate over the SNN running time and classify the feature neurons

Create a numpy array with the classification of each feature neuron.

In [8]:
from utils.snn import NeuronClass

# Create numpy array containing the class of each feature neuron
feature_neuron_class = np.full(shape=(v_dynamics.shape[1]), fill_value=NeuronClass.SILENT)
preview_np_array(feature_neuron_class, "feature_neuron_class", edge_items=3)

feature_neuron_class Shape: (256,).
Preview: [0 0 0 ... 0 0 0]


### Keep track of the relevant events each feature neuron detected

In [9]:
# Define a map to keep track of the number of relevant events each feature neuron detected
relevant_neuron_spike_times = {}

# Create an empty list for each feature neuron with spikes
for neuron_idx in neuron_spike_times.keys():
    relevant_neuron_spike_times[neuron_idx] = []

print("relevant_neuron_spike_times: ", relevant_neuron_spike_times)

relevant_neuron_spike_times:  {2: [], 1: [], 18: [], 0: [], 5: [], 87: [], 4: [], 11: [], 99: [], 52: [], 42: [], 82: [], 13: [], 32: [], 3: [], 116: [], 16: [], 6: [], 149: [], 106: [], 28: [], 20: [], 45: [], 71: [], 34: [], 77: [], 51: [], 7: [], 47: [], 202: [], 75: [], 244: [], 26: [], 92: [], 27: [], 66: [], 12: [], 176: [], 59: [], 8: [], 23: [], 9: [], 57: [], 204: [], 15: [], 19: [], 21: [], 41: [], 145: [], 65: [], 108: [], 62: [], 175: [], 36: [], 10: [], 153: [], 17: [], 114: [], 29: [], 93: [], 205: [], 31: [], 127: [], 163: [], 25: [], 142: [], 113: [], 170: [], 136: [], 162: [], 235: [], 35: [], 249: [], 39: [], 234: [], 255: [], 211: [], 55: [], 96: [], 168: [], 184: [], 110: [], 229: [], 30: [], 141: [], 89: [], 139: [], 83: [], 169: [], 179: [], 166: [], 132: [], 46: [], 241: [], 129: [], 200: [], 97: [], 119: []}


### Define method that searches for a relevant event in the ground_truth np array

In [10]:
# Declare a sorted np.array of the ground truth events by timestamp
ground_truth_timestamps = np.array([ann_event[1] for ann_event in ground_truth])

preview_np_array(ground_truth_timestamps, "ground_truth_timestamps", edge_items=3)

ground_truth_timestamps Shape: (199,).
Preview: [  1000.     3206.54   3770.02 ... 116096.   116769.   119000.  ]


In [11]:
def has_relevant_event(spike_time, confidence_window):
    """
    Searches for a relevant event in the ground truth array given a timestamp.
    The annotated event must be located before the spike time, since the annotation
    corresponds to the insertion of the event.
    Thus, the Event must be in the window [spike_time - confidence_window, spike_time]
    """
    for event_time in ground_truth_timestamps:
        if spike_time - confidence_window <= event_time <= spike_time:
            return True
        
        if event_time > spike_time:
            # Since the events are sorted, if the event_time is greater than the spike_time, we can stop looking
            break
    

In [12]:
# Iterate over the feature neurons that spiked
for neuron_idx in neuron_spike_times.keys():
    curr_spike_times = neuron_spike_times[neuron_idx]
    
    # For each Spike
    for curr_spike_time in curr_spike_times:
        # Since the SNN will spike after the insertion of the event, we need to consider the confidence window before the SNN spike.
        # The Annotated event is located at the insertion time (before)

        # Check if the ground truth contains at least 1 annotated event within the confidence window
        if has_relevant_event(curr_spike_time, confidence_window):
            # Add the event to the relevant neuron spike times
            relevant_neuron_spike_times[neuron_idx].append(curr_spike_time)

print("relevant_neuron_spike_times: ", relevant_neuron_spike_times)

relevant_neuron_spike_times:  {2: [1007, 3226, 4138, 4775, 6290, 8106, 9104, 9683, 10056, 10518, 11131, 11643, 12130, 13199, 14055, 14591, 15556, 16134, 16785, 17030, 17749, 18784, 20155, 20766, 21179, 21792, 22667, 23021, 24214, 24657, 27053, 28538, 29269, 30260, 30748, 33096, 33719, 34617, 38197, 38558, 39230, 41232, 42286, 43313, 44103, 45148, 45551, 47108, 48226, 50727, 51044, 51556, 52112, 52751, 53301, 54244, 55012, 55761, 56093, 57313, 58575, 59266, 59706, 60039, 60717, 61167, 62647, 63291, 64636, 66608, 67246, 68048, 69130, 69794, 70763, 72718, 73168, 74056, 75154, 75721, 76238, 76623, 77719, 79540, 80739, 81749, 82062, 82791, 85772, 86270, 86638, 87128, 89050, 90129, 90648, 91249, 91578, 92273, 93287, 94030, 94623, 95169, 96705, 97608, 99270, 99632, 100293, 101680, 102267, 102711, 103574, 105792, 106037, 106767, 107172, 107701, 109518, 110684, 113147, 113610, 114303, 116785, 119006], 1: [1008, 3209, 3777, 4142, 4781, 6286, 7230, 8107, 9108, 9689, 11117, 11647, 12141, 12683, 13

## Statistics regarding the spiking neurons

In [13]:
# Create an array containing the relevant spike ratio for each feature neuron 
# in the order of their keys
relevant_spike_ratios = []
for neuron_idx in neuron_spike_times.keys():
    # Get all the spikes of the current neuron
    curr_spike_times = neuron_spike_times[neuron_idx]
    # Get the relevant spikes of the current neuron
    relevant_spike_times = relevant_neuron_spike_times[neuron_idx]

    # Calculate the ratio of relevant spikes
    relevant_spike_ratio = (len(relevant_spike_times) / len(curr_spike_times)) * 100 if len(curr_spike_times) > 0 else 0.0
    relevant_spike_ratios.append(relevant_spike_ratio)

In [14]:
PLOT_RELEVANCY_STATS = True

if PLOT_RELEVANCY_STATS:
    # Iterate over the feature neurons that spiked
    for iter_idx, neuron_idx in enumerate(neuron_spike_times.keys()):
        # Get all the spikes of the current neuron
        curr_spike_times = neuron_spike_times[neuron_idx]

        # Get the relevant spikes of the current neuron
        relevant_spike_times = relevant_neuron_spike_times[neuron_idx]

        print(f"Neuron {neuron_idx}: ")
        print(f"Spikes: {curr_spike_times}")
        print(f"Relevant Spikes: {relevant_spike_times}")
        print(f"Relevant Spikes Ratio: {relevant_spike_ratios[iter_idx]}%")
        print("================================================================\n")

Neuron 2: 
Spikes: [1007, 3226, 3770, 4138, 4775, 6290, 6678, 7221, 8106, 9104, 9683, 10056, 10518, 11131, 11643, 12130, 13199, 14055, 14591, 15205, 15556, 16134, 16785, 17030, 17749, 18122, 18784, 20155, 20766, 21179, 21792, 22667, 23021, 24214, 24657, 25120, 27053, 28538, 29269, 30260, 30748, 31727, 33096, 33719, 34617, 38197, 38558, 39230, 41232, 42286, 43313, 44103, 45148, 45551, 46270, 47108, 48226, 50727, 51044, 51556, 52112, 52751, 53301, 54244, 55012, 55761, 56093, 57313, 57555, 58575, 59266, 59706, 60039, 60717, 61167, 61513, 62647, 63291, 64636, 66608, 67246, 68048, 68787, 69130, 69794, 70763, 71292, 72718, 73168, 73715, 74056, 75154, 75721, 76238, 76623, 77151, 77719, 78257, 79540, 80239, 80739, 81749, 82062, 82791, 83060, 84184, 85105, 85772, 86270, 86638, 87128, 89050, 90129, 90648, 91249, 91578, 92273, 92619, 93287, 94030, 94623, 95169, 95634, 96705, 97608, 99270, 99632, 100293, 100575, 101680, 102267, 102711, 103261, 103574, 105792, 106037, 106767, 107172, 107701, 109518

## Plot the Relevant Spike Ratio of the Feature Neurons that Spike

In [15]:
# Create a bar plot containing the % of relevant spikes for each feature neuron 
from utils.bar_plot import create_bar_fig  # Import the function to create the figure

# Define the x and y values
neuron_label = [f"Neu. {neuron_idx}" for neuron_idx in neuron_spike_times.keys()]

# sorting the bars means sorting the range factors
neurons_descending_rel_ratio = sorted(neuron_label, key=lambda x: relevant_spike_ratios[neuron_label.index(x)], reverse=True)

# Create the LIF1 Voltage
feat_neurons_relevancy_plot = create_bar_fig(
    title="Feature Neurons Relevant Spikes Ratio", 
    x_axis_label='Feature Neuron Index', 
    y_axis_label='Relevant Spike Ratio (%)',
    x=neuron_label,
    y=relevant_spike_ratios,
    x_range=neurons_descending_rel_ratio,
    sizing_mode="stretch_width",
    bar_width=0.5,
    tooltips=[("Neuron Index", "@x"), ("Relevant Spike Ratio", "@top%")]
)

In [16]:
import bokeh.plotting as bplt

showPlot = True
if showPlot:
    # Show the plot
    bplt.show(feat_neurons_relevancy_plot)

In [17]:
EXPORT_RELEVANCY_PLOT = False
OUTPUT_FOLDER = "./neuron_classification/custom_subset_90-119_segment500_200"

if EXPORT_RELEVANCY_PLOT:
    file_path = f"{OUTPUT_FOLDER}/{band_file_name}_relevancy_barplot_0.07dv_time{init_offset}-{num_steps}-{virtual_time_step_interval}.html"

    # Customize the output file settings
    bplt.output_file(filename=file_path, title="Feature Neurons Relevancy (%) Bar Plot")

    # Save the plot
    bplt.save(feat_neurons_relevancy_plot)

## Classify the Feature Neurons according to their Spiking Activity
In this step, we will update the `feature_neuron_class` array with the classification of the `Noisy Neurons`, `Ripple Neurons` and `Fast Ripple Neurons`.

We set a **threshold of 0.9** for the `Relevant Spike Ratio` to classify a neuron as a `Ripple` or `Fast Ripple` Neuron.

In [18]:
from utils.snn import marker_type_to_neuron_class

relevant_ratio_threshold = 90.0

# Iterate over the feature neurons that spiked
for iter_idx, neuron_idx in enumerate(neuron_spike_times.keys()):
    # Get all the spikes of the current neuron
    curr_spike_times = neuron_spike_times[neuron_idx]
    # Get the relevant spikes of the current neuron
    relevant_spike_times = relevant_neuron_spike_times[neuron_idx]
    # Get the relevant spikes ratio of the current neuron
    relevant_spike_ratio = relevant_spike_ratios[iter_idx]

    if relevant_spike_ratio >= relevant_ratio_threshold:
        # Classify the Neuron as a RIPPLE/FAST RIPPLE/BOTH detector based on the band we are analyzing
        feature_neuron_class[neuron_idx] = marker_type_to_neuron_class(chosen_band)
    else:
        # Classify the Neuron as Noisy
        feature_neuron_class[neuron_idx] = NeuronClass.NOISY

### Show the classification of the Feature Neurons

In [19]:
preview_np_array(feature_neuron_class, "feature_neuron_class", edge_items=3)

feature_neuron_class Shape: (256,).
Preview: [1 1 1 ... 0 0 3]


# Calculate the Classification Metrics of the SNN based on the Feature Neurons
As of right now, we have a Spiking Neural Network that updates it's state in real-time. This network is capable of detecting relevant oscillatory events through spikes.

Therefore, we have to define the conditions that identify the detection of a HFO event or not. Since each Feature Neuron is modelled to be tuned to a specific frequency range, not all of them will fire simultaneously in the presence of a HFO event. Hence, we need to define a threshold for the number of Feature Neurons that need to fire in order to classify an event as a HFO. This is an exploratory approach, but as a starting point, let's say that we need 2 neurons firing during the time window (`confidence_window`).

In [None]:
# Initialize the variables that store the values of the Confusion Matrix
true_positive = 0
false_positive = 0
true_negative = 0
false_negative = 0 

In [23]:
# Go through the voltage dynamics array and calculate the classification metrics of the network
for run_step, feat_neurons_step in enumerate(v_dynamics):
    # Check if there is a spike in the feature neurons in the interval [run_step, run_step+confidence_window]


    # Check if there is an annotated event in the ground_truth for the current run_step
    

    # If feature neuron spiked + annotated event -> True Positive

    # If feature neuron spiked + no annotated event -> False Positive

    # If feature neuron didn't spike + annotated event -> False Negative

    # If feature neuron didn't spike + no annotated event -> True Negative


    # Edge case: If there is an annotated event at a distance < confidence_window from the end of the simulation, the feature events
    # won't have time to detect it. Thus, we don't consider that event.

feat_neurons_step length: 256
