# SNN that detects Channel bursts in a single channel
This notebook is a simple example of how to use a Spiking Neural Network (SNN) to detect network bursts in a single channel. 

## Definition of a network burst
A network burst is a sequence of spikes that occur in a short time window. The definition of a network burst is not unique and depends on the context. 

In this notebook, we will **consider a channel burst any sequence of 4 or more spikes that occurs within 20 ms**

### Check WD (change if necessary) and file loading

In [1]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)

# Check if the current WD is the file location
if "/src/network_bursts" not in os.getcwd():
    # Set working directory to this file location
    file_location = f"{os.getcwd()}/thesis-lava/src/network_bursts"
    print("File Location: ", file_location)

    # Change the current working Directory
    os.chdir(file_location)

    # New Working Directory
    print("New Working Directory: ", os.getcwd())

/home/monkin/Desktop/feup/thesis
File Location:  /home/monkin/Desktop/feup/thesis/thesis-lava/src/network_bursts
New Working Directory:  /home/monkin/Desktop/feup/thesis/thesis-lava/src/network_bursts


In [2]:
from lava.proc.lif.process import LIF, LIFRefractory
from lava.proc.dense.process import Dense

LIF?

[0;31mInit signature:[0m [0mLIF[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Leaky-Integrate-and-Fire (LIF) neural Process.

LIF dynamics abstracts to:
u[t] = u[t-1] * (1-du) + a_in         # neuron current
v[t] = v[t-1] * (1-dv) + u[t] + bias  # neuron voltage
s_out = v[t] > vth                    # spike if threshold is exceeded
v[t] = 0                              # reset at spike

Parameters
----------
shape : tuple(int)
    Number and topology of LIF neurons.
u : float, list, numpy.ndarray, optional
    Initial value of the neurons' current.
v : float, list, numpy.ndarray, optional
    Initial value of the neurons' voltage (membrane potential).
du : float, optional
    Inverse of decay time-constant for current decay. Currently, only a
    single decay can be set for the entire population of neurons.
dv : float, optional
    Inverse of decay time-constant for voltage decay. Curr

In [3]:
# Define the number of neurons in each LIF Layer
n1 = 251  # Let's create a simple program with only 1 LIF layer

## Choose the LIF Models to use

In [4]:
import numpy as np

v_th = 1
v_init = 0
du = 0.45
dv = 0.05
use_refractory = False
refrac_period = 20

# Scale the weights
weights_scale = 0.2

# Simulation Parameters
init_offset = 0  # 1000                  
virtual_time_step_interval = 1  
num_steps = 6 * (10**4)  # 30000    # 1000      # 3000  # 26500     # TODO: Check the number of steps to run the simulation for

# Input File
file_path =  "../lab_data/lab_data_all_channels.csv"   # "./data/custom_4spikes_20ms_ch1.csv"       # "../lab_data/lab_data_1-8channels.csv"

### Create the LIF Processes

In [5]:
# Create Processes
lif1 = LIF(
    shape=(n1,),  # There are 2 neurons
    vth=v_th,  # TODO: Verify these initial values
    v=v_init,
    dv=dv,    # Inverse of decay time-constant for voltage decay
    du=du,  # Inverse of decay time-constant for current decay
    bias_mant=0,
    bias_exp=0,
    name="lif1"
)

refrac_lif1 = LIFRefractory(
    shape=(n1,),  # There are 2 neurons
    vth=v_th,  # TODO: Verify these initial values
    v=v_init,
    dv=dv,    # Inverse of decay time-constant for voltage decay
    du=du,  # Inverse of decay time-constant for current decay
    bias_mant=0,
    bias_exp=0,
    refractory_period=refrac_period,
    name="lif1"
)

### Choose the Selected LIF

In [6]:
selected_lif = refrac_lif1 if use_refractory else lif1

## Create the Custom Input Layer

In [7]:
from utils.input import read_spike_events

# Call the function to read the spike events
spike_events = read_spike_events(file_path)
print("Spike events: ", spike_events.shape, spike_events[:10])

Spike events:  (40020, 2) [[ 99.5 229. ]
 [303.6   7. ]
 [502.5 229. ]
 [510.6  71. ]
 [528.   54. ]
 [540.9   7. ]
 [589.3 225. ]
 [631.6 100. ]
 [633.8 100. ]
 [758.3 229. ]]


In [8]:
# Find the max channel idx
max_channel_idx = int(np.max(spike_events[:, 1]))
print("Max Channel Idx: ", max_channel_idx)

Max Channel Idx:  251


## Generate the Ground Truth

In [9]:
from ground_truth import find_channel_bursts

# Generate the Ground Truth -> Find the channel bursts according to the conditions specified
CAUSALITY_WINDOW = refrac_period

# Parameters to Calculate Channel Bursts
num_spikes_to_burst = 4
max_burst_duration = CAUSALITY_WINDOW
min_inter_burst_interval = CAUSALITY_WINDOW
# Calculate the Channel Bursts
gt, gt_detailed = find_channel_bursts(spike_events, num_spikes_to_burst, max_burst_duration, min_inter_burst_interval, verbose=False)

print(f"Ground Truth detected {sum([len(gt_detailed[channel]) for channel in gt_detailed])} bursts")
print("Ground Truth: ", gt)
print("Ground Truth Detailed: ", gt_detailed)

Ground Truth detected 3701 bursts
Ground Truth:  {195.0: [1430.5, 2794.3, 2874.8, 4192.8, 5013.7, 5879.0, 5901.8, 5912.3, 5944.0, 5954.0, 5988.900000000001, 6001.8, 9848.7, 18772.7, 18806.7, 21585.2, 22081.4, 22176.9, 22208.7, 22220.3, 26736.5, 29603.5, 30794.9, 30814.8, 30844.7, 33297.5, 33408.8, 33441.6, 33451.0, 33481.6, 33508.5, 36258.1, 37081.4, 41398.2, 43103.1, 44341.5, 45168.0, 45201.7, 46132.9, 46789.1, 47369.0, 47413.6, 47447.6, 47457.3, 47492.5, 47505.7, 47536.6, 47549.1, 50238.5, 53910.9, 54706.4, 55153.5, 56964.5, 58732.1, 58762.7, 58772.0, 58803.3, 61612.4, 63493.9, 65648.0, 65682.1, 65691.7, 65724.2, 67098.8, 67996.3, 68151.7, 69477.4, 69517.1, 72148.3, 74914.2, 77350.0, 78235.90000000001, 78464.8, 79529.40000000001, 79557.40000000001, 79568.90000000001, 79599.90000000001, 81339.8, 81370.90000000001, 81385.40000000001, 82535.0, 84119.5, 84972.40000000001, 86599.90000000001, 87508.40000000001, 87536.7, 87545.7, 87576.40000000001, 87599.2, 87609.0, 87642.0, 88376.0, 89218.

In [10]:
# Crop the ground truth according to the simulation time
gt_cropped = gt.copy()
for elec_idx in gt_cropped:
    gt_cropped[elec_idx] = [spike_time for spike_time in gt_cropped[elec_idx] if (
        spike_time >= init_offset and
        spike_time < init_offset + num_steps*virtual_time_step_interval
        )
    ]

print(f"Ground Truth Cropped has {sum([len(gt_cropped[channel]) for channel in gt_cropped])} bursts")
print("Ground Truth Cropped: ", gt_cropped)

Ground Truth Cropped has 753 bursts
Ground Truth Cropped:  {195.0: [1430.5, 2794.3, 2874.8, 4192.8, 5013.7, 5879.0, 5901.8, 5912.3, 5944.0, 5954.0, 5988.900000000001, 6001.8, 9848.7, 18772.7, 18806.7, 21585.2, 22081.4, 22176.9, 22208.7, 22220.3, 26736.5, 29603.5, 30794.9, 30814.8, 30844.7, 33297.5, 33408.8, 33441.6, 33451.0, 33481.6, 33508.5, 36258.1, 37081.4, 41398.2, 43103.1, 44341.5, 45168.0, 45201.7, 46132.9, 46789.1, 47369.0, 47413.6, 47447.6, 47457.3, 47492.5, 47505.7, 47536.6, 47549.1, 50238.5, 53910.9, 54706.4, 55153.5, 56964.5, 58732.1, 58762.7, 58772.0, 58803.3], 229.0: [2796.5, 2882.0, 5019.3, 5884.1, 5913.7, 5924.7, 5956.900000000001, 5965.5, 5999.7, 13772.1, 18772.8, 18805.8, 21583.7, 22088.2, 22192.3, 29613.7, 30807.7, 30822.4, 30850.7, 33417.5, 33450.0, 33460.6, 33494.8, 37053.5, 41407.0, 44336.8, 45183.9, 47418.4, 47451.6, 47465.7, 47499.3, 47512.5, 47540.8, 55171.2, 57386.8, 58742.6, 58771.1, 58789.0], 240.0: [2802.1, 2892.2000000000003, 5012.8, 5889.7, 5904.0, 5936.40

### Map the input channels to the corresponding indexes in the input layer
Since the input channels in the input file may be of any number, we need to map the input channels to the corresponding indexes in the input layer. This is done by the `channel_map` dictionary.

In [11]:
# Map the channels of the input file to the respective index in the output list of SpikeEventGen
# channel_map = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7}

# Create a channel map based on the max channel idx
channel_map = {channel: idx for idx, channel in enumerate(range(1, max_channel_idx+1))}
print("Channel Map: ", channel_map)
print("Number of channels: ", len(channel_map))

Channel Map:  {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29, 31: 30, 32: 31, 33: 32, 34: 33, 35: 34, 36: 35, 37: 36, 38: 37, 39: 38, 40: 39, 41: 40, 42: 41, 43: 42, 44: 43, 45: 44, 46: 45, 47: 46, 48: 47, 49: 48, 50: 49, 51: 50, 52: 51, 53: 52, 54: 53, 55: 54, 56: 55, 57: 56, 58: 57, 59: 58, 60: 59, 61: 60, 62: 61, 63: 62, 64: 63, 65: 64, 66: 65, 67: 66, 68: 67, 69: 68, 70: 69, 71: 70, 72: 71, 73: 72, 74: 73, 75: 74, 76: 75, 77: 76, 78: 77, 79: 78, 80: 79, 81: 80, 82: 81, 83: 82, 84: 83, 85: 84, 86: 85, 87: 86, 88: 87, 89: 88, 90: 89, 91: 90, 92: 91, 93: 92, 94: 93, 95: 94, 96: 95, 97: 96, 98: 97, 99: 98, 100: 99, 101: 100, 102: 101, 103: 102, 104: 103, 105: 104, 106: 105, 107: 106, 108: 107, 109: 108, 110: 109, 111: 110, 112: 111, 113: 112, 114: 113, 115: 114, 116: 115, 117: 116, 118: 117, 119: 118, 120: 119, 121

### Create the `SpikeEventGenerator` object 

In [12]:
from utils.spike_event_gen import SpikeEventGen

# Create the Input Process
spike_event_gen = SpikeEventGen(shape=(n1,), spike_events=spike_events, name="CustomInput", channel_map=channel_map,
                            virtual_time_step_interval=virtual_time_step_interval, init_offset=init_offset)

## Connect the Custom Input Layer to the middle layer

To define the connectivity between the `SpikeGenerator` and the first `LIF` population, we use another `Dense` Layer.

In [13]:
# Define the matrix of weights
dense_weights = np.eye(n1)

# dense_weights = np.round(dense_weights * weights_scale).astype(np.int32)
dense_weights = dense_weights * weights_scale

# Instantiante a Dense Layer to connect the SpikeEventGen to the middle layer
dense_input = Dense(weights=dense_weights, name="DenseInput")     # 1-1 connectivity with the Middle Layer (should be a 2x2 matrix right now)

# If I connect the SpikeEventGen to the Dense Layer, the a_out value of the custom input will be rounded to 0 or 1 in the Dense Layer (it will not be a float) 
# Connect the SpikeEventGen to the Dense Layer
spike_event_gen.s_out.connect(dense_input.s_in)

# Connect the Dense_Input to the LIF1 Layer
dense_input.a_out.connect(selected_lif.a_in)

# Connect the SpikeEventGen layer directly to the LIF1 layer
# spike_event_gen.s_out.connect(lif1.a_in)

### Take a look at the connections in the Input Layer

In [14]:
for proc in [spike_event_gen, dense_input, selected_lif]:
    for port in proc.in_ports:
        print(f"Proc: {proc.name:<5} Port Name: {port.name:<5} Size: {port.size}")
    for port in proc.out_ports:
        print(f"Proc: {proc.name:<5} Port Name: {port.name:<5} Size: {port.size}")

Proc: CustomInput Port Name: s_out Size: 251
Proc: DenseInput Port Name: s_in  Size: 251
Proc: DenseInput Port Name: a_out Size: 251
Proc: lif1  Port Name: a_in  Size: 251
Proc: lif1  Port Name: s_out Size: 251


In [15]:
# Weights of the Input Dense Layer
dense_input.weights.get()

array([[0.2, 0. , 0. , ..., 0. , 0. , 0. ],
       [0. , 0.2, 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0.2, ..., 0. , 0. , 0. ],
       ...,
       [0. , 0. , 0. , ..., 0.2, 0. , 0. ],
       [0. , 0. , 0. , ..., 0. , 0.2, 0. ],
       [0. , 0. , 0. , ..., 0. , 0. , 0.2]])

### Record Internal Vars over time
To record the evolution of the internal variables over time, we need a `Monitor`. For this example, we want to record the membrane potential of the `LIF` Layer, hence we need 1 `Monitors`.

We can define the `Var` that a `Monitor` should record, as well as the recording duration, using the `probe` function

In [16]:
from lava.proc.monitor.process import Monitor

monitor_lif1_v = Monitor()
monitor_lif1_u = Monitor()

# Connect the monitors to the variables we want to monitor
monitor_lif1_v.probe(selected_lif.v, num_steps)
monitor_lif1_u.probe(selected_lif.u, num_steps)

## Execution
Now that we have defined the network, we can execute it. We will use the `run` function to execute the network.

### Run Configuration and Conditions

In [17]:
from lava.magma.core.run_conditions import RunContinuous, RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg

# run_condition = RunContinuous()   # TODO: Change to this one
run_condition = RunSteps(num_steps=num_steps)
run_cfg = Loihi1SimCfg(select_tag="floating_pt")   # TODO: Check why we need this select_tag="floating_pt"

### Execute

In [18]:
selected_lif.run(condition=run_condition, run_cfg=run_cfg)



### Retrieve recorded data

In [19]:
data_lif1_v = monitor_lif1_v.get_data()
data_lif1_u = monitor_lif1_u.get_data()

data_lif1 = data_lif1_v.copy()
data_lif1["lif1"]["u"] = data_lif1_u["lif1"]["u"]   # Merge the dictionaries to contain both voltage and current

In [20]:
data_lif1

{'lif1': {'v': array([[0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
          0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
         [0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
          0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
         [0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
          0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
         ...,
         [2.74064676e-027, 0.00000000e+000, 1.43279037e-322, ...,
          1.43279037e-322, 1.43279037e-322, 0.00000000e+000],
         [2.60361442e-027, 0.00000000e+000, 1.43279037e-322, ...,
          1.43279037e-322, 1.43279037e-322, 0.00000000e+000],
         [2.47343370e-027, 0.00000000e+000, 1.43279037e-322, ...,
          1.43279037e-322, 1.43279037e-322, 0.00000000e+000]]),
  'u': array([[0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
          0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
         [0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
    

In [21]:
# Check the shape to verify if it is printing the voltage for every step
len(data_lif1['lif1']['v'])     # Indeed, there are 300 values (same as the number of steps we ran the simulation for)

60000

## Find the timesteps where the network bursts occur

In [22]:
from utils.data_analysis import find_spike_times

# Call the find_spike_times util function that detects the spikes in a voltage array
spike_times_lif1 = find_spike_times(data_lif1['lif1']['v'], data_lif1['lif1']['u'])

# Add the initial offset to the spike times (1st index )
for spike_times in spike_times_lif1:
    spike_times[0] += init_offset
    
print("Spike times: ", spike_times_lif1)

Spike times:  [[ 1431   194]
 [ 1436     0]
 [ 2793   194]
 ...
 [58830    64]
 [58842    70]
 [58871    53]]


## View the Voltage and Current dynamics with an interactive plot

Grab the data from the recorded variables

In [23]:
lif1_voltage_vals = data_lif1['lif1']['v']
lif1_current_vals = data_lif1['lif1']['u']

print("v_ch0 shape:", len(lif1_voltage_vals))

# print("voltage head: ", lif1_voltage_vals[:10])

v_ch0 shape: 60000


## Assemble the values to be plotted

In [24]:
chosen_channels = [0, 3, 5, 6, 40, 45, 59, 174]
num_chosen_ch = len(chosen_channels)

voltage_lif1_y_arrays = [ ( lif1_voltage_vals[:, chosen_ch], f"Neuron. {chosen_ch}" ) for chosen_ch in chosen_channels ]

print("Voltage Lif1 Y Arrays: ", voltage_lif1_y_arrays[:5])

Voltage Lif1 Y Arrays:  [(array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
       2.74064676e-27, 2.60361442e-27, 2.47343370e-27]), 'Neuron. 0'), (array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
       1.24596700e-28, 1.18366865e-28, 1.12448522e-28]), 'Neuron. 3'), (array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
       4.56259549e-63, 4.33446571e-63, 4.11774243e-63]), 'Neuron. 5'), (array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
       4.08171474e-28, 3.87762900e-28, 3.68374755e-28]), 'Neuron. 6'), (array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
       1.82384245e-28, 1.73265033e-28, 1.64601781e-28]), 'Neuron. 40')]


In [25]:
current_lif1_y_arrays = [ ( lif1_current_vals[:, chosen_ch], f"Neuron. {chosen_ch}" ) for chosen_ch in chosen_channels ]
print("Current Lif1 Y Arrays: ", current_lif1_y_arrays[:5])

Current Lif1 Y Arrays:  [(array([0.00000000e+000, 0.00000000e+000, 0.00000000e+000, ...,
       6.50939300e-311, 3.58016615e-311, 1.96909138e-311]), 'Neuron. 0'), (array([0.00000e+000, 0.00000e+000, 0.00000e+000, ..., 1.37888e-318,
       7.58381e-319, 4.17110e-319]), 'Neuron. 3'), (array([0.e+000, 0.e+000, 0.e+000, ..., 5.e-324, 5.e-324, 5.e-324]), 'Neuron. 5'), (array([0.0000e+000, 0.0000e+000, 0.0000e+000, ..., 3.8275e-319,
       2.1051e-319, 1.1578e-319]), 'Neuron. 6'), (array([0.0e+000, 0.0e+000, 0.0e+000, ..., 7.4e-323, 4.0e-323, 2.0e-323]), 'Neuron. 40')]


In [26]:
from utils.line_plot import create_fig  # Import the function to create the figure
from bokeh.models import Range1d

# Define the x and y values
x = [val + init_offset for val in range(num_steps)]

# Create the plot
# Define the box annotation parameters
box_annotation_voltage = {
    "bottom": 0,
    "top": v_th,
    "left": 0,
    "right": num_steps,
    "fill_alpha": 0.03,
    "fill_color": "green"
}

# Create the LIF1 Voltage
voltage_lif1_plot = create_fig(
    title="LIF1 Voltage dynamics", 
    x_axis_label='time (ms)', 
    y_axis_label='Voltage (V)',
    x=x, 
    y_arrays=voltage_lif1_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    box_annotation_params=box_annotation_voltage,
    y_range=Range1d(-0.05, 1.05)
)


# Create the LIF1 Current
current_lif1_plot = create_fig(
    title="LIF1 Current dynamics", 
    x_axis_label='time (ms)', 
    y_axis_label='Current (U)',
    x=x, 
    y_arrays=current_lif1_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    x_range=voltage_lif1_plot.x_range,    # Link the x-axis range to the voltage plot
)

# bplt.show(voltage_lif1_plot)

## Show the Plots assembled in a grid

In [27]:
import bokeh.plotting as bplt
from bokeh.layouts import gridplot

showPlot = True
if showPlot:
    # Create array of plots to be shown
    plots = [voltage_lif1_plot, current_lif1_plot]

    if len(plots) == 1:
        grid = plots[0]
    else:   # Create a grid layout
        grid = gridplot(plots, ncols=2, sizing_mode="stretch_both")

    # Show the plot
    bplt.show(grid)

## Export the plot to a file

In [31]:
export = False

OUT_FOLDER = "./results/channel_burst"
REFRAC_SUFFIX = refrac_period if use_refractory else "no_refrac"
OUT_FILENAME = f"lab_ch1-{n1}_4spikes_20ms_{v_th}_{du}_{dv}_{REFRAC_SUFFIX}_{weights_scale}_{num_steps}steps"

if export:
    file_path = f"{OUT_FOLDER}/{OUT_FILENAME}.html"

    # Customize the output file settings
    bplt.output_file(filename=file_path, title="Channel Burst detection - LIF1 Voltage and Current dynamics")

    # Save the plot
    bplt.save(grid)

## Calculate Detection Metrics

### Convert the spike times to the same format as the Ground Truth

In [32]:
# Invert the mapping of the electrodes to the neuron indices
channel_map_inv = {v: k for k, v in channel_map.items()}
print("Channel Map Inverted: ", channel_map_inv)

Channel Map Inverted:  {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12, 12: 13, 13: 14, 14: 15, 15: 16, 16: 17, 17: 18, 18: 19, 19: 20, 20: 21, 21: 22, 22: 23, 23: 24, 24: 25, 25: 26, 26: 27, 27: 28, 28: 29, 29: 30, 30: 31, 31: 32, 32: 33, 33: 34, 34: 35, 35: 36, 36: 37, 37: 38, 38: 39, 39: 40, 40: 41, 41: 42, 42: 43, 43: 44, 44: 45, 45: 46, 46: 47, 47: 48, 48: 49, 49: 50, 50: 51, 51: 52, 52: 53, 53: 54, 54: 55, 55: 56, 56: 57, 57: 58, 58: 59, 59: 60, 60: 61, 61: 62, 62: 63, 63: 64, 64: 65, 65: 66, 66: 67, 67: 68, 68: 69, 69: 70, 70: 71, 71: 72, 72: 73, 73: 74, 74: 75, 75: 76, 76: 77, 77: 78, 78: 79, 79: 80, 80: 81, 81: 82, 82: 83, 83: 84, 84: 85, 85: 86, 86: 87, 87: 88, 88: 89, 89: 90, 90: 91, 91: 92, 92: 93, 93: 94, 94: 95, 95: 96, 96: 97, 97: 98, 98: 99, 99: 100, 100: 101, 101: 102, 102: 103, 103: 104, 104: 105, 105: 106, 106: 107, 107: 108, 108: 109, 109: 110, 110: 111, 111: 112, 112: 113, 113: 114, 114: 115, 115: 116, 116: 117, 117: 118, 118: 119, 119:

In [33]:
predicted_spikes = {}
for (spike_time, neuron_idx) in spike_times_lif1:
    # print(f"spike time: {spike_time} neuron idx: {neuron_idx}")

    # Map the neuron index to the electrode channel index
    electrode_idx = channel_map_inv[neuron_idx]

    # Add the spike time to the predicted spikes
    curr_ch_spikes = predicted_spikes.get(electrode_idx, [])
    curr_ch_spikes.append(spike_time)
    predicted_spikes[electrode_idx] = curr_ch_spikes

print(f"Predicted {sum([len(predicted_spikes[channel]) for channel in predicted_spikes])} spikes")
print("Predicted Spikes: ", predicted_spikes)

Predicted 1531 spikes
Predicted Spikes:  {195: [1431, 2793, 2800, 2875, 2883, 2891, 4193, 5014, 5022, 5028, 5042, 5879, 5886, 5894, 5902, 5910, 5918, 5925, 5933, 5941, 5949, 5957, 5964, 5972, 5979, 5986, 5997, 6004, 6015, 6023, 6045, 9848, 10177, 10261, 13771, 18771, 18779, 18786, 18797, 18804, 18813, 20756, 21584, 21600, 22078, 22087, 22097, 22106, 22176, 22182, 22189, 22197, 22203, 22212, 22221, 26736, 29602, 29612, 29619, 29628, 30793, 30812, 30821, 30829, 30836, 30843, 33298, 33408, 33415, 33423, 33430, 33437, 33444, 33451, 33459, 33468, 33477, 33484, 33494, 33527, 36256, 37077, 39274, 41397, 43102, 44339, 45167, 45173, 45180, 45188, 45197, 46132, 46139, 46791, 46814, 47369, 47411, 47418, 47426, 47434, 47442, 47448, 47456, 47463, 47470, 47478, 47485, 47493, 47503, 47513, 47522, 47532, 47540, 47550, 49914, 50237, 50248, 53909, 54706, 55153, 55164, 56965, 58731, 58737, 58745, 58754, 58761, 58768, 58775, 58782, 58794, 58804], 1: [1436, 2888, 2902, 5011, 5018, 5030, 5902, 5909, 5917, 5

### Calculate the Confusion Matrix

In [34]:
# Initialize the variables that store the values of the Confusion Matrix
true_positive = 0
false_positive = 0
true_negative = 0
false_negative = 0

In [35]:
# Go through the ground truth and check if the predicted spikes are correct
# TODO: For now using I'm considering the causality window [gt_burst_time - CAUSALITY_WINDOW, gt_burst_time + CAUSALITY_WINDOW]. 
for electrode_idx in gt_cropped:
    # Get the ground truth burst times for the current electrode
    gt_burst_times = gt_cropped[electrode_idx]

    # Get the predicted burst times for the current electrode
    pred_burst_times = predicted_spikes.get(electrode_idx, [])

    # Check if the predicted spikes are correct
    for gt_burst_time in gt_burst_times:
        # Check if the predicted burst time is within the causality window [gt_burst_time, gt_burst_time + CAUSALITY_WINDOW]
        # TODO: Test causality window only after the annotated burst time!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        if any([abs(pred_burst_time - gt_burst_time) <= CAUSALITY_WINDOW for pred_burst_time in pred_burst_times]):
            true_positive += 1
        else:
            false_negative += 1

    # Check if the predicted spikes are false positives
    for pred_burst_time in pred_burst_times:
        # Check if the predicted burst time is within the causality window [gt_burst_time, gt_burst_time + CAUSALITY_WINDOW]
        if all([abs(pred_burst_time - gt_burst_time) > CAUSALITY_WINDOW for gt_burst_time in gt_burst_times]):
            false_positive += 1

# Calculate the True Negative value
# TN = P - (TP + FP + FN)
true_negative = num_steps - true_positive - false_positive - false_negative

In [36]:
# Print the Confusion Matrix
print("True Positive: ", true_positive)
print("False Positive: ", false_positive)
print("True Negative: ", true_negative)
print("False Negative: ", false_negative)

# Print the Total of predictions
total_predictions = true_positive + false_positive + true_negative + false_negative
print("Total Predictions: ", total_predictions)

True Positive:  749
False Positive:  104
True Negative:  59143
False Negative:  4
Total Predictions:  60000


### Calculate Prediction Metrics

In [37]:
# Calculate relevant metrics
if true_positive + false_positive == 0:
    print("No relevant predictions were made. Cannot calculate metrics.")
    exit()

accuracy = (true_positive + true_negative) / total_predictions * 100    # Proportion of correct predictions
precision = true_positive / (true_positive + false_positive) * 100      # Proportion of TPs that were identified correctly
recall = true_positive / (true_positive + false_negative) * 100         # Proportion of TPs that were captured by the model
f1_score = (2 * precision * recall) / (precision + recall)              # Harmonic mean of Precision and Recall
specificity = true_negative / (true_negative + false_positive) * 100    # Proportion of TNs that were identified correctly

print(f"Accuracy: {accuracy}%")
print(f"Precision: {precision}%")
print(f"Recall: {recall}%")
print(f"F1 Score: {f1_score}")
print(f"Specificity: {specificity}%")

Accuracy: 99.82%
Precision: 87.80773739742087%
Recall: 99.46879150066401%
F1 Score: 93.27521793275218
Specificity: 99.82446368592502%


# Export the results of the Detection to a JSON file
Export the results of the classification to a JSON file. This file will include:
- `Causality Window` used
- Classification Metrics (`True Positives`, `False Positives`, `True Negatives`, `False Negatives`, `Accuracy`, `Precision`, `Recall`, `F1 Score`, `Specificity`)

In [38]:
import json

# Export the results to a JSON file

# Create a dictionary with the results
json_results = {
    "causality_window": CAUSALITY_WINDOW,
    "metrics": {
        "true_positive": true_positive,
        "false_positive": false_positive,
        "true_negative": true_negative,
        "false_negative": false_negative,
        "total_predictions": total_predictions,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "specificity": specificity
    }
}

EXPORT_JSON_FILE = False
if EXPORT_JSON_FILE:
    json_file_name = f"{OUT_FOLDER}/{OUT_FILENAME}_metrics.json"
    with open(json_file_name, 'w') as f:
        json.dump(json_results, f)

## Stop the Runtime

In [39]:
selected_lif.stop()