# Classification of the Feature Neurons from the SNN
This notebook performs the classification of the feature neurons from the SNN according to certain criteria. In the future, it would be interesting to implement a 3rd layer in the SNN to classify the feature neurons automatically. For now, this manual classification will classify each neuron as one of the following:
- **Silent Neuron**: Neuron that does not fire at all.
- **Noisy Neuron**: Neuron that fires randomly or without a relevant pattern.
- **Ripple Neuron**: Neuron that fires in the presence of a ripple.
- **Fast Ripple Neuron**: Neuron that fires in the presence of a fast ripple.

## Check WD (change if necessary) and file loading

In [82]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)

# Check if the current WD is the file location
if "/src/hfo/snn" not in os.getcwd():
    # Set working directory to this file location
    file_location = f"{os.getcwd()}/thesis-lava/src/hfo/snn"
    print("File Location: ", file_location)

    # Change the current working Directory
    os.chdir(file_location)

    # New Working Directory
    print("New Working Directory: ", os.getcwd())

/home/monkin/Desktop/feup/thesis/thesis-lava/src/hfo/snn


## Load the Voltage and Current Dynamics during the SNN run

In [83]:
import numpy as np
from utils.io import preview_np_array
from utils.input import MarkerType, band_to_file_name

# Declare if using ripples, fast ripples, or both
chosen_band = MarkerType.FAST_RIPPLE     # RIPPLE, FAST_RIPPLE, or BOTH
band_file_name = band_to_file_name(chosen_band)

INPUT_PATH = f"./results/custom_subset_90-119_segment500_200"

# Load the voltage and current data from the numpy files
voltage_file_name = f"{INPUT_PATH}/{band_file_name}_v_dynamics_0.07dv_5ch_time1000-300-1.npy"
current_file_name = f"{INPUT_PATH}/{band_file_name}_u_dynamics_0.07dv_5ch_time1000-300-1.npy"

v_dynamics = np.load(voltage_file_name)
u_dynamics = np.load(current_file_name)

preview_np_array(v_dynamics, "v_dynamics", edge_items=3)
preview_np_array(u_dynamics, "u_dynamics", edge_items=3)

v_dynamics Shape: (300, 256).
Preview: [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-3.00000000e-01 -3.00000000e-01 -3.00000000e-01 ... -3.00000000e-01
  -3.00000000e-01 -3.00000000e-01]
 [-3.63637806e-02 -6.19632354e-02  2.10000000e-02 ... -1.82933897e-01
   1.99834875e-02 -1.20341243e-01]
 ...
 [ 1.08405053e-09  9.42952115e-10 -1.23785429e-10 ...  1.94098992e-09
   7.40639226e-10  2.13210245e-10]
 [ 1.00816699e-09  8.76945467e-10 -1.15120449e-10 ...  1.80512063e-09
   6.88794480e-10  1.98285528e-10]
 [ 9.37595301e-10  8.15559284e-10 -1.07062018e-10 ...  1.67876218e-09
   6.40578866e-10  1.84405541e-10]]
u_dynamics Shape: (300, 256).
Preview: [[ 0.00000000e+000  0.00000000e+000  0.00000000e+000 ...  0.00000000e+000
   0.00000000e+000  0.00000000e+000]
 [-3.00000000e-001 -3.00000000e-001 -3.00000000e-001 ... -3.00000000e-001
  -3.00000000e-001 -3.00000000e-001]
 [ 2.42636219e-001  2.17036765e-001  3.00000000e-001 ...  9.6066

## Load the SNN Configuration and extract the fields
The SNN configuration contains:
- Numpy array with the ground_truth for each timestep
- Initial Time Offset
- Virtual Time Step Interval
- Number of Steps

In [84]:
from utils.snn import SNNSimConfig

# Load the SNN Config data
snn_config_file_name = f"{INPUT_PATH}/{band_file_name}_snn_config.npy"

# Load the SNNConfig data as an element of the class SNNConfig
snn_config: SNNSimConfig = np.load(snn_config_file_name, allow_pickle=True).item()

# Extract the data fields from the SNNConfig object
ground_truth: np.ndarray = snn_config.ground_truth
init_offset = snn_config.init_offset
virtual_time_step_interval = snn_config.virtual_time_step_interval
num_steps = snn_config.num_steps

preview_np_array(ground_truth, "ground_truth", edge_items=3)
np.count_nonzero(ground_truth)

print("init_offset: ", init_offset)
print("virtual_time_step_interval: ", virtual_time_step_interval)
print("num_steps: ", num_steps)

ground_truth Shape: (199,).
Preview: [('Fast-Ripple',   1000.  , 0.)
 ('Spike+Ripple+Fast-Ripple',   3206.54, 0.)
 ('Fast-Ripple',   3770.02, 0.) ... ('Fast-Ripple', 116096.  , 0.)
 ('Ripple+Fast-Ripple', 116769.  , 0.) ('Fast-Ripple', 119000.  , 0.)]
init_offset:  1000
virtual_time_step_interval:  1
num_steps:  300


## Find the timesteps where the network spiked
Let's find the timesteps where the network spiked and create a `dictionary` mapping each feature neuron to the timesteps where it spiked.

In [85]:
from utils.data_analysis import find_spike_times

# Create a map storing the spike times for each feature neuron
neuron_spike_times = {}

# Call the find_spike_times util function that detects the spikes in a voltage array
spike_times_lif1 = find_spike_times(v_dynamics, u_dynamics)

for (spike_time, neuron_idx) in spike_times_lif1:
    # Calculate the spike time in ms
    real_spike_time = init_offset + spike_time * virtual_time_step_interval

    # If the neuron index is not in the map (first time the feature neuron spikes), add it
    if neuron_idx not in neuron_spike_times:
        neuron_spike_times[neuron_idx] = [real_spike_time]
    # Otherwise, append the spike time to the list of spike times for that neuron
    else:
        neuron_spike_times[neuron_idx].append(real_spike_time)

    # print(f"Spike time: {real_spike_time} (iter. {spike_time}) at neuron: {neuron_idx}")

# Print the spike times for each feature neuron
print("neuron_spike_times: ", neuron_spike_times)


neuron_spike_times:  {0: [1007, 1017], 31: [1008], 5: [1009], 4: [1010], 47: [1011], 50: [1014], 121: [1018]}


## Define the Time Window after the event insertion to consider as part of the event

In [86]:
# Define the time window after an HFO insertion to consider as part of the event
ripple_confidence_window = 120  # Let's give a 120ms window after the Ripple to consider as part of the event
fr_confidence_window = 60  # Let's give a 60ms window after the Fast Ripple to consider as part of the event    (TODO: Could be changed)
both_confidence_window = 120  # Let's give a 120ms window after the HFO event (Ripple or Fast Ripple) insertion to consider as part of the event

confidence_window = ripple_confidence_window
if chosen_band == MarkerType.FAST_RIPPLE:
    confidence_window = fr_confidence_window
elif chosen_band == MarkerType.BOTH:
    confidence_window = both_confidence_window

## Merge the spikes that occur at a distance less than the confidence window
In order to calculate the efficiency of the feature neurons, we need to merge the spikes that occur at a distance less than the time window. Otherwise, we would be counting the same event multiple times, which would lead to an overestimation of the efficiency.

This is **probably** safe, since each annotated event occurs at a distance greater than the considered time window, usually at least 1 order of magnitude greater.

In [87]:
# Merge the spikes that occur at a distance less than the confidence window
for neuron_key in neuron_spike_times.keys():
    spike_times = neuron_spike_times[neuron_key]

    # Array with the merged spike times
    merged_spike_times = [spike_times[0]]

    curr_idx = 1
    comparison_idx = 0
    while curr_idx < len(spike_times):
        spike_time = spike_times[curr_idx]
        prev_spike_time = merged_spike_times[comparison_idx]

        if spike_time - prev_spike_time < confidence_window:
            # This spike is part of the same event
            curr_idx += 1
            continue
        else:
            # This spike is part of a new event
            merged_spike_times.append(spike_time)
            comparison_idx = curr_idx   # Update the comparison index to the new spike time
            curr_idx += 1

    # Update the neuron_spike_times map with the merged spike times
    neuron_spike_times[neuron_key] = merged_spike_times

# Print the spike times for each feature neuron
print("Merged neuron_spike_times: ", neuron_spike_times)

Merged neuron_spike_times:  {0: [1007], 31: [1008], 5: [1009], 4: [1010], 47: [1011], 50: [1014], 121: [1018]}


## Iterate over the SNN running time and classify the feature neurons

Create a numpy array with the classification of each feature neuron.

In [68]:
from utils.snn import NeuronClass

# Create numpy array containing the class of each feature neuron
feature_neuron_class = np.full(shape=(v_dynamics.shape[1]), fill_value=NeuronClass.SILENT)
preview_np_array(feature_neuron_class, "feature_neuron_class", edge_items=3)

feature_neuron_class Shape: (256,).
Preview: [0 0 0 ... 0 0 0]


### Keep track of the relevant events each feature neuron detected

In [69]:
# Define a map to keep track of the number of relevant events each feature neuron detected
relevant_neuron_spike_times = {}

# Create an empty list for each feature neuron with spikes
for neuron_idx in neuron_spike_times.keys():
    relevant_neuron_spike_times[neuron_idx] = []

print("relevant_neuron_spike_times: ", relevant_neuron_spike_times)

relevant_neuron_spike_times:  {0: [], 31: [], 5: [], 4: [], 47: [], 50: [], 121: []}


### Define method that searches for a relevant event in the ground_truth np array

In [70]:
# Declare a sorted np.array of the ground truth events by timestamp
ground_truth_timestamps = np.array([ann_event[1] for ann_event in ground_truth])

preview_np_array(ground_truth_timestamps, "ground_truth_timestamps", edge_items=3)

ground_truth_timestamps Shape: (199,).
Preview: [  1000.     3206.54   3770.02 ... 116096.   116769.   119000.  ]


In [71]:
def has_relevant_event(spike_time, confidence_window):
    """
    Searches for a relevant event in the ground truth array given a timestamp.
    The annotated event must be located before the spike time, since the annotation
    corresponds to the insertion of the event.
    Thus, the Event must be in the window [spike_time - confidence_window, spike_time]
    """
    for event_time in ground_truth_timestamps:
        if spike_time - confidence_window <= event_time <= spike_time:
            return True
        
        if event_time > spike_time:
            # Since the events are sorted, if the event_time is greater than the spike_time, we can stop looking
            break
    

In [72]:
from utils.snn import marker_type_to_neuron_class

# Iterate over the feature neurons that spiked
for neuron_idx in neuron_spike_times.keys():
    curr_spike_times = neuron_spike_times[neuron_idx]
    
    # For each Spike
    for curr_spike_time in curr_spike_times:
        # Since the SNN will spike after the insertion of the event, we need to consider the confidence window before the SNN spike.
        # The Annotated event is located at the insertion time (before)

        # Check if the ground truth contains at least 1 annotated event within the confidence window
        if has_relevant_event(curr_spike_time, confidence_window):
            # Add the event to the relevant neuron spike times
            relevant_neuron_spike_times[neuron_idx].append(curr_spike_time)

print("relevant_neuron_spike_times: ", relevant_neuron_spike_times)

relevant_neuron_spike_times:  {0: [1007], 31: [1008], 5: [1009], 4: [1010], 47: [1011], 50: [1014], 121: [1018]}


## Statistics regarding the spiking neurons

In [90]:
# Iterate over the feature neurons that spiked
for neuron_idx in neuron_spike_times.keys():
    # Get all the spikes of the current neuron
    curr_spike_times = neuron_spike_times[neuron_idx]

    # Get the relevant spikes of the current neuron
    relevant_spike_times = relevant_neuron_spike_times[neuron_idx]

    print(f"Neuron {neuron_idx}: ")
    print(f"Spikes: {curr_spike_times}")
    print(f"Relevant Spikes: {relevant_spike_times}")
    print(f"Relevant Spikes Ratio: {(len(relevant_spike_times) / len(curr_spike_times))*100 if len(curr_spike_times) > 0 else 0.0}%")
    print("\n\n")

Neuron 0: 
Spikes: [1007]
Relevant Spikes: [1007]
Relevant Spikes Ratio: 100.0%



Neuron 31: 
Spikes: [1008]
Relevant Spikes: [1008]
Relevant Spikes Ratio: 100.0%



Neuron 5: 
Spikes: [1009]
Relevant Spikes: [1009]
Relevant Spikes Ratio: 100.0%



Neuron 4: 
Spikes: [1010]
Relevant Spikes: [1010]
Relevant Spikes Ratio: 100.0%



Neuron 47: 
Spikes: [1011]
Relevant Spikes: [1011]
Relevant Spikes Ratio: 100.0%



Neuron 50: 
Spikes: [1014]
Relevant Spikes: [1014]
Relevant Spikes Ratio: 100.0%



Neuron 121: 
Spikes: [1018]
Relevant Spikes: [1018]
Relevant Spikes Ratio: 100.0%



