# SNN that detects Network bursts
This notebook is a simple example of how to use a Spiking Neural Network (SNN) to detect network bursts in a network of 5 neurons (channels)

## Definition of a network burst
A network burst is a sequence of spikes that occur in a short time window in a neural population. The definition of a network burst is not unique and depends on the context. 

In this notebook, we will **consider a channel burst a neuronal activity where 4 spikes occur within 10ms in the same channel**.

In this notebook, we will **consider a network burst a neuronal activity where 3 channels spike in a 20ms time frame**.

### Check WD (change if necessary) and file loading

In [33]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)

# Check if the current WD is the file location
if "/src/network_bursts" not in os.getcwd():
    # Set working directory to this file location
    file_location = f"{os.getcwd()}/thesis-lava/src/network_bursts"
    print("File Location: ", file_location)

    # Change the current working Directory
    os.chdir(file_location)

    # New Working Directory
    print("New Working Directory: ", os.getcwd())

/home/monkin/Desktop/feup/thesis/thesis-lava/src/network_bursts


In [34]:
from lava.proc.lif.process import LIF, LIFRefractory
from lava.proc.dense.process import Dense
import numpy as np

LIF?

[0;31mInit signature:[0m [0mLIF[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Leaky-Integrate-and-Fire (LIF) neural Process.

LIF dynamics abstracts to:
u[t] = u[t-1] * (1-du) + a_in         # neuron current
v[t] = v[t-1] * (1-dv) + u[t] + bias  # neuron voltage
s_out = v[t] > vth                    # spike if threshold is exceeded
v[t] = 0                              # reset at spike

Parameters
----------
shape : tuple(int)
    Number and topology of LIF neurons.
u : float, list, numpy.ndarray, optional
    Initial value of the neurons' current.
v : float, list, numpy.ndarray, optional
    Initial value of the neurons' voltage (membrane potential).
du : float, optional
    Inverse of decay time-constant for current decay. Currently, only a
    single decay can be set for the entire population of neurons.
dv : float, optional
    Inverse of decay time-constant for voltage decay. Curr

## Define the Architecture of the Network

In [35]:
# Define the number of neurons in each LIF Layer
n1 = 10
n2 = 1

## Choose the LIF Models to use

In [36]:
import numpy as np

# -- LIF Parameters --
# Fixed
v_th = 1
v_init = 0
C_IBI = 20    # Channel Inter-Burst Interval (ms)
N_IBI = 40    # Network Inter-Burst Interval (ms)
# Tunable
du1 = 0.45
dv1 = 0.05
du2 = 0.95
dv2 = 0.05
use_refractory = True
refrac_period1 = C_IBI  # Should it be the same for both layers?
refrac_period2 = N_IBI

# Scale the weights
weights_scale_input = 0.2
weights_scale_middle = 0.5

# Simulation Parameters
init_offset = 0  # 1000                  
virtual_time_step_interval = 1  
num_steps = 10000    # 1000      # 3000  # 26500     # TODO: Check the number of steps to run the simulation for

# Input File
file_path =  "../lab_data/lab_data_all_channels.csv"   # "./data/custom_4spikes_20ms_ch1.csv"       # "../lab_data/lab_data_1-8channels.csv"

### Create the LIF Processes

In [37]:
# Create Processes
lif1 = LIF(
    shape=(n1,),  # There are 2 neurons
    vth=v_th,  # TODO: Verify these initial values
    v=v_init,
    dv=dv1,    # Inverse of decay time-constant for voltage decay
    du=du1,  # Inverse of decay time-constant for current decay
    bias_mant=0,
    bias_exp=0,
    name="lif1"
)

lif2 = LIF(
    shape=(n2,),  # There is 1 neuron
    vth=v_th,  # TODO: Verify these initial values
    v=v_init,
    dv=dv2,    # Inverse of decay time-constant for voltage decay
    du=du2,  # Inverse of decay time-constant for current decay
    bias_mant=0,
    bias_exp=0,
    name="lif2"
)

In [38]:
refrac_lif1 = LIFRefractory(
    shape=(n1,),  # There are 2 neurons
    vth=v_th,  # TODO: Verify these initial values
    v=v_init,
    dv=dv1,    # Inverse of decay time-constant for voltage decay
    du=du1,  # Inverse of decay time-constant for current decay
    bias_mant=0,
    bias_exp=0,
    refractory_period=refrac_period1,
    name="lif1"
)

refrac_lif2 = LIFRefractory(
    shape=(n2,),  # There are 2 neurons
    vth=v_th,  # TODO: Verify these initial values
    v=v_init,
    dv=dv2,    # Inverse of decay time-constant for voltage decay
    du=du2,  # Inverse of decay time-constant for current decay
    bias_mant=0,
    bias_exp=0,
    refractory_period=refrac_period2,
    name="lif2"
)

### Choose the Selected LIF

In [39]:
selected_lif1 = refrac_lif1 if use_refractory else lif1
selected_lif2 = refrac_lif2 if use_refractory else lif2

## Create the Custom Input Layer

In [40]:
from utils.input import read_spike_events

# Call the function to read the spike events
spike_events = read_spike_events(file_path)
print("Spike events: ", spike_events.shape, spike_events[:10])

Spike events:  (40020, 2) [[ 99.5 229. ]
 [303.6   7. ]
 [502.5 229. ]
 ...
 [631.6 100. ]
 [633.8 100. ]
 [758.3 229. ]]


In [41]:
# Find the max channel idx
max_channel_idx = int(np.max(spike_events[:, 1]))
print("Max Channel Idx: ", max_channel_idx)

Max Channel Idx:  251


### Map the input channels to the corresponding indexes in the input layer
Since the input channels in the input file may be of any number, we need to map the input channels to the corresponding indexes in the input layer. This is done by the `channel_map` dictionary.

In [42]:
# Map the channels of the input file to the respective index in the output list of SpikeEventGen
# channel_map = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7}
channel_map = {1: 0, 4: 1, 6: 2, 7: 3, 8: 4}
channel_map = {ch:idx for idx, ch in enumerate(range(45, 55, 1))}

relevant_channels = list(channel_map.keys())

# Create a channel map based on the max channel idx
# channel_map = {channel: idx for idx, channel in enumerate(range(1, max_channel_idx+1))}
print("Channel Map: ", channel_map)
print("Number of channels: ", len(channel_map))
print("Relevant Channels: ", relevant_channels)

Channel Map:  {45: 0, 46: 1, 47: 2, 48: 3, 49: 4, 50: 5, 51: 6, 52: 7, 53: 8, 54: 9}
Number of channels:  10
Relevant Channels:  [45, 46, 47, 48, 49, 50, 51, 52, 53, 54]


## Generate the Ground Truth

### Filter the `spike_events` belonging to a relevant channel

In [43]:
from utils.io import preview_np_array

rel_spike_events = list(filter(lambda x: x[1] in relevant_channels, spike_events))
rel_spike_events = np.array(rel_spike_events)

preview_np_array(rel_spike_events, "Relevant Spike Events")

Relevant Spike Events Shape: (4989, 2).
Preview: [[5.280000e+02 5.400000e+01]
 [1.251200e+03 5.400000e+01]
 [1.452800e+03 5.400000e+01]
 [1.460500e+03 5.200000e+01]
 [1.471600e+03 5.200000e+01]
 ...
 [2.972576e+05 5.400000e+01]
 [2.987729e+05 5.400000e+01]
 [2.999202e+05 5.400000e+01]
 [2.999223e+05 4.500000e+01]
 [2.999228e+05 5.400000e+01]]


### 1) Find the Channel Bursts first

In [44]:
from ground_truth import find_channel_bursts

# Generate the Ground Truth -> Find the network bursts according to the conditions specified
CH_CAUSALITY_WINDOW = C_IBI

# Parameters to Calculate Channel Bursts
num_spikes_to_burst = 4
max_burst_duration = CH_CAUSALITY_WINDOW
min_inter_burst_interval = CH_CAUSALITY_WINDOW
# Calculate the Channel Bursts
ch_bursts, ch_bursts_detailed = find_channel_bursts(rel_spike_events, num_spikes_to_burst, max_burst_duration, min_inter_burst_interval, verbose=False)

print(f"CBs detected {sum([len2: ", ch_bursts_detailed)

CBs detected 490 bursts
Channel Bursts:  {54.0: [2907.0, 4999.7, 5034.2, 5896.1, 5926.3, 5936.1, 5967.400000000001, 5976.5, 6008.400000000001, 6021.400000000001, 15752.9, 20753.6, 22067.4, 22103.5, 22172.1, 22204.8, 22220.4, 23479.0, 26724.2, 29609.5, 30804.0, 30836.6, 33409.8, 33444.6, 33453.9, 33486.3, 33500.1, 33535.7, 33558.2, 39237.9, 43124.1, 44357.3, 45170.4, 45203.7, 47426.0, 47456.7, 47467.4, 47503.3, 47514.8, 47545.1, 47567.7, 49918.4, 50250.1, 53529.4, 53916.6, 55145.9, 56965.3, 57199.9, 58727.5, 58760.2, 58770.8, 58804.2, 60260.8, 61611.1, 63046.1, 65677.1, 65687.7, 65721.3, 67461.9, 68145.6, 69293.8, 69418.9, 69464.0, 69496.1, 69505.8, 72200.40000000001, 72382.5, 76180.6, 77351.90000000001, 78228.0, 79559.6, 79591.3, 79609.5, 81341.0, 81373.5, 81384.0, 83448.8, 86603.0, 87504.40000000001, 87535.90000000001, 87545.7, 87577.2, 87586.5, 87621.0, 87631.6, 87658.0, 99452.8, 101396.7, 102432.4, 104102.7, 105359.6, 105390.7, 105414.5, 105426.3, 113730.4, 116531.5, 116755.7, 11676

### 2) Find the Network Bursts using the Channel Bursts as the input spikes

In [78]:
# Convert the Channel Bursts to the same format as the spike_events
ch_burst_events = []

for ch, burst_times in ch_bursts.items():
    for time in burst_times:
        ch_burst_events.append([time, ch])

ch_burst_events = np.array(ch_burst_events)
# Sort the channel burst events by time
ch_burst_events = ch_burst_events[ch_burst_events[:, 0].argsort()]

preview_np_array(ch_burst_events, "ch_burst_events", edge_items=40)

ch_burst_events Shape: (490, 2).
Preview: [[2.907000e+03 5.400000e+01]
 [2.907700e+03 4.500000e+01]
 [2.915200e+03 5.200000e+01]
 [4.999700e+03 5.400000e+01]
 [5.002400e+03 4.500000e+01]
 [5.033500e+03 4.500000e+01]
 [5.034200e+03 5.400000e+01]
 [5.894600e+03 4.500000e+01]
 [5.896100e+03 5.400000e+01]
 [5.916300e+03 4.900000e+01]
 [5.924400e+03 4.500000e+01]
 [5.926300e+03 5.400000e+01]
 [5.936100e+03 5.400000e+01]
 [5.946700e+03 4.900000e+01]
 [5.948900e+03 4.500000e+01]
 [5.964800e+03 4.500000e+01]
 [5.967400e+03 5.400000e+01]
 [5.976500e+03 5.400000e+01]
 [5.990900e+03 4.500000e+01]
 [6.008400e+03 5.400000e+01]
 [6.009400e+03 4.500000e+01]
 [6.021400e+03 5.400000e+01]
 [1.575290e+04 5.400000e+01]
 [2.075360e+04 5.400000e+01]
 [2.075760e+04 4.500000e+01]
 [2.206740e+04 5.400000e+01]
 [2.207440e+04 4.500000e+01]
 [2.208030e+04 4.900000e+01]
 [2.210270e+04 4.500000e+01]
 [2.210350e+04 5.400000e+01]
 [2.217210e+04 5.400000e+01]
 [2.217550e+04 4.500000e+01]
 [2.218840e+04 4.900000e+01]
 

In [46]:
from ground_truth import find_net_bursts

# Generate the Ground Truth -> Find the network bursts according to the conditions specified
NET_CAUSALITY_WINDOW = C_IBI

# Parameters to Calculate Network Bursts
num_spikes_to_burst = 3
max_burst_duration = NET_CAUSALITY_WINDOW
min_inter_burst_interval = N_IBI
# Calculate the Network Bursts
gt, gt_detailed = find_net_bursts(ch_burst_events, num_spikes_to_burst, max_burst_duration, min_inter_burst_interval, verbose=False)

print(f"NBs detected {len(gt)} bursts")
print("Network Bursts: ", gt)
print("Network Bursts Detailed: ", gt_detailed)

NBs detected 23 bursts
Network Bursts:  [2915.2000000000003, 5926.3, 22080.3, 22188.4, 29617.4, 33418.3, 47459.1, 55161.0, 58739.3, 69475.7, 87537.8, 105363.2, 139034.1, 141385.5, 157759.7, 167635.2, 172743.1, 196678.0, 197577.9, 239255.2, 253974.6, 265771.4, 290489.0]
Network Bursts Detailed:  [(2915.2000000000003, [54.0, 45.0, 52.0]), (5926.3, [49.0, 45.0, 54.0]), (22080.3, [54.0, 45.0, 49.0]), (22188.4, [54.0, 45.0, 49.0]), (29617.4, [54.0, 45.0, 49.0]), (33418.3, [54.0, 45.0, 49.0]), (47459.1, [45.0, 54.0, 49.0]), (55161.0, [54.0, 45.0, 49.0]), (58739.3, [54.0, 45.0, 49.0]), (69475.7, [54.0, 45.0, 49.0]), (87537.8, [49.0, 54.0, 45.0]), (105363.2, [54.0, 45.0, 49.0]), (139034.1, [54.0, 45.0, 49.0]), (141385.5, [54.0, 45.0, 49.0]), (157759.7, [45.0, 54.0, 49.0]), (167635.2, [54.0, 45.0, 52.0]), (172743.1, [45.0, 54.0, 49.0]), (196678.0, [54.0, 45.0, 49.0]), (197577.9, [54.0, 45.0, 52.0]), (239255.2, [54.0, 45.0, 49.0]), (253974.6, [54.0, 45.0, 49.0]), (265771.4, [54.0, 45.0, 49.0]), 

Crop the ground truth according to the simulation time and selected channels

In [47]:
gt_detailed_cropped = gt_detailed.copy()
gt_detailed_cropped = [burst_info for burst_info in gt_detailed_cropped if (
    burst_info[0] >= init_offset and
    burst_info[0] < init_offset + num_steps*virtual_time_step_interval
    )
]

print(f"Ground Truth Detailed Cropped has {len(gt_detailed_cropped)} bursts")
print("Ground Truth Detailed Cropped: ", gt_detailed_cropped)

Ground Truth Detailed Cropped has 2 bursts
Ground Truth Detailed Cropped:  [(2915.2000000000003, [54.0, 45.0, 52.0]), (5926.3, [49.0, 45.0, 54.0])]


In [48]:
# Crop the ground truth according to the simulation time
gt_cropped = gt.copy()
gt_cropped = [spike_time for spike_time in gt_cropped if (
    spike_time >= init_offset and
    spike_time < init_offset + num_steps*virtual_time_step_interval
    )
]

print(f"Ground Truth Cropped has {len(gt_cropped)} bursts")
print("Ground Truth Cropped: ", gt_cropped)

Ground Truth Cropped has 2 bursts
Ground Truth Cropped:  [2915.2000000000003, 5926.3]


### Create the `SpikeEventGenerator` object 

In [49]:
from utils.spike_event_gen import SpikeEventGen

# Create the Input Process
spike_event_gen = SpikeEventGen(shape=(n1,), spike_events=rel_spike_events, name="CustomInput", channel_map=channel_map,
                            virtual_time_step_interval=virtual_time_step_interval, init_offset=init_offset)

# TODO: Check the channels being used and alter the GT to only view those channnels

### Create the Dense Layers

In [50]:
# Create Dense Process to connect the input layer and LIF1
# create weights of the dense layer
dense_weights_input = np.eye(N=n1, M=n1)
# multiply the weights of the Dense layer by a constant
dense_weights_input *= weights_scale_input
dense_input = Dense(
    weights=np.array(dense_weights_input), 
    name="DenseInput"
)


# Create Dense Process to connect LIF1 and LIF2
# create weights of the dense layer connecting LIF1 and LIF2
dense_weights_middle = np.ones(shape=(n2, n1))

# multiply the weights of the Dense layer by a constant
dense_weights_middle *= weights_scale_middle

# Create Dense Process to connect the two LIF layers
dense_middle = Dense(
    shape=(n1, n2),  # There are 2 neurons in the first layer and 1 in the second
    weights=np.array(dense_weights_middle),
    name="Dense_LIF1-2"
)

## Connect the Layers

In [51]:
# Connect the SpikeEventGen to the Dense Layer
spike_event_gen.s_out.connect(dense_input.s_in)

# Connect the Dense_Input to the LIF1 Layer
dense_input.a_out.connect(selected_lif1.a_in)

# Connect the LIF1 Layer to the Dense Layer
selected_lif1.s_out.connect(dense_middle.s_in)   # Connect the output of the first LIF layer to the Dense Layer
# Connect the Dense Layer to the LIF2 Layer
dense_middle.a_out.connect(selected_lif2.a_in)   # Connect the output of the Dense Layer to the second LIF Layer

### Take a look at the connections in the Input Layer

In [52]:
for proc in [spike_event_gen, dense_input, selected_lif1, dense_middle, selected_lif2]:
    for port in proc.in_ports:
        print(f"Proc: {proc.name:<5} Port Name: {port.name:<5} Size: {port.size}")
    for port in proc.out_ports:
        print(f"Proc: {proc.name:<5} Port Name: {port.name:<5} Size: {port.size}")

Proc: CustomInput Port Name: s_out Size: 10
Proc: DenseInput Port Name: s_in  Size: 10
Proc: DenseInput Port Name: a_out Size: 10
Proc: lif1  Port Name: a_in  Size: 10
Proc: lif1  Port Name: s_out Size: 10
Proc: Dense_LIF1-2 Port Name: s_in  Size: 10
Proc: Dense_LIF1-2 Port Name: a_out Size: 1
Proc: lif2  Port Name: a_in  Size: 1
Proc: lif2  Port Name: s_out Size: 1


### Look at the weights of the Dense Layers

In [53]:
# Weights of the Input Dense Layer
dense_input.weights.get()

array([[0.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.2]])

In [54]:
# Weights of the Dense Layer between LIF1 and LIF2
dense_middle.weights.get()

array([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]])

### Record Internal Vars over time
To record the evolution of the internal variables over time, we need a `Monitor`. For this example, we want to record the membrane potential of the `LIF` Layer, hence we need 1 `Monitors`.

We can define the `Var` that a `Monitor` should record, as well as the recording duration, using the `probe` function

In [55]:
from lava.proc.monitor.process import Monitor

monitor_lif1_v = Monitor()
monitor_lif1_u = Monitor()
monitor_lif2_v = Monitor()
monitor_lif2_u = Monitor()

# Connect the monitors to the variables we want to monitor
monitor_lif1_v.probe(selected_lif1.v, num_steps)
monitor_lif1_u.probe(selected_lif1.u, num_steps)
monitor_lif2_v.probe(selected_lif2.v, num_steps)
monitor_lif2_u.probe(selected_lif2.u, num_steps)

## Execution
Now that we have defined the network, we can execute it. We will use the `run` function to execute the network.

### Run Configuration and Conditions

In [56]:
from lava.magma.core.run_conditions import RunContinuous, RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg

# run_condition = RunContinuous()   # TODO: Change to this one
run_condition = RunSteps(num_steps=num_steps)
run_cfg = Loihi1SimCfg(select_tag="floating_pt")   # TODO: Check why we need this select_tag="floating_pt"

### Execute

In [57]:
selected_lif1.run(condition=run_condition, run_cfg=run_cfg)



### Retrieve recorded data

In [58]:
data_lif1_v = monitor_lif1_v.get_data()
data_lif1_u = monitor_lif1_u.get_data()
data_lif2_v = monitor_lif2_v.get_data()
data_lif2_u = monitor_lif2_u.get_data()

data_lif1 = data_lif1_v.copy()
data_lif1["lif1"]["u"] = data_lif1_u["lif1"]["u"]   # Merge the dictionaries to contain both voltage and current

data_lif2 = data_lif2_v.copy()
data_lif2["lif2"]["u"] = data_lif2_u["lif2"]["u"]   # Merge the dictionaries to contain both voltage and current


In [59]:
# print("data_lif1:", data_lif1)

In [60]:
# print("data_lif2:", data_lif2)

In [61]:
# Check the shape to verify if it is printing the voltage for every step
preview_np_array(data_lif1['lif1']['v'], "LIF1 V:", edge_items=3)    # Indeed, there are 300 values (same as the number of steps we ran the simulation for)

preview_np_array(data_lif2['lif2']['v'], "LIF2 V:", edge_items=3)   

LIF1 V: Shape: (10000, 10).
Preview: [[0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 ...
 [2.01785355e-53 1.91696087e-53 0.00000000e+00 ... 2.31546964e-89
  0.00000000e+00 1.99879937e-09]
 [1.91696087e-53 1.82111283e-53 0.00000000e+00 ... 2.19969616e-89
  0.00000000e+00 1.89885940e-09]
 [1.82111283e-53 1.73005719e-53 0.00000000e+00 ... 2.08971135e-89
  0.00000000e+00 1.80391643e-09]]
LIF2 V: Shape: (10000, 1).
Preview: [[0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 ...
 [3.22053515e-89]
 [3.05950840e-89]
 [2.90653298e-89]]


## Find the timesteps where the network bursts occur

In [62]:
from utils.data_analysis import find_spike_times

# Call the find_spike_times util function that detects the spikes in a voltage array
spike_times_lif1 = find_spike_times(data_lif1['lif1']['v'], data_lif1['lif1']['u'])
spike_times_lif2 = find_spike_times(data_lif2['lif2']['v'], data_lif2['lif2']['u'])


print("Spike times LIF1: ", spike_times_lif1)
print("Spike times LIF2: ", spike_times_lif2)

Spike times LIF1:  [[2907    9]
 [2908    0]
 [2916    7]
 ...
 [6004    9]
 [6011    0]
 [6036    9]]
Spike times LIF2:  [[2917    0]
 [5919    0]
 [5978    0]]


## View the Voltage and Current dynamics with an interactive plot

Grab the data from the recorded variables

In [63]:
# LIF1 variables
lif1_voltage_vals = data_lif1['lif1']['v']
lif1_current_vals = data_lif1['lif1']['u']

print("lif1 voltage shape:", len(lif1_voltage_vals))
# print("voltage head: ", lif1_voltage_vals[:10])


# LIF2 variables
lif2_voltage_vals = data_lif2['lif2']['v']
lif2_current_vals = data_lif2['lif2']['u']

print("lif2 voltage shape:", len(lif2_voltage_vals))

lif1 voltage shape: 10000
lif2 voltage shape: 10000


## Assemble the values to be plotted

In [64]:
from utils.line_plot import create_fig  # Import the function to create the figure
from bokeh.models import Range1d

# Define the x and y values
x = [val + init_offset for val in range(num_steps)]

v_y1 = [val[0] for val in lif1_voltage_vals]
v_y2 = [val[1] for val in lif1_voltage_vals]
v_y3 = [val[2] for val in lif1_voltage_vals]
v_y4 = [val[3] for val in lif1_voltage_vals]
v_y5 = [val[4] for val in lif1_voltage_vals]
v_y6 = [val[5] for val in lif1_voltage_vals]
v_y7 = [val[6] for val in lif1_voltage_vals]
v_y8 = [val[7] for val in lif1_voltage_vals]
v_y9 = [val[8] for val in lif1_voltage_vals]
v_y10 = [val[9] for val in lif1_voltage_vals]

# Create the plot
voltage_lif1_y_arrays = [
    (v_y1, "Ch. 1"), (v_y2, "Ch. 2"), (v_y3, "Ch. 3"),
    (v_y4, "Ch. 4"), (v_y5, "Ch. 5"), (v_y6, "Ch. 6"),
    (v_y7, "Ch. 7"), (v_y8, "Ch. 8"), (v_y9, "Ch. 9"),
    (v_y10, "Ch. 10")
]    # List of tuples containing the y values and the legend label
# Define the box annotation parameters
box_annotation_voltage = {
    "bottom": 0,
    "top": v_th,
    "left": 0,
    "right": num_steps,
    "fill_alpha": 0.03,
    "fill_color": "green"
}

# Create the LIF1 Voltage
voltage_lif1_plot = create_fig(
    title="LIF1 Voltage dynamics", 
    x_axis_label='time (ms)', 
    y_axis_label='Voltage (V)',
    x=x, 
    y_arrays=voltage_lif1_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    box_annotation_params=box_annotation_voltage,
    y_range=Range1d(-0.05, 1.05)
)


# Create the LIF1 Current
u_y1 = [val[0] for val in lif1_current_vals]
u_y2 = [val[1] for val in lif1_current_vals]
u_y3 = [val[2] for val in lif1_current_vals]
u_y4 = [val[3] for val in lif1_current_vals]
u_y5 = [val[4] for val in lif1_current_vals]
u_y6 = [val[5] for val in lif1_current_vals]
u_y7 = [val[6] for val in lif1_current_vals]
u_y8 = [val[7] for val in lif1_current_vals]
u_y9 = [val[8] for val in lif1_current_vals]
u_y10 = [val[9] for val in lif1_current_vals]

current_lif1_y_arrays = [
    (u_y1, "Ch. 1"), (u_y2, "Ch. 2"), (u_y3, "Ch. 3"),
    (u_y4, "Ch. 4"), (u_y5, "Ch. 5"), (u_y6, "Ch. 6"),
    (u_y7, "Ch. 7"), (u_y8, "Ch. 8"), (u_y9, "Ch. 9"),
    (u_y10, "Ch. 10")
]    # List of tuples containing the y values and the legend label

current_lif1_plot = create_fig(
    title="LIF1 Current dynamics", 
    x_axis_label='time (ms)', 
    y_axis_label='Current (U)',
    x=x, 
    y_arrays=current_lif1_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    x_range=voltage_lif1_plot.x_range,    # Link the x-axis range to the voltage plot
)

# bplt.show(voltage_lif1_plot)

In [65]:
v_y1 = [val[0] for val in lif2_voltage_vals]

# Create the plot
voltage_lif2_y_arrays = [(v_y1, "Ch. 1")]    # List of tuples containing the y values and the legend label
# Define the box annotation parameters
box_annotation_voltage = {
    "bottom": 0,
    "top": v_th,
    "left": 0,
    "right": num_steps,
    "fill_alpha": 0.03,
    "fill_color": "green"
}

# Create the LIF2 Voltage
voltage_lif2_plot = create_fig(
    title="LIF2 Voltage dynamics", 
    x_axis_label='time (ms)', 
    y_axis_label='Voltage (V)',
    x=x, 
    y_arrays=voltage_lif2_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    box_annotation_params=box_annotation_voltage,
    y_range=Range1d(-0.05, 1.05),
    x_range=voltage_lif1_plot.x_range,    # Link the x-axis range to the voltage plot
)


# Create the LIF2 Current
u_y1 = [val[0] for val in lif2_current_vals]
current_lif2_y_arrays = [(u_y1, "Output")]    # List of tuples containing the y values and the legend label
current_lif2_plot = create_fig(
    title="LIF2 Current dynamics", 
    x_axis_label='time (ms)', 
    y_axis_label='Current (U)',
    x=x, 
    y_arrays=current_lif2_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    x_range=voltage_lif1_plot.x_range,    # Link the x-axis range to the voltage plot
)

# bplt.show(voltage_lif1_plot)

## Show the Plots assembled in a grid

In [66]:
import bokeh.plotting as bplt
from bokeh.layouts import gridplot

showPlot = True
if showPlot:
    # Create array of plots to be shown
    plots = [voltage_lif1_plot, current_lif1_plot, voltage_lif2_plot, current_lif2_plot]

    if len(plots) == 1:
        grid = plots[0]
    else:   # Create a grid layout
        grid = gridplot(plots, ncols=2, sizing_mode="stretch_both")

    # Show the plot
    bplt.show(grid)

## Export the plot to a file

In [67]:
export = False

OUT_FOLDER = "./results/net_burst"
OUT_FILENAME = f"lab_ch1-{n1}_3spikes5ch_20ms"

if export:
    file_path = f"{OUT_FOLDER}/{OUT_FILENAME}.html"

    # Customize the output file settings
    bplt.output_file(filename=file_path, title="Network Burst detection - Voltage and Current dynamics")

    # Save the plot
    bplt.save(grid)

## Calculate Detection Metrics

### Convert the spike times to the same format as the Ground Truth

In [68]:
# Invert the mapping of the electrodes to the neuron indices
channel_map_inv = {v: k for k, v in channel_map.items()}
print("Channel Map Inverted: ", channel_map_inv)

Channel Map Inverted:  {0: 45, 1: 46, 2: 47, 3: 48, 4: 49, 5: 50, 6: 51, 7: 52, 8: 53, 9: 54}


In [69]:
predicted_spikes = list( map(lambda x: x[0], spike_times_lif2 ) )
predicted_spikes = np.array(predicted_spikes)

print(f"Predicted {len(predicted_spikes)} spikes")
print("Predicted Spikes: ", predicted_spikes)

Predicted 3 spikes
Predicted Spikes:  [2917 5919 5978]


### Calculate the Confusion Matrix

In [70]:
# Initialize the variables that store the values of the Confusion Matrix
true_positive = 0
false_positive = 0
true_negative = 0
false_negative = 0

In [71]:
# Go through the ground truth and check if the predicted spikes are correct
# TODO: For now using I'm considering the causality window [gt_burst_time - CAUSALITY_WINDOW, gt_burst_time + CAUSALITY_WINDOW]. 

# Get the True Positive and False Negative values by comparing with the ground truth
for gt_net_burst_time in gt_cropped:
    # Check if the predicted burst time is within the causality window [gt_burst_time, gt_burst_time + CAUSALITY_WINDOW]
    if any([abs(pred_burst_time - gt_net_burst_time) <= NET_CAUSALITY_WINDOW for pred_burst_time in predicted_spikes]):
        true_positive += 1
    else:
        false_negative += 1

# Get the False Positive values by checking if the predicted spikes are false positives
for pred_net_burst_time in predicted_spikes:
    # Check if the predicted burst time is within the causality window [gt_burst_time, gt_burst_time + CAUSALITY_WINDOW]
        if all([abs(pred_net_burst_time - gt_net_burst_time) > NET_CAUSALITY_WINDOW for gt_net_burst_time in gt_cropped]):
            false_positive += 1

# Calculate the True Negative value
# TN = P - (TP + FP + FN)
true_negative = num_steps - true_positive - false_positive - false_negative

In [72]:
# Print the Confusion Matrix
print("True Positive: ", true_positive)
print("False Positive: ", false_positive)
print("True Negative: ", true_negative)
print("False Negative: ", false_negative)

# Print the Total of predictions
total_predictions = true_positive + false_positive + true_negative + false_negative
print("Total Predictions: ", total_predictions)

True Positive:  2
False Positive:  1
True Negative:  9997
False Negative:  0
Total Predictions:  10000


In [73]:
# Calculate relevant metrics
accuracy = 0
precision = 0
recall = 0
f1_score = 0
specificity = 0
if true_positive + false_positive == 0 == 0:
    print("No relevant predictions were made. Cannot calculate metrics.")
else:    
    accuracy = (true_positive + true_negative) / total_predictions * 100    # Proportion of correct predictions
    precision = true_positive / (true_positive + false_positive) * 100      # Proportion of TPs that were identified correctly
    recall = true_positive / (true_positive + false_negative) * 100         # Proportion of TPs that were captured by the model
    f1_score = (2 * precision * recall) / (precision + recall)              # Harmonic mean of Precision and Recall
    specificity = true_negative / (true_negative + false_positive) * 100    # Proportion of TNs that were identified correctly

print(f"Accuracy: {accuracy}%")
print(f"Precision: {precision}%")
print(f"Recall: {recall}%")
print(f"F1 Score: {f1_score}")
print(f"Specificity: {specificity}%")

Accuracy: 99.99%
Precision: 66.66666666666666%
Recall: 100.0%
F1 Score: 80.0
Specificity: 99.98999799959992%


# Export the results of the Detection to a JSON file
Export the results of the classification to a JSON file. This file will include:
- `Causality Window` used
- Classification Metrics (`True Positives`, `False Positives`, `True Negatives`, `False Negatives`, `Accuracy`, `Precision`, `Recall`, `F1 Score`, `Specificity`)

In [74]:
import json

# Export the results to a JSON file

# Create a dictionary with the results
json_results = {
    "causality_window": NET_CAUSALITY_WINDOW,
    "metrics": {
        "true_positive": true_positive,
        "false_positive": false_positive,
        "true_negative": true_negative,
        "false_negative": false_negative,
        "total_predictions": total_predictions,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "specificity": specificity
    }
}

EXPORT_JSON_FILE = False
if EXPORT_JSON_FILE:
    json_file_name = f"{OUT_FOLDER}/{OUT_FILENAME}_metrics.json"
    with open(json_file_name, 'w') as f:
        json.dump(json_results, f)

## Stop the Runtime

In [75]:
selected_lif1.stop()