# SNN that detects High Frequency Oscillations (HFOs) with constant parameters
This notebook is a simple example of how to use a Spiking Neural Network (SNN) to detect HFOs

### What is an HFO?
High Frequency Oscillations (HFOs) are a type of brain activity that occurs in the range of 80-500 Hz. They are believed to be related to the generation of seizures in patients with epilepsy. The detection of HFOs is an important task in the diagnosis and treatment of epilepsy. 

In terms of electrophysiology, HFOs are characterized by their high frequency and short duration, often lasting only a few milliseconds. The wave of a typical HFO consists of at least 4 UP and DOWN waves.

In [33]:
from lava.proc.lif.process import LIF
from lava.proc.dense.process import Dense
import numpy as np

LIF?

[0;31mInit signature:[0m [0mLIF[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Leaky-Integrate-and-Fire (LIF) neural Process.

LIF dynamics abstracts to:
u[t] = u[t-1] * (1-du) + a_in         # neuron current
v[t] = v[t-1] * (1-dv) + u[t] + bias  # neuron voltage
s_out = v[t] > vth                    # spike if threshold is exceeded
v[t] = 0                              # reset at spike

Parameters
----------
shape : tuple(int)
    Number and topology of LIF neurons.
u : float, list, numpy.ndarray, optional
    Initial value of the neurons' current.
v : float, list, numpy.ndarray, optional
    Initial value of the neurons' voltage (membrane potential).
du : float, optional
    Inverse of decay time-constant for current decay. Currently, only a
    single decay can be set for the entire population of neurons.
dv : float, optional
    Inverse of decay time-constant for voltage decay. Curr

### Check WD (change if necessary) and file loading

In [34]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)

# Check if the current WD is the file location
if "/src/hfo/snn" not in os.getcwd():
    # Set working directory to this file location
    file_location = f"{os.getcwd()}/thesis-lava/src/hfo/snn"
    print("File Location: ", file_location)

    # Change the current working Directory
    os.chdir(file_location)

    # New Working Directory
    print("New Working Directory: ", os.getcwd())

/home/monkin/Desktop/feup/thesis/thesis-lava/src/hfo/snn


## Create the Custom Input Layer

### Define function to read the input data from the csv file and generate the corresponding spike events

In [35]:
import pandas as pd

def read_spike_events(file_path: str):
    """Reads the spike events from the input file and returns them as a numpy array

    Args:
        file_path (str): name of the file containing the spike events
    """
    spike_events = []

    try:
        # Read the spike events from the file
        df = pd.read_csv(file_path, header=None)

        # Detect errors
        if df.empty:
            raise Exception("The input file is empty")

        # Convert the scientific notation values to integers if any exist
        df = df.applymap(lambda x: int(float(x)) if (isinstance(x, str) and 'e' in x) else x)

        # Convert the dataframe to a numpy array
        spike_events = df.to_numpy()
        return spike_events[0]
    except Exception as e:
        print("Unable to read the input file: ", file_path, " error:", e)

    return spike_events

## Configurable Parameters

In [36]:
from utils.input import BaselineAlgorithm, MarkerType, ModelDistStrategy, band_to_confidence_window, X_STEP
from math import floor

# Declare if using ripples, fast ripples, or both
chosen_band = MarkerType.RIPPLE     # RIPPLE, FAST_RIPPLE, or BOTH

# Specify the chosen Baseline Algorithm
chosen_baseline_alg_suffix = BaselineAlgorithm.Q3

# Select the Distribution Model Strategy
selected_strategy = ModelDistStrategy.LOG_NORMAL

# Define the Weight Scale of the Dense Layer #TODO: Test changing this
weights_scale_input = 0.2   # 0.5
weights_scale_std = 0.1

# Define the IQR ranges for the specific data being fed (synaptic time constants)
ripple_IQR = [0.48, 6.5]
fr_IQR = [0.49, 2.93] # Inter-Quartile Range for the time constants of the Fast-Ripple neurons

# Define the mean and std. deviation of the synaptic time constants
synaptic_tc_mean = 3.8735
scale_std_dev = 1
synaptic_tc_std_dev = 7.4958 * scale_std_dev

# Define the mean and std. deviation of the Membrane Potential Time Constants
mean_dv = 20    # Mean voltage time constant = 15ms (following Indiveri's paper)
std_dev_dv = 10  # Standard deviation of the voltage time constant. If 0, then the time constant is fixed 

# Range of random values to subtract from the excitatory time constants to get the inhibitory time constants
inh_subtract_range = [0.1, 1.0]  

# Constants for the Refractory LIF Process
confidence_window = band_to_confidence_window(chosen_band)
# We know that 2 relevant events do not occur within the confidence window of a ripple event, so we set the refractory period accordingly
refrac_period = floor(confidence_window / X_STEP)   # Number of time-steps for the refractory period
print("Refractory Period: ", refrac_period)

Refractory Period:  245


### Load the UP and DOWN spikes from the CSV Files

In [37]:
# Define the name of the dataset version being used
DATASET_FILENAME = "seeg_filtered_subset_90-119_segment500_200"

#### Load the Baseline Thresholds from the output file from the baseline process

In [38]:
# Load the Baseline Thresholds
BASELINE_FILE = f"../signal_to_spike/baseline_results/{DATASET_FILENAME}_thresholds_{chosen_baseline_alg_suffix}.npy"
baseline_thresholds = np.load(BASELINE_FILE)

# preview_np_array(baseline_thresholds, "baseline_thresholds", edge_items=3)

baseline_ripple_thresh = round(baseline_thresholds[0], 4)
baseline_fr_thresh = round(baseline_thresholds[1], 4)

# For now, the UP and DN thresholds are the same (symmetric)
ripple_thresh_up = baseline_ripple_thresh
ripple_thresh_down = -baseline_ripple_thresh
fr_thresh_up = baseline_fr_thresh
fr_thresh_down = -baseline_fr_thresh

print("Ripple Thresholds: ", ripple_thresh_up, ripple_thresh_down)
print("FR Thresholds: ", fr_thresh_up, fr_thresh_down)

Ripple Thresholds:  3.7058 -3.7058
FR Thresholds:  1.4961 -1.4961


In [39]:
def band_to_thresholds(band: MarkerType):
    """Returns the ripple and FR thresholds for the given band

    Args:
        band (MarkerType): the band for which the thresholds are required
    """
    if band == MarkerType.RIPPLE:
        return ripple_thresh_up, ripple_thresh_down
    elif band == MarkerType.FAST_RIPPLE:
        return fr_thresh_up, fr_thresh_down
    else:
        raise Exception("Invalid band type")

In [40]:
from utils.input import read_spike_events, band_to_file_name
from utils.io import preview_np_array

# Define the path of the files containing the spike events
INPUT_PATH = f"../signal_to_spike/results/{DATASET_FILENAME}"
band_file_name = band_to_file_name(chosen_band)

# Define the chosen thresholds
thresh_up, thresh_down = band_to_thresholds(chosen_band)
print("Thresholds: ", thresh_up, thresh_down)

# Call the function to read the spike events
up_spikes_file_path = f"{INPUT_PATH}/{band_file_name}_up_spike_train_{thresh_up}.csv"
up_spike_train = read_spike_events(up_spikes_file_path)

down_spikes_file_path = f"{INPUT_PATH}/{band_file_name}_down_spike_train_{thresh_down}.csv"
down_spike_train = read_spike_events(down_spikes_file_path)

preview_np_array(up_spike_train, "up_spike_train")
preview_np_array(down_spike_train, "down_spike_train")

Thresholds:  3.7058 -3.7058
up_spike_train Shape: (3724, 2).
Preview: [[ 9.43847656e+02 -1.00000000e+00]
 [ 9.51171875e+02 -1.00000000e+00]
 [ 9.52148438e+02 -1.00000000e+00]
 [ 9.63378906e+02 -1.00000000e+00]
 [ 9.73632812e+02 -1.00000000e+00]
 ...
 [ 1.19084473e+05 -1.00000000e+00]
 [ 1.19092285e+05 -1.00000000e+00]
 [ 1.19093262e+05 -1.00000000e+00]
 [ 1.19094727e+05 -1.00000000e+00]
 [ 1.19105957e+05 -1.00000000e+00]]
down_spike_train Shape: (3711, 2).
Preview: [[ 9.47265625e+02 -1.00000000e+00]
 [ 9.48242188e+02 -1.00000000e+00]
 [ 9.56054688e+02 -1.00000000e+00]
 [ 9.66796875e+02 -1.00000000e+00]
 [ 9.77539062e+02 -1.00000000e+00]
 ...
 [ 1.19087402e+05 -1.00000000e+00]
 [ 1.19088379e+05 -1.00000000e+00]
 [ 1.19089844e+05 -1.00000000e+00]
 [ 1.19098633e+05 -1.00000000e+00]
 [ 1.19100098e+05 -1.00000000e+00]]


### Define the SpikeEvent Generator Interface

In [41]:
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import OutPort

class SpikeEventGen(AbstractProcess):
    """Input Process that generates spike events based on the input file

    Args:
        @out_shape (tuple): Shape of the output port
        @exc_spike_events (np.ndarray): Excitatory spike events
        @inh_spike_event (np.ndarray): Inhibitory spike events
        @name (str): Name of the process
    """
    def __init__(self, out_shape: tuple, exc_spike_events: np.ndarray, inh_spike_event: np.ndarray, name: str) -> None:
        super().__init__(name=name)
        self.s_out = OutPort(shape=out_shape)
        self.exc_spike_events = Var(shape=exc_spike_events.shape, init=exc_spike_events)
        self.inh_spike_events = Var(shape=inh_spike_event.shape, init=inh_spike_event)


## Define the Architecture of the Network

In [42]:
# Define the number of neurons in the Input Spike Event Generator
n_spike_gen = 2  # 2 neurons in the input spike event generator

# Define the number of neurons in each LIF Layer
# TODO: Test with 512 neurons?
n_lif1 = 256   # 256 neurons in the first LIF layer 
# n2 = 1  # 1 neuron in the second layer

### Choose the LIF Models to use

In [43]:
use_refractory = True

### Define the LIF parameters

In [44]:
# Constants for the LIF Process
v_th = 1
v_init = 0

### Create a `ConfigTimeConstantsLIF` object

#### Define the time constants for the `ConfigTimeConstantsLIF` neurons
The synapse time constants, corresponding to `du_exc` and `du_inh` will be different for each neuron in the LIF layer. Likewise, the excitatory and inhibitory time constants will also differ for each neuron in order to capture the dynamics of the HFOs.

In [45]:
from utils.neuron_dynamics import time_constant_to_fraction
# Create the np arrays for the time constants of each neuron

# Select the chosen IQR based on the chosen band
chosen_IQR = ripple_IQR
if chosen_band == MarkerType.FAST_RIPPLE:
    chosen_IQR = fr_IQR
elif chosen_band != MarkerType.RIPPLE:
    raise Exception("What to do in this case?")


if selected_strategy == ModelDistStrategy.IQR:
    chosen_mu = np.mean(chosen_IQR)  # Midpoint of the IQR
    # Calculate the standard deviation of the normal distribution
    # For a normal distribution, the first quartile is ~0.675 standard deviations below the mean
    chosen_std_dev = (chosen_IQR[1] - chosen_IQR[0]) / (2 * 0.675)  # standard deviation is the IQR divided by 2*0.675
elif selected_strategy == ModelDistStrategy.MEAN_AND_STD:
    # Use the actual values from the STS_ANALYSIS
    chosen_mu = synaptic_tc_mean
    chosen_std_dev = synaptic_tc_std_dev
elif selected_strategy == ModelDistStrategy.LOG_NORMAL:
    # Convert the mean and std. deviation to the log-normal distribution
    chosen_mu = np.log( synaptic_tc_mean**2 / np.sqrt(synaptic_tc_std_dev**2 + synaptic_tc_mean**2) )
    chosen_std_dev = np.sqrt( np.log( 1 + (synaptic_tc_std_dev**2 / synaptic_tc_mean**2) ) )

print("Chosen mu: ", chosen_mu)
print("Chosen std_dev: ", chosen_std_dev)
# print("Chosen IQR: ", chosen_IQR)

Chosen mu:  0.5756336395181222
Chosen std_dev:  1.2478179767672029


### Generate a random distribution of the Synaptic Excitatory time constants 

In [46]:
# Generate the time constants for the Fast-Ripple neurons
exc_syn_time_constants = np.random.normal(chosen_mu, chosen_std_dev, n_lif1)
# Generate Log Normal Distribution if selected
if selected_strategy == ModelDistStrategy.LOG_NORMAL:
    exc_syn_time_constants = np.random.lognormal(chosen_mu, chosen_std_dev, n_lif1)

# preview_np_array(exc_syn_time_constants, "exc_syn_time_constants", edge_items=10)
print("Min and max time constants before:", np.min(exc_syn_time_constants), np.max(exc_syn_time_constants))

# Cannot have negative time constants. Make them 0 or positive?
# exc_syn_time_constants = np.clip(exc_syn_time_constants, a_min=0, a_max=None)
exc_syn_time_constants = np.abs(exc_syn_time_constants)
# TODO: Could add the mean to the negative values? 
print("[Excitatory] Min and max time constants after abs:", np.min(exc_syn_time_constants), np.max(exc_syn_time_constants))
print(f"[Excitatory] Mean time constants after: {np.mean(exc_syn_time_constants)} ± {np.std(exc_syn_time_constants)}")

Min and max time constants before: 0.02826254263840086 38.901523357611666
[Excitatory] Min and max time constants after abs: 0.02826254263840086 38.901523357611666
[Excitatory] Mean time constants after: 3.3510693414200414 ± 5.084034031184725


### The inhibitory time constants will be calculated from the excitatory time constants by subtracting a value in a range

In [47]:
# Generate the inhibitory time constants by subtracting a random value in a range from the excitatory time constants

# Generate the random values to subtract from the excitatory time constants
inh_syn_offset = np.random.uniform(inh_subtract_range[0], inh_subtract_range[1], n_lif1)
preview_np_array(inh_syn_offset, "inh_syn_offset", edge_items=10)

# Subtract the random values from the excitatory time constants to get the inhibitory time constants
inh_syn_time_constants = exc_syn_time_constants - inh_syn_offset

# Clip the inhibitory time constants to be positive or equal to the minimum found excitatory time 

inh_syn_time_constants = np.clip(inh_syn_time_constants, a_min=np.min(exc_syn_time_constants), a_max=None)
print("[Inhibitory] Min and max time constants after:", np.min(inh_syn_time_constants), np.max(inh_syn_time_constants))
print(f"[Inhibitory] Mean time constants after: {np.mean(inh_syn_time_constants)} ± {np.std(inh_syn_time_constants)}")

inh_syn_offset Shape: (256,).
Preview: [0.6245685  0.11134428 0.62556626 0.34031323 0.80856861 0.51897728
 0.909393   0.77162257 0.88902599 0.69356288 ... 0.4595099  0.1553502
 0.57208619 0.18969203 0.13813101 0.54810279 0.48532392 0.55301466
 0.94418241 0.22362338]
[Inhibitory] Min and max time constants after: 0.02826254263840086 38.43508810555922
[Inhibitory] Mean time constants after: 2.8484560268004726 ± 5.05035579643124


In [48]:
# Convert the excitatory time constants to fractions (du_exc values) that are used in the LAVA Processes dynamics
exc_syn_time_constants_frac = time_constant_to_fraction(exc_syn_time_constants)
preview_np_array(exc_syn_time_constants_frac, "exc_syn_time_constants_frac", edge_items=5)

print(f"Min. Exc. τ: {np.min(exc_syn_time_constants_frac)}. Max. Exc. τ: {np.max(exc_syn_time_constants_frac)}")
print(f"Mean Exc. τ: {np.mean(exc_syn_time_constants_frac)} ± {np.std(exc_syn_time_constants_frac)}")

# TODO: Maybe there is a better approx. function than normal distribution that avoids having 1.0 time constant (no decay)

exc_syn_time_constants_frac Shape: (256,).
Preview: [0.56139554 0.74546144 0.0804599  0.91453483 0.31593486 ... 0.32193654
 0.41808139 0.99991433 0.86724273 0.78020289]
Min. Exc. τ: 0.025378349639114672. Max. Exc. τ: 0.9999999999999996
Mean Exc. τ: 0.490179927059338 ± 0.28605179867452224


In [49]:
# Convert the inhibitory time constants to fractions (du_inh values) that are used in the LAVA Processes dynamics
inh_syn_time_constants_frac = time_constant_to_fraction(inh_syn_time_constants)
preview_np_array(inh_syn_time_constants_frac, "inh_syn_time_constants_frac", edge_items=5)

print(f"Min. Inh. τ: {np.min(inh_syn_time_constants_frac)}. Max. Inh. τ: {np.max(inh_syn_time_constants_frac)}")
print(f"Mean Inh. τ: {np.mean(inh_syn_time_constants_frac)} ± {np.std(inh_syn_time_constants_frac)}")

inh_syn_time_constants_frac Shape: (256,).
Preview: [0.81702253 0.80095735 0.08472155 0.99999972 0.42185048 ... 0.38959406
 0.52020694 1.         1.         0.89887189]
Min. Inh. τ: 0.02568234376521983. Max. Inh. τ: 0.9999999999999996
Mean Inh. τ: 0.5987675552850732 ± 0.3407541841269754


#### Define the Membrane Potential Time Constants for the LIF Neurons
To add more variability to the network, the membrane potential time constants will be randomly generated from a normal distribution around a mean value.

In [50]:
dv_time_constants = np.random.normal(mean_dv, std_dev_dv, n_lif1)

preview_np_array(dv_time_constants, "dv_time_constants", edge_items=5)

# Guarantee that the time constants are positive
dv_time_constants = np.abs(dv_time_constants)
print("[ms] Min and max Voltage time constants:", np.min(dv_time_constants), np.max(dv_time_constants))

# Transform the time constants to fractions
dv_time_constants_frac = time_constant_to_fraction(dv_time_constants)
print("[Frac] Min and max Voltage time constants:", np.min(dv_time_constants_frac), np.max(dv_time_constants_frac))

dv_time_constants Shape: (256,).
Preview: [32.66680091 16.81401807  8.73880505 14.4618338  13.62065445 ...
 28.7821803  14.36544626 19.14522485 19.74947647 31.48735612]
[ms] Min and max Voltage time constants: 0.05679024673581523 53.02091484937837
[Frac] Min and max Voltage time constants: 0.01868373583735028 0.9999999774753918


In [51]:
from lava.proc.lif.process import ConfigTimeConstantsLIF, ConfigTimeConstantsRefractoryLIF
""" 
configLIF = ConfigTimeConstantsLIF(shape=(n_lif1,),  # There are 256 neurons
            vth=v_th,  # TODO: Verify these initial values
            v=v_init,
            dv=dv_time_constants_frac,    # Inverse of decay time-constant for voltage decay
            du_exc=exc_syn_time_constants_frac,  # Inverse of decay time-constant for excitatory current decay
            du_inh=inh_syn_time_constants_frac,  # Inverse of decay time-constant for inhibitory current decay
            bias_mant=0,
            bias_exp=0,
            name="lif1")

configLIF.du_exc """

' \nconfigLIF = ConfigTimeConstantsLIF(shape=(n_lif1,),  # There are 256 neurons\n            vth=v_th,  # TODO: Verify these initial values\n            v=v_init,\n            dv=dv_time_constants_frac,    # Inverse of decay time-constant for voltage decay\n            du_exc=exc_syn_time_constants_frac,  # Inverse of decay time-constant for excitatory current decay\n            du_inh=inh_syn_time_constants_frac,  # Inverse of decay time-constant for inhibitory current decay\n            bias_mant=0,\n            bias_exp=0,\n            name="lif1")\n\nconfigLIF.du_exc '

In [52]:
# Config Time Constants Refractory LIF
configRefracLIF = ConfigTimeConstantsRefractoryLIF(shape=(n_lif1,),  # There are 256 neurons
            vth=v_th,  # TODO: Verify these initial values
            v=v_init,
            dv=dv_time_constants_frac,    # Inverse of decay time-constant for voltage decay
            du_exc=exc_syn_time_constants_frac,  # Inverse of decay time-constant for excitatory current decay
            du_inh=inh_syn_time_constants_frac,  # Inverse of decay time-constant for inhibitory current decay
            bias_mant=0,
            bias_exp=0,
            refractory_period=refrac_period,
            name="lif1")

In [53]:
selectedLIF = configRefracLIF # configRefracLIF

### Create the Dense Layers

In [54]:
# Create Dense Process to connect the input layer and LIF1
# create weights of the dense layer
# dense_weights_input = np.eye(N=n1, M=n1)
# Fully Connected Layer from n_spike_gen neurons to n_lif1 neurons
dense_weights_input = np.ones(shape=(n_lif1, n_spike_gen))

# Make the weights (synapses) connecting the odd-parity neurons of the input layer to the network negative (inhibitory)
dense_weights_input[:, 1::2] *= -1

# Create a Distribution of the weights_scale to multiply the weights of the Dense Layer
weights_scale_dist = np.random.normal(weights_scale_input, weights_scale_std, n_lif1)

# Multiply the weights of the Dense layer by the Distribution
dense_weights_input[:, 0] *= weights_scale_dist
dense_weights_input[:, 1] *= weights_scale_dist

# multiply the weights of the Dense layer by a constant
# dense_weights_input *= weights_scale_input

dense_input = Dense(weights=np.array(dense_weights_input), name="DenseInput")

#### Look at the weights of the Dense Layers

In [55]:
# Weights of the Input Dense Layer
dense_input.weights.get()

array([[ 0.35609889, -0.35609889],
       [ 0.00895139, -0.00895139],
       [ 0.00163203, -0.00163203],
       [ 0.21217467, -0.21217467],
       [ 0.24436039, -0.24436039],
       ...,
       [ 0.17653917, -0.17653917],
       [ 0.38541508, -0.38541508],
       [ 0.11842201, -0.11842201],
       [ 0.23033663, -0.23033663],
       [ 0.01153452, -0.01153452]])

### Map the input channels to the corresponding indexes in the input layer
Since the input channels in the input file may be of any number, we need to **map the input channels to the corresponding indexes in the input layer**. This is done by the `channel_map` dictionaries.

The network expects an UP and DOWN spike train for each channel. Thusly, let's define 2 dictionaries, one for the UP spikes and one for the DOWN spikes. We want the UP and DOWN spike trains to be followed by each other in the input layer for each channel.

In [56]:
# Map the channels of the input file to the respective index in the output list of SpikeEventGen

# Define the mapping of the channels of the UP spike train to the respective index in the output list of SpikeEventGen
up_channel_map = {-1: 0}
# Define the mapping of the channels of the DOWN spike train to the respective index in the output list of SpikeEventGen
down_channel_map = {-1: 1}

# Define constants related to the simulation time

In [57]:
init_offset = 0 # 900 # 33400      #   
virtual_time_step_interval = 1  # TODO: Check if this should be the time-step value. it is not aligned with the sampling rate of the input data

num_steps = 120000    # 200 # Number of steps to run the simulation

# OPTIONAL: Scale down the simulation time
time_downscale = 4  # 

num_steps = num_steps // time_downscale

### Update the UP and DOWN spike trains to include only the spikes that occur within the simulation time

In [58]:
# Iterate the UP spike train to find the spikes that occur within the time interval
up_train_start = -1
up_train_end = up_spike_train.shape[0]
for i, (spike_time, _) in enumerate(up_spike_train):
    if up_train_start == -1 and spike_time >= init_offset:
        up_train_start = i
    
    if spike_time > init_offset + num_steps:
        up_train_end = i
        break

# Slice the spike train to the time interval
up_spike_train_interval = []
if up_train_start != -1:
    # If there are spikes in the interval
    up_spike_train_interval = up_spike_train[up_train_start:up_train_end]
    
preview_np_array(up_spike_train_interval, "Spike Events")

Spike Events Shape: (947, 2).
Preview: [[ 9.43847656e+02 -1.00000000e+00]
 [ 9.51171875e+02 -1.00000000e+00]
 [ 9.52148438e+02 -1.00000000e+00]
 [ 9.63378906e+02 -1.00000000e+00]
 [ 9.73632812e+02 -1.00000000e+00]
 ...
 [ 2.95771484e+04 -1.00000000e+00]
 [ 2.95869141e+04 -1.00000000e+00]
 [ 2.95888672e+04 -1.00000000e+00]
 [ 2.98808594e+04 -1.00000000e+00]
 [ 2.98901367e+04 -1.00000000e+00]]


In [59]:
# Iterate the UP spike train to find the spikes that occur within the time interval
down_train_start = -1
down_train_end = down_spike_train.shape[0]
for i, (spike_time, _) in enumerate(down_spike_train):
    if down_train_start == -1 and spike_time >= init_offset:
        down_train_start = i
    
    if spike_time > init_offset + num_steps:
        down_train_end = i
        break

# Slice the spike train to the time interval
down_spike_train_interval = []
if down_train_start != -1:
    # If there are spikes in the interval
    down_spike_train_interval = down_spike_train[down_train_start:down_train_end]
    
preview_np_array(down_spike_train_interval, "Spike Events")

Spike Events Shape: (951, 2).
Preview: [[ 9.47265625e+02 -1.00000000e+00]
 [ 9.48242188e+02 -1.00000000e+00]
 [ 9.56054688e+02 -1.00000000e+00]
 [ 9.66796875e+02 -1.00000000e+00]
 [ 9.77539062e+02 -1.00000000e+00]
 ...
 [ 2.95825195e+04 -1.00000000e+00]
 [ 2.95839844e+04 -1.00000000e+00]
 [ 2.95947266e+04 -1.00000000e+00]
 [ 2.98833008e+04 -1.00000000e+00]
 [ 2.98935547e+04 -1.00000000e+00]]


## Implement the `SpikeEventGenerator` Model

In [60]:
from lava.magma.core.model.py.model import PyLoihiProcessModel  # Processes running on CPU inherit from this class
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.py.ports import PyOutPort

@implements(proc=SpikeEventGen, protocol=LoihiProtocol)
@requires(CPU)
class PySpikeEventGenModel(PyLoihiProcessModel):
    """Spike Event Generator Process implementation running on CPU (Python)
    Args:
    """
    s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)   # IT IS POSSIBLE TO SEND FLOATS AFTER ALL
    exc_spike_events: np.ndarray = LavaPyType(np.ndarray, np.ndarray)
    inh_spike_events: np.ndarray = LavaPyType(np.ndarray, np.ndarray)

    def __init__(self, proc_params) -> None:
        super().__init__(proc_params=proc_params)
        # print("spike events", self.spike_events.__str__())    # TODO: Check why during initialization the variable prints the class, while during run it prints the value
        
        self.curr_exc_idx = 0     # Index of the next excitatory spiking event to send
        self.curr_inh_idx = 0     # Index of the next inhibitory spiking event to send
        self.virtual_time_step_interval = virtual_time_step_interval  # 1000    # Arbitrary time between time steps (in microseconds). This is not a real time interval (1000ms = 1s)
        self.init_offset = init_offset        # 698995               # Arbitrary offset to start the simulation (in microseconds)
        
        # Try to increment the curr_exc_idx and curr_inh_idx to the first spike event that is greater than the init_offset here?

    def run_spk(self) -> None:
        spike_data = np.zeros(self.s_out.shape) # Initialize the spike data to 0
        
        #print("time step:", self.time_step)

        # If the current simulation time is greater than a spike event, send a spike in the corresponding channel
        currTime = self.init_offset + self.time_step*self.virtual_time_step_interval

        spiking_channels = set()   # List of channels that will spike in the current time step

        # Add the excitatory spike events to the spike_date
        while (self.curr_exc_idx < len(self.exc_spike_events)) and currTime >= self.exc_spike_events[self.curr_exc_idx][0]:
            # Get the channel of the current spike event
            curr_channel = self.exc_spike_events[self.curr_exc_idx][1]

            # Check if the channel is valid (belongs to a channel in the up_channel_map therefore it has an output index)
            if curr_channel not in up_channel_map:
                self.curr_exc_idx += 1
                continue    # Skip the current spike event

            # Check if the next spike belongs to a channel that will already spike in this time step
            # If so, we don't add the event and stop looking for more events
            if curr_channel in spiking_channels:
                break

            # Add the channel to the list of spiking channels
            spiking_channels.add(curr_channel)

            # Get the output index of the current channel according to the up_channel_map
            out_idx = up_channel_map[curr_channel]
            if out_idx < self.s_out.shape[0]:   # Check if the channel is valid
                # Update the spike_data with the excitatory spike event (value = 1.0)
                spike_data[out_idx] = 1.0   # Send spike (value corresponds to the punctual current of the spike event)

            # Move to the next spike event
            self.curr_exc_idx += 1

        # Add the inhibitory spike events to the spike_date
        while (self.curr_inh_idx < len(self.inh_spike_events)) and currTime >= self.inh_spike_events[self.curr_inh_idx][0]:
            # Get the channel of the current spike event
            curr_channel = self.inh_spike_events[self.curr_inh_idx][1]

            # Check if the channel is valid (belongs to a channel in the down_channel_map therefore it has an output index)
            if curr_channel not in down_channel_map:
                self.curr_inh_idx += 1
                continue    # Skip the current spike event

            # Check if the next spike belongs to a channel that will already spike in this time step
            # If so, we don't add the event and stop looking for more events
            if curr_channel in spiking_channels:
                break

            # Add the channel to the list of spiking channels
            spiking_channels.add(curr_channel)

            # Get the output index of the current channel according to the down_channel_map
            out_idx = down_channel_map[curr_channel]
            if out_idx < self.s_out.shape[0]:   # Check if the channel is valid
                # It is not possible to send negative values or floats in the spike_data. The weight of the synapse should do the inhibition
                spike_data[out_idx] = 1.0   # Send spike (value corresponds to the punctual current of the spike event)

            # Move to the next spike event
            self.curr_inh_idx += 1


        if len(spiking_channels) > 0:   # Print the spike event if there are any spikes
            VERBOSE = False
            if VERBOSE:
                print(f"""Sending spike event at time: {currTime}({self.time_step}). Last (E/I) spike idx: {self.curr_exc_idx-1}/{self.curr_inh_idx-1}
                        Spike times: {self.exc_spike_events[self.curr_exc_idx-1][0] if self.curr_exc_idx > 0 else "?"}/\
                        {self.inh_spike_events[self.curr_inh_idx-1][0] if self.curr_inh_idx > 0 else "?"}
                        Spike_data: {spike_data}\n"""
                )
            #else:
            #     print(f"Sending spike event at time: {currTime}({self.time_step}).")

        # Send spikes if self.curr_exc_idx > 0 else "?"
        # print("sending spike_data: ", spike_data, " at step: ", self.time_step)
        self.s_out.send(spike_data)

        # Stop the Process if there are no more spike events to send. (It will stop all the connected processes)
        # TODO: Should it be another process that stops the simulation? Such as the last LIF process
        # if self.curr_spike_idx >= 5: # len(self.spike_events):
        #    self.pause()

        # Print a progress message every 1000 time steps
        if self.time_step % 1000 == 0:
            # Clear the console
            print(f"Time step: {self.time_step}")

## Connect the Layers
To define the connectivity between the `SpikeGenerator` and the first `LIF` population, we use another `Dense` Layer.

In [61]:
# Create the Input Process
spike_event_gen = SpikeEventGen(out_shape=(n_spike_gen,),
                                exc_spike_events=up_spike_train_interval,
                                inh_spike_event=down_spike_train_interval,
                                name="SpikeEventsGenerator")

# If I connect the SpikeEventGen to the Dense Layer, the a_out value of the custom input will be rounded to 0 or 1 in the Dense Layer (it will not be a float) 
# However, setting the Dense weights to a float works instead
# Connect the SpikeEventGen to the Dense Layer
spike_event_gen.s_out.connect(dense_input.s_in)

# Connect the Dense_Input to the LIF1 Layer
dense_input.a_out.connect(selectedLIF.a_in)

### Take a look at the connections in the Input Layer

In [62]:
for proc in [spike_event_gen, dense_input, selectedLIF]:
    for port in proc.in_ports:
        print(f"Proc: {proc.name:<5} Port Name: {port.name:<5} Size: {port.size}")
    for port in proc.out_ports:
        print(f"Proc: {proc.name:<5} Port Name: {port.name:<5} Size: {port.size}")

Proc: SpikeEventsGenerator Port Name: s_out Size: 2
Proc: DenseInput Port Name: s_in  Size: 2
Proc: DenseInput Port Name: a_out Size: 256
Proc: lif1  Port Name: a_in  Size: 256
Proc: lif1  Port Name: s_out Size: 256


### Record Internal Vars over time
To record the evolution of the internal variables over time, we need a `Monitor`. For this example, we want to record the membrane potential of the `LIF` Layer, hence we need 1 `Monitors`.

We can define the `Var` that a `Monitor` should record, as well as the recording duration, using the `probe` function

In [81]:
from lava.proc.monitor.process import Monitor

monitor_lif1_v = Monitor()
monitor_lif1_u = Monitor()

# Connect the monitors to the variables we want to monitor
monitor_lif1_v.probe(selectedLIF.v, num_steps)
monitor_lif1_u.probe(selectedLIF.u, num_steps)  # Monitoring the net_current (u_exc + u_inh) of the LIF1 Process

## Execution
Now that we have defined the network, we can execute it. We will use the `run` function to execute the network.

### Run Configuration and Conditions

In [82]:
from lava.magma.core.run_conditions import RunContinuous, RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg

# run_condition = RunContinuous()   # TODO: Change to this one
run_condition = RunSteps(num_steps=num_steps)
run_cfg = Loihi1SimCfg(select_tag="floating_pt")   # TODO: Check why we need this select_tag="floating_pt"

### Execute

In [83]:
selectedLIF.run(condition=run_condition, run_cfg=run_cfg)

Sending spike event at time: 944(944). Last (E/I) spike idx: 0/-1
                        Spike times: 943.84765625/                        ?
                        Spike_data: [1. 0.]

Resetting voltage of neurons that spiked at time step: 947.
Resetting voltage of neurons that spiked at time step: 948.Sending spike event at time: 948(948). Last (E/I) spike idx: 0/0
                        Spike times: 943.84765625/                        947.265625
                        Spike_data: [0. 1.]


Sending spike event at time: 949(949). Last (E/I) spike idx: 0/1
                        Spike times: 943.84765625/                        948.2421875
                        Spike_data: [0. 1.]

Sending spike event at time: 952(952). Last (E/I) spike idx: 1/1
                        Spike times: 951.171875/                        948.2421875
                        Spike_data: [1. 0.]

Sending spike event at time: 953(953). Last (E/I) spike idx: 2/1
                        Spike times: 952.14

### Retrieve recorded data

In [84]:
data_lif1_v = monitor_lif1_v.get_data()
data_lif1_u = monitor_lif1_u.get_data()

print("Copying...")
data_lif1 = data_lif1_v.copy()
data_lif1["lif1"]["u"] = data_lif1_u["lif1"]["u"]   # Merge the dictionaries to contain both voltage and current


Copying...


In [85]:
selectedLIF

<lava.proc.lif.process.ConfigTimeConstantsRefractoryLIF at 0x7f41d6bf5ab0>

In [86]:
# Check the shape to verify if it is printing the voltage for every step
print(len(data_lif1['lif1']['v']))     # Indeed, there are 300 values (same as the number of steps we ran the simulation for)

30000


### Plot the recorded data

In [87]:
import matplotlib
%matplotlib inline
from matplotlib import pyplot as plt

# Boolean defining if we should use the monitor plot
MONITOR_PLOT = False

if MONITOR_PLOT:
    # Create a subplot for each monitored variable
    fig = plt.figure(figsize=(16, 10))
    ax0 = fig.add_subplot(221)
    ax0.set_title('Voltage (V) / time step')
    ax1 = fig.add_subplot(222)
    ax1.set_title('Current (U) / time step')


    # Plot the data
    monitor_lif1_v.plot(ax0, lif1.v)
    monitor_lif1_u.plot(ax1, lif1.u)

## Find the timesteps where the network spiked

In [88]:
from utils.data_analysis import find_spike_times

lif1_voltage_vals = np.array(data_lif1['lif1']['v'])
lif1_current_vals = np.array(data_lif1['lif1']['u'])
# preview_np_array(voltage_arr_1, "Voltage Array")

# Call the find_spike_times util function that detects the spikes in a voltage array
# TODO: Improve the find_spike_times method to view the current of the preview timestep to make sure it is a spike, instead of an inhibition
spike_times_lif1 = find_spike_times(lif1_voltage_vals, lif1_current_vals)

print(f"Found {len(spike_times_lif1)} spikes in the LIF1 Process")
for (spike_time, neuron_idx) in spike_times_lif1:
    print(f"Spike time: {init_offset + spike_time * virtual_time_step_interval} (iter. {spike_time}) at neuron: {neuron_idx}")



Found 5668 spikes in the LIF1 Process
Spike time: 946 (iter. 946) at neuron: 76
Spike time: 947 (iter. 947) at neuron: 14
Spike time: 947 (iter. 947) at neuron: 33
Spike time: 947 (iter. 947) at neuron: 48
Spike time: 947 (iter. 947) at neuron: 96
Spike time: 947 (iter. 947) at neuron: 105
Spike time: 947 (iter. 947) at neuron: 112
Spike time: 947 (iter. 947) at neuron: 152
Spike time: 947 (iter. 947) at neuron: 163
Spike time: 947 (iter. 947) at neuron: 231
Spike time: 956 (iter. 956) at neuron: 246
Spike time: 965 (iter. 965) at neuron: 222
Spike time: 966 (iter. 966) at neuron: 16
Spike time: 966 (iter. 966) at neuron: 98
Spike time: 966 (iter. 966) at neuron: 123
Spike time: 966 (iter. 966) at neuron: 161
Spike time: 966 (iter. 966) at neuron: 201
Spike time: 966 (iter. 966) at neuron: 249
Spike time: 977 (iter. 977) at neuron: 79
Spike time: 985 (iter. 985) at neuron: 176
Spike time: 985 (iter. 985) at neuron: 200
Spike time: 986 (iter. 986) at neuron: 157
Spike time: 986 (iter. 9

## View the Voltage and Current dynamics with an interactive plot

Grab the data from the recorded variables

In [89]:
preview_np_array(lif1_voltage_vals, "Voltage Values", edge_items=3)

Voltage Values Shape: (30000, 256).
Preview: [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 ...
 [-2.01272642e-28 -4.74768567e-05 -1.04971936e-08 ... -3.38417463e-04
  -1.33054534e-07  4.57217616e-04]
 [-1.10400906e-28 -4.48079081e-05 -9.04947140e-09 ... -3.21447068e-04
  -1.17192599e-07  4.42756661e-04]
 [-6.05564665e-29 -4.22889966e-05 -7.80141204e-09 ... -3.05327676e-04
  -1.03221626e-07  4.28753080e-04]]


## Assemble the values to be plotted

In [90]:
from utils.line_plot import create_fig  # Import the function to create the figure
from bokeh.models import Range1d

# Define the x and y values
x = [val + init_offset for val in range(num_steps)]

v_y1 = [val[215] for val in lif1_voltage_vals]
v_y2 = [val[124] for val in lif1_voltage_vals]
v_y3 = [val[129] for val in lif1_voltage_vals]
v_y4 = [val[138] for val in lif1_voltage_vals]
v_y5 = [val[164] for val in lif1_voltage_vals]
v_y6 = [val[5] for val in lif1_voltage_vals]
v_y7 = [val[6] for val in lif1_voltage_vals]
v_y8 = [val[7] for val in lif1_voltage_vals]
v_y9 = [val[8] for val in lif1_voltage_vals]
v_y10 = [val[9] for val in lif1_voltage_vals]

# Create the plot
voltage_lif1_y_arrays = [
    (v_y1, "Neuron. 0"), (v_y2, "Neuron. 1"), (v_y3, "Neuron. 2"),
    (v_y4, "Neuron. 3"), (v_y5, "Neuron. 4"), # (v_y6, "Neuron. 5"),
    # (v_y7, "Neuron. 6"), (v_y8, "Neuron. 7"), (v_y9, "Neuron. 8"),
    # (v_y10, "Neuron. 9")
]    # List of tuples containing the y values and the legend label
# Define the box annotation parameters
box_annotation_voltage = {
    "bottom": 0,
    "top": v_th,
    "left": 0,
    "right": num_steps,
    "fill_alpha": 0.03,
    "fill_color": "green"
}

# Create the LIF1 Voltage
voltage_lif1_plot = create_fig(
    title="LIF1 Voltage dynamics", 
    x_axis_label='time (ms)', 
    y_axis_label='Voltage (V)',
    x=x, 
    y_arrays=voltage_lif1_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    box_annotation_params=box_annotation_voltage,
    y_range=Range1d(-1.05, 1.05)
)


# Create the LIF1 Current
u_y1 = [val[215] for val in lif1_current_vals]
u_y2 = [val[124] for val in lif1_current_vals]
u_y3 = [val[129] for val in lif1_current_vals]
u_y4 = [val[138] for val in lif1_current_vals]
u_y5 = [val[164] for val in lif1_current_vals]
current_lif1_y_arrays = [(u_y1, "Neuron. 0"), (u_y2, "Neuron. 1"), (u_y3, "Neuron. 2"),
                          (u_y4, "Neuron. 3"), (u_y5, "Neuron. 4")]    # List of tuples containing the y values and the legend label
current_lif1_plot = create_fig(
    title="LIF1 Current dynamics", 
    x_axis_label='time (ms)', 
    y_axis_label='Current (U)',
    x=x, 
    y_arrays=current_lif1_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    x_range=voltage_lif1_plot.x_range,    # Link the x-axis range to the voltage plot
)

# bplt.show(voltage_lif1_plot)

## Show the Plots assembled in a grid

In [91]:
import bokeh.plotting as bplt
from bokeh.layouts import gridplot

showPlot = True
if showPlot:
    # Create array of plots to be shown
    plots = [voltage_lif1_plot, current_lif1_plot]

    if len(plots) == 1:
        grid = plots[0]
    else:   # Create a grid layout
        grid = gridplot(plots, ncols=2, sizing_mode="stretch_both")

    # Show the plot
    bplt.show(grid)

## Export the plot to a file

In [92]:
export = False
OUTPUT_FOLDER = f"./results/{DATASET_FILENAME}"
TIME_SUFFIX = f"time{init_offset}-{num_steps}-{virtual_time_step_interval}"
THRESH_SUFFIX = f"thresh{thresh_up}-{thresh_down}"
STRAT_SUFFIX = f"strat{selected_strategy}"
WEIGHT_SUFFIX = f"w{weights_scale_input}"
DV_SUFFIX = f"dv{mean_dv}"
INH_RANGE = f"inh{inh_subtract_range[0]}-{inh_subtract_range[1]}"

if export:
    # Create the output folder if it does not exist
    if not os.path.exists(OUTPUT_FOLDER):
        os.makedirs(OUTPUT_FOLDER)

    file_path = f"{OUTPUT_FOLDER}/{band_file_name}_output_{DV_SUFFIX}_{WEIGHT_SUFFIX}_{THRESH_SUFFIX}_{INH_RANGE}_{STRAT_SUFFIX}_{TIME_SUFFIX}.html"

    # Customize the output file settings
    bplt.output_file(filename=file_path, title="HFO Detection - Voltage and Current dynamics")

    # Save the plot
    bplt.save(grid)

## Export the Voltage and Current dynamics to a `.npy` file
In order to classify the feature neurons (Noisy, Silent, Ripple, or Fast Ripple Detector), we need to export the voltage and current dynamics to a `.npy` file to be analyzed by a Classification Algorithm

In [93]:
EXPORT_DYNAMICS = True
if EXPORT_DYNAMICS:
    # Define the file paths to save the Voltage and Current dynamics
    v_dynamics_file_path = f"{OUTPUT_FOLDER}/{band_file_name}_v_dynamics_{DV_SUFFIX}_{WEIGHT_SUFFIX}_{THRESH_SUFFIX}_{INH_RANGE}_{STRAT_SUFFIX}_{TIME_SUFFIX}.npy"
    u_dynamic_file_path = f"{OUTPUT_FOLDER}/{band_file_name}_u_dynamics_{DV_SUFFIX}_{WEIGHT_SUFFIX}_{THRESH_SUFFIX}_{INH_RANGE}_{STRAT_SUFFIX}_{TIME_SUFFIX}.npy"
    
    # Export the Voltage dynamics to a numpy file
    np.save(v_dynamics_file_path, lif1_voltage_vals)
    
    # Export the Current dynamics to a numpy file
    np.save(u_dynamic_file_path, lif1_current_vals)

## Export the Ground Truth data to a `.npy` file along with necessary Simulation Parameters

### Load the Ground Truth data from the `.npy` file

In [94]:
# Load the ground_truth data
ground_truth_file_name = f"{INPUT_PATH}/{band_file_name}_ground_truth.npy"

ground_truth = np.load(ground_truth_file_name)

preview_np_array(ground_truth, "ground_truth", edge_items=3)
print(f"Number of relevant events: {np.count_nonzero(ground_truth)}")

ground_truth Shape: (222,).
Preview: [('Fast-Ripple',   1000.  , 0.)
 ('Spike+Ripple+Fast-Ripple',   3206.54, 0.)
 ('Spike+Ripple',   3521.  , 0.) ... ('Spike+Ripple', 116216.  , 0.)
 ('Ripple+Fast-Ripple', 116769.  , 0.) ('Ripple', 119000.  , 0.)]
Number of relevant events: 222


In [95]:
from utils.snn import SNNSimConfig

EXPORT_CONFIG = True
if EXPORT_CONFIG:
    # Define the simulation configuration
    snn_config = SNNSimConfig(ground_truth, init_offset, virtual_time_step_interval, num_steps)

    snn_config_file_name = f"{OUTPUT_FOLDER}/{band_file_name}_snn_config_{TIME_SUFFIX}.npy"
    # Save the SNN Config Class to a npy file
    np.save(snn_config_file_name, snn_config)

## Stop the Runtime

In [96]:
lif1.stop()