# Classification of the Feature Neurons from the SNN
This notebook performs the classification of the feature neurons from the SNN according to certain criteria. In the future, it would be interesting to implement a 3rd layer in the SNN to classify the feature neurons automatically. For now, this manual classification will classify each neuron as one of the following:
- **Silent Neuron**: Neuron that does not fire at all.
- **Noisy Neuron**: Neuron that fires randomly or without a relevant pattern.
- **Ripple Neuron**: Neuron that fires in the presence of a ripple.
- **Fast Ripple Neuron**: Neuron that fires in the presence of a fast ripple.

## Check WD (change if necessary) and file loading

In [2]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)

# Check if the current WD is the file location
if "/src/hfo/snn" not in os.getcwd():
    # Set working directory to this file location
    file_location = f"{os.getcwd()}/thesis-lava/src/hfo/snn"
    print("File Location: ", file_location)

    # Change the current working Directory
    os.chdir(file_location)

    # New Working Directory
    print("New Working Directory: ", os.getcwd())

/home/monkin/Desktop/feup/thesis/thesis-lava/src/hfo/snn


## Load the Voltage and Current Dynamics during the SNN run

In [3]:
import numpy as np
from utils.io import preview_np_array
from utils.input import MarkerType, band_to_file_name

# Declare if using ripples, fast ripples, or both
chosen_band = MarkerType.FAST_RIPPLE     # RIPPLE, FAST_RIPPLE, or BOTH
band_file_name = band_to_file_name(chosen_band)

thresh_up = 1.8974
thresh_down = -thresh_up
DATASET_NAME = f"custom_subset_90-119_segment500_200_thresh{thresh_up}-{thresh_down}"
INPUT_PATH = f"./results/{DATASET_NAME}"
TIME_SUFFIX = "time0-120000-1"

# Load the voltage and current data from the numpy files
voltage_file_name = f"{INPUT_PATH}/{band_file_name}_v_dynamics_0.07dv_5ch_{TIME_SUFFIX}.npy"
current_file_name = f"{INPUT_PATH}/{band_file_name}_u_dynamics_0.07dv_5ch_{TIME_SUFFIX}.npy"

v_dynamics = np.load(voltage_file_name)
u_dynamics = np.load(current_file_name)

preview_np_array(v_dynamics, "v_dynamics", edge_items=3)
preview_np_array(u_dynamics, "u_dynamics", edge_items=3)

v_dynamics Shape: (120000, 256).
Preview: [[0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 ...
 [0.18756436 0.34578193 0.10013955 ... 0.13897003 0.24468999 0.06984873]
 [0.05018655 0.33788161 0.01963627 ... 0.12346785 0.20876799 0.03114039]
 [0.26011964 0.64542083 0.28117835 ... 0.42674998 0.50239221 0.32775098]]
u_dynamics Shape: (120000, 256).
Preview: [[ 0.          0.          0.         ...  0.          0.
   0.        ]
 [ 0.          0.          0.         ...  0.          0.
   0.        ]
 [ 0.          0.          0.         ...  0.          0.
   0.        ]
 ...
 [-0.17796511 -0.12175926 -0.13167818 ... -0.16395164 -0.15528676
  -0.16970102]
 [-0.12424831  0.01630441 -0.07349351 ... -0.00577428 -0.0187937
  -0.03381893]
 [ 0.21344615  0.33119094  0.26291662 ...  0.31192487  0.30823798
   0.29879041]]


## Load the SNN Configuration and extract the fields
The SNN configuration contains:
- Numpy array with the ground_truth for each timestep
- Initial Time Offset
- Virtual Time Step Interval
- Number of Steps

In [4]:
from utils.snn import SNNSimConfig

# Load the SNN Config data
snn_config_file_name = f"{INPUT_PATH}/{band_file_name}_snn_config_{TIME_SUFFIX}.npy"

# Load the SNNConfig data as an element of the class SNNConfig
snn_config: SNNSimConfig = np.load(snn_config_file_name, allow_pickle=True).item()

# Extract the data fields from the SNNConfig object
ground_truth: np.ndarray = snn_config.ground_truth
init_offset = snn_config.init_offset
virtual_time_step_interval = snn_config.virtual_time_step_interval
num_steps = snn_config.num_steps

preview_np_array(ground_truth, "ground_truth", edge_items=3)
np.count_nonzero(ground_truth)

print("init_offset: ", init_offset)
print("virtual_time_step_interval: ", virtual_time_step_interval)
print("num_steps: ", num_steps)

ground_truth Shape: (199,).
Preview: [('Fast-Ripple',   1000.  , 0.)
 ('Spike+Ripple+Fast-Ripple',   3206.54, 0.)
 ('Fast-Ripple',   3770.02, 0.) ... ('Fast-Ripple', 116096.  , 0.)
 ('Ripple+Fast-Ripple', 116769.  , 0.) ('Fast-Ripple', 119000.  , 0.)]
init_offset:  0
virtual_time_step_interval:  1
num_steps:  120000


## Find the timesteps where the network spiked
Let's find the timesteps where the network spiked and create a `dictionary` mapping each feature neuron to the timesteps where it spiked.

In [5]:
from utils.data_analysis import find_spike_times

# Create a map storing the spike times for each feature neuron
neuron_spike_times = {}

# Call the find_spike_times util function that detects the spikes in a voltage array
spike_times_lif1 = find_spike_times(v_dynamics, u_dynamics)

total_spikes_count = 0
for (spike_time, neuron_idx) in spike_times_lif1:
    # Calculate the spike time in ms
    real_spike_time = init_offset + spike_time * virtual_time_step_interval

    # If the neuron index is not in the map (first time the feature neuron spikes), add it
    if neuron_idx not in neuron_spike_times:
        neuron_spike_times[neuron_idx] = [real_spike_time]
    # Otherwise, append the spike time to the list of spike times for that neuron
    else:
        neuron_spike_times[neuron_idx].append(real_spike_time)
    total_spikes_count += 1

    # print(f"Spike time: {real_spike_time} (iter. {spike_time}) at neuron: {neuron_idx}")

# Print the spike times for each feature neuron
print("neuron_spike_times: ", neuron_spike_times)


neuron_spike_times:  {70: [21, 170, 338, 721, 1023, 1380, 1520, 2033, 2249, 2436, 2814, 2888, 3672, 4022, 4113, 4158, 4322, 4803, 4992, 5098, 5674, 6140, 6204, 6516, 6654, 7970, 8121, 8125, 8373, 8429, 8500, 8802, 9702, 9710, 10387, 10530, 10781, 10840, 10936, 11218, 11389, 11441, 11536, 11653, 11657, 11812, 12105, 12157, 12552, 13212, 13473, 13555, 13659, 14044, 14566, 14602, 15863, 16065, 16455, 17059, 17385, 17634, 17907, 17980, 18141, 18149, 18415, 18727, 18731, 18763, 18800, 18888, 19076, 19255, 19265, 20172, 20389, 20669, 21044, 21164, 21323, 21666, 21744, 21799, 21803, 21964, 22677, 22681, 22839, 23041, 23286, 23735, 23891, 24606, 24687, 24692, 24995, 25139, 25372, 25420, 25636, 25781, 25788, 25957, 26666, 26836, 27036, 27075, 27522, 27558, 28000, 28101, 28140, 28251, 28340, 28410, 29221, 29281, 29502, 29545, 29907, 30096, 30274, 30770, 30948, 30982, 31226, 31439, 32021, 32125, 32522, 33088, 33093, 33366, 33703, 33977, 34595, 34626, 34910, 34967, 34994, 35028, 35573, 35901, 3624

## Define the Time Window after the event insertion to consider as part of the event

In [6]:
from utils.input import RIPPLE_CONFIDENCE_WINDOW, FR_CONFIDENCE_WINDOW, BOTH_CONFIDENCE_WINDOW

# Define the time window after an HFO insertion to consider as part of the event
confidence_window = RIPPLE_CONFIDENCE_WINDOW
if chosen_band == MarkerType.FAST_RIPPLE:
    confidence_window = FR_CONFIDENCE_WINDOW
elif chosen_band == MarkerType.BOTH:
    confidence_window = BOTH_CONFIDENCE_WINDOW

## Merge the spikes that occur at a distance less than the confidence window
In order to calculate the efficiency of the feature neurons, we need to merge the spikes that occur at a distance less than the time window. Otherwise, we would be counting the same event multiple times, which would lead to an overestimation of the efficiency.

This is **probably** safe, since each annotated event occurs at a distance greater than the considered time window, usually at least 1 order of magnitude greater.

In [7]:
# Merge the spikes that occur at a distance less than the confidence window
merged_spikes_counter = 0
for neuron_key in neuron_spike_times.keys():
    spike_times = neuron_spike_times[neuron_key]

    # Array with the merged spike times
    merged_spike_times = [spike_times[0]]

    curr_idx = 1
    comparison_idx = 0
    while curr_idx < len(spike_times):
        spike_time = spike_times[curr_idx]
        prev_spike_time = spike_times[comparison_idx]

        if spike_time - prev_spike_time < confidence_window:
            # This spike is part of the same event
            curr_idx += 1
            merged_spikes_counter += 1
            continue
        else:
            # This spike is part of a new event
            merged_spike_times.append(spike_time)
            comparison_idx = curr_idx   # Update the comparison index to the new spike time
            curr_idx += 1

    # Update the neuron_spike_times map with the merged spike times
    neuron_spike_times[neuron_key] = merged_spike_times

# Print the spike times for each feature neuron
print("Merged neuron_spike_times: ", neuron_spike_times)

print(f"Merged a total of {merged_spikes_counter} spikes out of {total_spikes_count} spikes")

Merged neuron_spike_times:  {70: [21, 170, 338, 721, 1023, 1380, 1520, 2033, 2249, 2436, 2814, 2888, 3672, 4022, 4113, 4322, 4803, 4992, 5098, 5674, 6140, 6204, 6516, 6654, 7970, 8121, 8373, 8500, 8802, 9702, 10387, 10530, 10781, 10936, 11218, 11389, 11536, 11653, 11812, 12105, 12552, 13212, 13473, 13555, 13659, 14044, 14566, 15863, 16065, 16455, 17059, 17385, 17634, 17907, 17980, 18141, 18415, 18727, 18800, 18888, 19076, 19255, 20172, 20389, 20669, 21044, 21164, 21323, 21666, 21744, 21964, 22677, 22839, 23041, 23286, 23735, 23891, 24606, 24687, 24995, 25139, 25372, 25636, 25781, 25957, 26666, 26836, 27036, 27522, 28000, 28101, 28251, 28340, 28410, 29221, 29281, 29502, 29907, 30096, 30274, 30770, 30948, 31226, 31439, 32021, 32125, 32522, 33088, 33366, 33703, 33977, 34595, 34910, 34994, 35573, 35901, 36246, 36381, 36643, 36902, 37178, 37375, 38178, 38328, 38515, 38836, 38968, 39103, 39246, 39509, 39818, 40817, 41063, 41253, 42027, 42238, 42468, 42543, 42899, 43655, 44406, 44471, 44546, 

## Iterate over the SNN running time and classify the feature neurons

Create a numpy array with the classification of each feature neuron.

In [8]:
from utils.snn import NeuronClass

# Create numpy array containing the class of each feature neuron
feature_neuron_class = np.full(shape=(v_dynamics.shape[1]), fill_value=NeuronClass.SILENT)
preview_np_array(feature_neuron_class, "feature_neuron_class", edge_items=3)

feature_neuron_class Shape: (256,).
Preview: [0 0 0 ... 0 0 0]


### Keep track of the relevant events each feature neuron detected

In [9]:
# Define a map to keep track of the number of relevant events each feature neuron detected
relevant_neuron_spike_times = {}

# Create an empty list for each feature neuron with spikes
for neuron_idx in neuron_spike_times.keys():
    relevant_neuron_spike_times[neuron_idx] = []

print("relevant_neuron_spike_times: ", relevant_neuron_spike_times)

relevant_neuron_spike_times:  {70: [], 90: [], 112: [], 13: [], 63: [], 2: [], 9: [], 17: [], 47: [], 216: [], 1: [], 135: [], 157: [], 51: [], 30: [], 251: [], 218: [], 91: [], 213: [], 111: [], 40: [], 0: [], 21: [], 19: [], 175: [], 27: [], 23: [], 192: [], 97: [], 185: [], 35: [], 204: [], 140: [], 98: [], 89: [], 72: [], 96: [], 130: [], 4: [], 25: [], 10: [], 5: [], 113: [], 3: [], 34: [], 236: [], 80: [], 6: [], 8: [], 46: [], 176: [], 79: [], 190: [], 172: [], 28: [], 53: [], 7: [], 119: [], 24: [], 18: [], 150: [], 54: [], 147: [], 74: [], 15: [], 55: [], 177: [], 14: [], 16: [], 101: [], 99: [], 214: [], 181: [], 22: [], 152: [], 156: [], 183: [], 215: [], 82: [], 160: [], 133: [], 92: [], 138: [], 58: [], 239: [], 88: [], 66: [], 250: [], 110: [], 52: [], 95: [], 33: [], 38: [], 45: [], 178: [], 48: [], 231: [], 36: [], 232: [], 109: [], 168: [], 225: [], 118: [], 76: [], 49: [], 105: [], 149: [], 180: []}


### Define method that searches for a relevant event in the ground_truth np array

In [10]:
# Declare a sorted np.array of the ground truth events by timestamp
ground_truth_timestamps = np.array([ann_event[1] for ann_event in ground_truth])

preview_np_array(ground_truth_timestamps, "ground_truth_timestamps", edge_items=3)

ground_truth_timestamps Shape: (199,).
Preview: [  1000.     3206.54   3770.02 ... 116096.   116769.   119000.  ]


In [11]:
def has_relevant_event(spike_time, confidence_window) -> float | None:
    """
    Searches for a relevant event in the ground truth array given a timestamp.
    The annotated event must be located before the spike time, since the annotation
    corresponds to the insertion of the event.
    Thus, the Event must be in the window [spike_time - confidence_window, spike_time]

    Returns:
    - The insertion time of the relevant event if found
    - None if no relevant event was found
    """
    # Check if spike_time is greater than the first event in the ground truth. If not, return False
    if ground_truth_timestamps[0] > spike_time:
        return None

    for event_time in ground_truth_timestamps:
        if spike_time - confidence_window <= event_time <= spike_time:
            return event_time
        
        if event_time > spike_time:
            # Since the events are sorted, if the event_time is greater than the spike_time, we can stop looking
            return None
    

In [12]:
# Iterate over the feature neurons that spiked
for neuron_idx in neuron_spike_times.keys():
    curr_spike_times = neuron_spike_times[neuron_idx]
    
    # For each Spike
    for curr_spike_time in curr_spike_times:
        # Since the SNN will spike after the insertion of the event, we need to consider the confidence window before the SNN spike.
        # The Annotated event is located at the insertion time (before)

        # Check if the ground truth contains at least 1 annotated event within the confidence window
        if has_relevant_event(curr_spike_time, confidence_window) != None:
            # Add the event to the relevant neuron spike times
            relevant_neuron_spike_times[neuron_idx].append(curr_spike_time)

print("relevant_neuron_spike_times: ", relevant_neuron_spike_times)

relevant_neuron_spike_times:  {70: [1023, 4803, 8121, 9702, 10530, 11653, 13212, 17059, 18141, 18800, 20172, 21164, 22677, 23041, 24687, 25139, 25781, 26666, 28251, 29281, 30274, 30770, 32125, 33088, 39246, 41253, 44546, 45562, 47121, 48227, 50307, 50749, 53671, 55032, 55776, 57587, 58594, 59283, 60041, 61174, 62263, 69146, 69801, 70775, 72732, 73179, 73746, 75174, 75744, 76646, 77737, 80252, 81761, 84738, 85134, 86644, 87150, 91586, 92283, 92635, 98061, 100314, 106063, 106789, 108310, 109526, 110325, 111109, 116097, 116789, 119018], 90: [4163, 9705, 13215, 17052, 18144, 20169, 24690, 26669, 31745, 39248, 40166, 52773, 57298, 59289, 60053, 71619, 86280, 92286, 92637, 95662, 106791, 111111, 116795], 112: [4161, 50309, 54271, 58596, 59286, 82812, 108543, 119004], 13: [12685, 26676, 33077, 34047, 37536, 39220, 45143, 50721, 51030, 59278, 60030, 68797, 70782, 71603, 81751, 85775, 86276, 90131, 91276, 98655, 119002], 63: [24697, 53648, 60732], 2: [3232, 9115, 12151, 12681, 13213, 16805, 170

## Statistics regarding the spiking neurons

In [13]:
# Create an array containing the relevant spike ratio for each feature neuron 
# in the order of their keys
relevant_spike_ratios = []
for neuron_idx in neuron_spike_times.keys():
    # Get all the spikes of the current neuron
    curr_spike_times = neuron_spike_times[neuron_idx]
    # Get the relevant spikes of the current neuron
    relevant_spike_times = relevant_neuron_spike_times[neuron_idx]

    # Calculate the ratio of relevant spikes
    relevant_spike_ratio = (len(relevant_spike_times) / len(curr_spike_times)) * 100 if len(curr_spike_times) > 0 else 0.0
    relevant_spike_ratios.append(relevant_spike_ratio)

In [14]:
PLOT_RELEVANCY_STATS = True

if PLOT_RELEVANCY_STATS:
    # Iterate over the feature neurons that spiked
    for iter_idx, neuron_idx in enumerate(neuron_spike_times.keys()):
        # Get all the spikes of the current neuron
        curr_spike_times = neuron_spike_times[neuron_idx]

        # Get the relevant spikes of the current neuron
        relevant_spike_times = relevant_neuron_spike_times[neuron_idx]

        print(f"Neuron {neuron_idx}: ")
        print(f"Spikes: {curr_spike_times}")
        print(f"Relevant Spikes: {relevant_spike_times}")
        print(f"Relevant Spikes Ratio: {relevant_spike_ratios[iter_idx]}%")
        print("================================================================\n")

Neuron 70: 
Spikes: [21, 170, 338, 721, 1023, 1380, 1520, 2033, 2249, 2436, 2814, 2888, 3672, 4022, 4113, 4322, 4803, 4992, 5098, 5674, 6140, 6204, 6516, 6654, 7970, 8121, 8373, 8500, 8802, 9702, 10387, 10530, 10781, 10936, 11218, 11389, 11536, 11653, 11812, 12105, 12552, 13212, 13473, 13555, 13659, 14044, 14566, 15863, 16065, 16455, 17059, 17385, 17634, 17907, 17980, 18141, 18415, 18727, 18800, 18888, 19076, 19255, 20172, 20389, 20669, 21044, 21164, 21323, 21666, 21744, 21964, 22677, 22839, 23041, 23286, 23735, 23891, 24606, 24687, 24995, 25139, 25372, 25636, 25781, 25957, 26666, 26836, 27036, 27522, 28000, 28101, 28251, 28340, 28410, 29221, 29281, 29502, 29907, 30096, 30274, 30770, 30948, 31226, 31439, 32021, 32125, 32522, 33088, 33366, 33703, 33977, 34595, 34910, 34994, 35573, 35901, 36246, 36381, 36643, 36902, 37178, 37375, 38178, 38328, 38515, 38836, 38968, 39103, 39246, 39509, 39818, 40817, 41063, 41253, 42027, 42238, 42468, 42543, 42899, 43655, 44406, 44471, 44546, 45128, 45562,

## Plot the Relevant Spike Ratio of the Feature Neurons that Spike

In [15]:
# Create a bar plot containing the % of relevant spikes for each feature neuron 
from utils.bar_plot import create_bar_fig  # Import the function to create the figure

# Define the x and y values
neuron_label = [f"Neu. {neuron_idx}" for neuron_idx in neuron_spike_times.keys()]

# sorting the bars means sorting the range factors
neurons_descending_rel_ratio = sorted(neuron_label, key=lambda x: relevant_spike_ratios[neuron_label.index(x)], reverse=True)

# Create the LIF1 Voltage
feat_neurons_relevancy_plot = create_bar_fig(
    title="Feature Neurons Relevant Spikes Ratio", 
    x_axis_label='Feature Neuron Index', 
    y_axis_label='Relevant Spike Ratio (%)',
    x=neuron_label,
    y=relevant_spike_ratios,
    x_range=neurons_descending_rel_ratio,
    sizing_mode="stretch_width",
    bar_width=0.5,
    tooltips=[("Neuron Index", "@x"), ("Relevant Spike Ratio", "@top%")]
)

In [16]:
import bokeh.plotting as bplt

showPlot = True
if showPlot:
    # Show the plot
    bplt.show(feat_neurons_relevancy_plot)

In [17]:
EXPORT_RELEVANCY_PLOT = True
OUTPUT_FOLDER = f"./neuron_classification/{DATASET_NAME}"

if EXPORT_RELEVANCY_PLOT:
    # Create the output folder if it does not exist
    if not os.path.exists(OUTPUT_FOLDER):
        os.makedirs(OUTPUT_FOLDER)

    file_path = f"{OUTPUT_FOLDER}/{band_file_name}_relevancy_barplot_0.07dv_time{init_offset}-{num_steps}-{virtual_time_step_interval}.html"

    # Customize the output file settings
    bplt.output_file(filename=file_path, title="Feature Neurons Relevancy (%) Bar Plot")

    # Save the plot
    bplt.save(feat_neurons_relevancy_plot)

## Classify the Feature Neurons according to their Spiking Activity
In this step, we will update the `feature_neuron_class` array with the classification of the `Noisy Neurons`, `Ripple Neurons` and `Fast Ripple Neurons`.

We set a **threshold of 0.9** for the `Relevant Spike Ratio` to classify a neuron as a `Ripple` or `Fast Ripple` Neuron.

In [18]:
from utils.snn import marker_type_to_neuron_class

relevant_ratio_threshold = 90.0

# Iterate over the feature neurons that spiked
for iter_idx, neuron_idx in enumerate(neuron_spike_times.keys()):
    # Get all the spikes of the current neuron
    curr_spike_times = neuron_spike_times[neuron_idx]
    # Get the relevant spikes of the current neuron
    relevant_spike_times = relevant_neuron_spike_times[neuron_idx]
    # Get the relevant spikes ratio of the current neuron
    relevant_spike_ratio = relevant_spike_ratios[iter_idx]

    if relevant_spike_ratio >= relevant_ratio_threshold:
        # Classify the Neuron as a RIPPLE/FAST RIPPLE/BOTH detector based on the band we are analyzing
        feature_neuron_class[neuron_idx] = marker_type_to_neuron_class(chosen_band)
    else:
        # Classify the Neuron as Noisy
        feature_neuron_class[neuron_idx] = NeuronClass.NOISY

### Show the classification of the Feature Neurons

In [19]:
preview_np_array(feature_neuron_class, "feature_neuron_class", edge_items=3)

feature_neuron_class Shape: (256,).
Preview: [1 1 1 ... 0 0 0]


# Calculate the Classification Metrics of the SNN based on the Feature Neurons
As of right now, we have a Spiking Neural Network that updates it's state in real-time. This network is capable of detecting relevant oscillatory events through spikes.

Therefore, we have to define the conditions that identify the detection of a HFO event or not. Since each Feature Neuron is modelled to be tuned to a specific frequency range, not all of them will fire simultaneously in the presence of a HFO event. Hence, we need to define a threshold for the number of Feature Neurons that need to fire in order to classify an event as a HFO. This is an exploratory approach, but as a starting point, let's say that we need 2 neurons firing during the time window (`confidence_window`).

In [20]:
# Initialize the variables that store the values of the Confusion Matrix
true_positive = 0
false_positive = 0
true_negative = 0
false_negative = 0

num_consensus_neurons = 2 # 1   # Number of neurons that must agree to classify an event as relevant

### Build a numpy array containing all the spike times from the `Detector Neurons` ordered by spiking time

In [21]:
# Find the index of the Feature Neurons that are HFO(Ripple/FR/Both) Detectors
detector_indices = np.where(feature_neuron_class == marker_type_to_neuron_class(chosen_band))[0]
preview_np_array(detector_indices, "detector_indices")

detector_indices Shape: (12,).
Preview: [  7  16  25  38  53 ... 110 149 180 225 232]


In [22]:
# Build an ordered list of spike times from all the HFO Detector Neurons
detector_spike_times = []

for neuron_idx in detector_indices:
    # print(f"Adding spikes from neuron {neuron_idx}: {neuron_spike_times[neuron_idx]}")
    detector_spike_times.extend(neuron_spike_times[neuron_idx])

# Sort the list of spike times
detector_spike_times.sort()

detector_spike_times = np.array(detector_spike_times)
preview_np_array(detector_spike_times, "detector_spike_times")

detector_spike_times Shape: (73,).
Preview: [  3213   6697   7231   7241  11132 ... 109529 110314 111098 113456 116798]


In [23]:
import math

curr_detector_idx = 0   # Index of the current detector neuron spike
curr_ground_truth_idx = 0   # Index of the current ground truth event in the ordered array

# Stores an active annotation. If None, there is no active annotation
# An annotation is active in the interval [annotation_start, annotation_start + confidence_window]
active_annotation = None    # Timestamp of the active annotation   

# Go through the voltage dynamics array and calculate the classification metrics of the network
for run_step, feat_neurons_v in enumerate(v_dynamics):
    is_end_detector_spikes = curr_detector_idx >= len(detector_spike_times)
    is_end_ground_truth = curr_ground_truth_idx >= len(ground_truth)

    # Block that checks if a new annotation gets activated
    if not is_end_ground_truth:
        # Check if the next annotation was inserted at this run_step
        next_annotation_insertion = math.ceil(ground_truth_timestamps[curr_ground_truth_idx])
        if next_annotation_insertion == run_step:
            if active_annotation != None:
                raise Exception("There is already an active annotation. This should not happen.")

            # Start a new active annotation
            active_annotation = next_annotation_insertion

            # Increment the curr_ground_truth_idx since this annotation will be processed
            curr_ground_truth_idx += 1

    # If there are still spikes from the Detector Neurons
    if not is_end_detector_spikes:
        next_spike_iter = detector_spike_times[curr_detector_idx]
        # If the next spike occurs in the current run_step interval, let's check what kind of prediction it is
        if next_spike_iter == run_step:
            # Increment the curr_detector_idx since this spike will be processed
            curr_detector_idx += 1

            # Check if there is an annotated event in the ground_truth within the confidence window [run_step, run_step-confidence_window]
            if active_annotation != None:
                # If there is a relevant event, let's see if the SNN predicted it.
                num_votes = 1   # The first vote is the current spike (SNN needs num_consensus_neurons to classify the event as relevant)

                # Calculate the window of detection for this event
                # The annotated events are floats, so we need to ceil the end_relevant_event_time because 
                # we are comparing with the integer run_step (iteration index of the simulation) (TODO: Include the virtual_step in the calculation)
                end_relevant_event_time = math.ceil(active_annotation + confidence_window)

                # Check if there are more spikes in the confidence window
                while curr_detector_idx < len(detector_spike_times) and detector_spike_times[curr_detector_idx] <= end_relevant_event_time:
                    num_votes += 1
                    curr_detector_idx += 1

                # If the number of votes is greater than the consensus threshold, classify the event as relevant
                if num_votes >= num_consensus_neurons:
                    # The SNN detected the event -> True Positive
                    true_positive += 1
                else:
                    # The SNN didn't detect the event (Not enough votes) -> False Negative
                    # False Negatives may also occur without any spikes in the network
                    false_negative += 1

                # The current annotation was processed, so we can deactivate it
                active_annotation = None
            else:
                # If there is no annotated event in the confidence window, let's check if the SNN predicted one.
                # If the SNN predicted an event, it is a False Positive
                # Calculate the window that would cause an SNN Detection
                end_relevant_event_iter = next_spike_iter + confidence_window

                # Verify if there isn't an annotated event between the current run_step and the end_relevant_event_iter
                # If there is an annotated event in between, we need to modify the end_relevant_event_iter to the annotated event time-1
                middle_relevant_event_time = has_relevant_event(end_relevant_event_iter, confidence_window)
                if middle_relevant_event_time != None:
                    # If there is an annotated event in the confidence window, we need to update the end of the window
                    # to the timestep before the annotated event, since we are analyzing the prediction of a False Positive
                    end_relevant_event_iter = math.ceil(middle_relevant_event_time) - 1 

                # Check if there are more spikes in the confidence window
                while curr_detector_idx < len(detector_spike_times) and detector_spike_times[curr_detector_idx] <= end_relevant_event_iter:
                    num_votes += 1
                    curr_detector_idx += 1

                # If the number of votes is greater than the consensus threshold
                if num_votes >= num_consensus_neurons:
                    # The SNN detected an event that does not exist -> False Positive
                    false_positive += 1
                else:
                    # The SNN didn't detect any event -> True Negative
                    true_negative += 1
            
            # Go to the next iteration
            continue
    
    # If one of the 2 conditions is true:
    #   1. There are no more spikes from the Detector Neurons
    #   2. The current run_step does not contain a spike from the Detector Neurons
    # Then: let's check if the active annotation is still active and report a True Negative or False Negative
    # according to the answer.
    
    # Block that checks if the active annotation is still active
    if active_annotation != None:
        if run_step >= active_annotation + confidence_window + 1:   # The +1 is to consider an extra iteration because the annotation is a float ceiled
            # The active annotation is over
            active_annotation = None

            # The SNN didn't detect the event in the interval [active_annotation, active_annotation + confidence_window]
            # Thus, it is a False Negative
            false_negative += 1
        else:
            # The active annotation is still active, so it is a True Negative (No event detected)
            true_negative += 1
    else:
        # If there is no active annotation, the SNN didn't detect any event rightfully so
        # So, it is a True Negative
        true_negative += 1


    # Edge case: If there is an annotated event at a distance < confidence_window from the end of the simulation, the feature events
    # won't have time to detect it. Thus, we don't consider that event.
    # TODO: Consider this edge case?

## Show the Confusion Matrix

In [24]:
# Print the Confusion Matrix
print("True Positive: ", true_positive)
print("False Positive: ", false_positive)
print("True Negative: ", true_negative)
print("False Negative: ", false_negative)

# Print the Total of predictions
total_predictions = true_positive + false_positive + true_negative + false_negative
print("Total Predictions: ", total_predictions)

True Positive:  8
False Positive:  1
True Negative:  119800
False Negative:  191
Total Predictions:  120000


## Calculate Classification Metrics

In [25]:
# Calculate relevant metrics
accuracy = (true_positive + true_negative) / total_predictions * 100    # Proportion of correct predictions
precision = true_positive / (true_positive + false_positive) * 100      # Proportion of TPs that were identified correctly
recall = true_positive / (true_positive + false_negative) * 100         # Proportion of TPs that were captured by the model
f1_score = (2 * precision * recall) / (precision + recall)              # Harmonic mean of Precision and Recall
specificity = true_negative / (true_negative + false_positive) * 100    # Proportion of TNs that were identified correctly

print(f"Accuracy: {accuracy}%")
print(f"Precision: {precision}%")
print(f"Recall: {recall}%")
print(f"F1 Score: {f1_score}")
print(f"Specificity: {specificity}%")

Accuracy: 99.83999999999999%
Precision: 88.88888888888889%
Recall: 4.0201005025125625%
F1 Score: 7.692307692307692
Specificity: 99.99916528242669%


# Export the results of the Classification to a JSON file
Export the results of the classification to a JSON file. This file will include:
- Frequency Band used (`Ripple`, `Fast Ripple` or `Both`)
- `Confidence Window` used
- Classification of each Feature Neuron (`Silent`, `Noisy`, `Ripple`, `Fast Ripple`, `Both`)
- Number of `Consensus Feature Neurons` to classify an event as a HFO
- Classification Metrics (`True Positives`, `False Positives`, `True Negatives`, `False Negatives`, `Accuracy`, `Precision`, `Recall`, `F1 Score`, `Specificity`)

In [26]:
# Calculate the number of feature neurons of each class
num_silent_neurons = np.count_nonzero(feature_neuron_class == NeuronClass.SILENT)
num_noisy_neurons = np.count_nonzero(feature_neuron_class == NeuronClass.NOISY)
num_detector_neurons = np.count_nonzero(feature_neuron_class == marker_type_to_neuron_class(chosen_band))

print(f"Number of Silent Neurons: {num_silent_neurons}")
print(f"Number of Noisy Neurons: {num_noisy_neurons}")
print(f"Number of Detector Neurons: {num_detector_neurons}")
print(f"Total: {num_silent_neurons + num_noisy_neurons + num_detector_neurons}")

Number of Silent Neurons: 148
Number of Noisy Neurons: 96
Number of Detector Neurons: 12
Total: 256


In [27]:
import json

# Export the results to a JSON file

# Create a dictionary with the results
json_results = {
    "freq_band": band_file_name,
    "confidence_window": confidence_window,
    "feature_neuron_class": {
        "array": feature_neuron_class.tolist(),
        "counts": {
            "silent_neurons": num_silent_neurons,
            "noisy_neurons": num_noisy_neurons,
            "detector_neurons": num_detector_neurons
        },
    },
    "num_consensus_neurons": num_consensus_neurons,
    "metrics": {
        "true_positive": true_positive,
        "false_positive": false_positive,
        "true_negative": true_negative,
        "false_negative": false_negative,
        "total_predictions": total_predictions,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "specificity": specificity
    }
}

EXPORT_JSON_FILE = True
if EXPORT_JSON_FILE:
    json_file_name = f"{OUTPUT_FOLDER}/{band_file_name}_results_0.07dv_time{init_offset}-{num_steps}-{virtual_time_step_interval}_classif{num_consensus_neurons}.json"
    with open(json_file_name, 'w') as f:
        json.dump(json_results, f)