# Pre-Processing of the SEEG Signal
This notebook presents the pre-processing stages the SEEG signal goes through before being fed to the SNN. The pre-processing stages are as follows:
1. **Filtering**: The SEEG signal is bandpass filtered to remove noise and artifacts. The bandpass filter is designed using the Butterworth filter and, since we are working with *iEEG*, the signal is filtered in the ripples and FR bands. The co-occurrence of HFOs in both bands is an optimal prediction of post-surgical seizure freedom by defining an optimal "HFO area" or EZ zone.
2. **Signal-to-Spike Conversion**: To interface and communicate with the silicon neurons in the SNN, the SEEG signal must be converted to spikes.

## Filtering
Depending on the EEG modality, the signal is filtered in different frequency bands. In this case, since we are handling *iEEG* or *sEEG* data, the signal is filtered in both the ripples (80-250Hz) and FR bands (250-500Hz). The co-occurrence of HFO in these bands represents an optimal prediction of post-surgical seizure freedom by defining an optimal "HFO area" or EZ zone.

The filter is implemented in different ways depending on the setup it will run on.
1. **Neuromorphic Hardware**: The filter is implemented using analog filters. 
2. **Software Simulation**: *Butterworth filters* are utilized since they are a good approximation of the tuned *Tow-Thomas* architectures implemented in hardware.

The frequency response of the *Butterworth filter* is maximally flat in the passband and rolls of towards 0 in the stopband.

### Check WD (change if necessary) and file loading

In [13]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)

# Check if the current WD is the file location
if "/src/hfo" not in os.getcwd():
    # Set working directory to this file location
    file_location = f"{os.getcwd()}/thesis-lava/src/hfo/"
    print("File Location: ", file_location)

    # Change the current working Directory
    os.chdir(file_location)

    # New Working Directory
    print("New Working Directory: ", os.getcwd())

PATH_TO_FILE = '' # 'src/hfo/'  # This is needed if the WD is not the same as the file location

/home/monkin/Desktop/feup/thesis/thesis-lava/src/hfo


In [14]:
import numpy as np
import math

seeg_file_name = "seeg_csl.npy"    # seeg_synthetic_humans.npy
recorded_data = np.load(f"{PATH_TO_FILE}data/{seeg_file_name}")

print("Data shape: ", recorded_data.shape)
print("First time steps: ", recorded_data[:10])

Data shape:  (129239, 86)
First time steps:  [[   1.0633698    34.293705     -5.3168535    31.369436    -37.483818
    -1.8608985     9.0386505  -102.88112       9.304497      6.911907
    38.813034     68.32156     -38.547188    -75.76517       8.241123
    57.422016   -140.36493      63.270557     15.684719     -5.582697
   -10.633707    -21.799099   -145.94763    -128.40201     -79.7528
   -44.927414     12.228767    106.33707     -26.052582   -141.16248
   -55.56112     -59.814606    -10.102022    -18.077301     12.228763
    -1.3292141   -70.979996     48.649216      7.7094383    61.941345
   -62.47303     -15.684718     24.72337     -56.358646    -51.573483
   -75.76517      11.431244     53.70022     103.14696     -36.154602
   -31.369438    -14.887188    -46.788315     60.877975     14.88719
   -32.16696      74.7018     -114.04651     -18.874832     29.774384
    40.408085    -14.621349     21.799103      3.4559555    67.52403
   -32.166965     71.77753     -41.205612    -28.4

### Add the parent directory to the path to detect the utils module

In [15]:
import os
import sys

# Add the parent directory to the path so it detects the utils module
module_path = os.path.abspath(os.path.join('src'))      # Changed this since WD is not the same as the file location
if module_path not in sys.path:
    sys.path.append(module_path)

## Define the Filter

In [16]:
from scipy.signal import butter, lfilter

# ================================================================ #
# ============ Butterworth Filter Coefficients =================== #
# ================================================================ #
def butter_bandpass(lowcut, highcut, sampling_freq, order=5):
    """
    This function is used to generate the coefficients for lowpass, highpass and bandpass
    filtering for Butterworth filters.
    @lowcut, highcut (int): cutoff frequencies for the bandpass filter
    @sampling_freq (float): sampling_frequency frequency of the wideband signal
    @order (int): filter order

    - return b, a (float): filtering coefficients that will be applied on the wideband signal
    """
    nyq = 0.5 * sampling_freq   # Nyquist frequency
    low = lowcut / nyq          # Normalizing the cutoff frequencies
    high = highcut / nyq        # Normalizing the cutoff frequencies

    return butter(order, [low, high], btype='band')    

# ================================================================ #
# ====================== Butterworth Filters ===================== #
# ================================================================ #
def butter_bandpass_filter(data, lowcut, highcut, sampling_freq, order=5):
    """
    This function applies the filtering coefficients calculated above to the wideband signal (original signal).
    @data (array): Array with the amplitude values of the wideband signal.
    @lowcut, highcut (int): cutoff frequencies for the bandpass filter.
    @sampling_freq (float): sampling frequency of the original signal.
    @order (int): filter order.

    - return (array): Array with the amplitude values of the filtered signal.
    """
    coef_b, coef_a = butter_bandpass(lowcut, highcut, sampling_freq, order)

    return lfilter(coef_b, coef_a, data)
    

### Extract a single SEEG channel from the SEEG data

In [17]:
selected_ch_idx = 30 # 1 32 63 94
seeg_ch = list(map(lambda all_channels: all_channels[selected_ch_idx], recorded_data))  # Selecting the first channel
seeg_ch_np = np.array(seeg_ch)

print(f"seeg_ch_np shape: {seeg_ch_np.shape}. \nPreview: {seeg_ch_np}")

seeg_ch_np shape: (129239,). 
Preview: [-55.56112  -56.890335 -58.219543 ... -48.915054 -52.63685  -48.117523]


## Define Global Parameters of the Experiment

In [18]:
sampling_rate = 2048    # 2048 Hz
input_duration = 64 * (10**3)   # 120 * (10**3)    # 120000 ms or 120 seconds
num_samples = recorded_data.shape[0]    # 2048 * 120 = 245760
num_channels = recorded_data.shape[1]   # 960

x_step = 1 / sampling_rate * (10**3)  # 0.48828125 ms

## Apply the Butterworth filter to the channel

In [19]:
filter_order = 9
# Filter the raw signal in the Ripple band (80-250 Hz)
ripple_band_seeg = butter_bandpass_filter(seeg_ch_np, 80, 250, sampling_rate, 9)    # TODO: Check if the order is correct

# Filter the raw signal in the Fast Ripple band (250-500 Hz)
fr_band_seeg = butter_bandpass_filter(seeg_ch_np, 250, 500, sampling_rate, 9)

## Import the Markers (Annotated Events) 
The markers are stored in a numpy array of shape (num_channels, events):
- Each row represents the events of a channel
- Each event is composed of the following 3 fields (Label, Position, Shape)

In [20]:
markers_seeg_file_name = "seeg_csl_markers.npy"
markers = np.load(f"{PATH_TO_FILE}data/{markers_seeg_file_name}", allow_pickle=True)

print("Markers shape: ", markers.shape)
print("First time steps: ", markers[:10])

Markers shape:  (86,)
First time steps:  [list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])
 list([array([],
       dtype=[('label', '<U64'), ('position', '<f4'), ('duration', '<f4')])])]


## Visualize the filtered signals

In [21]:
# Interactive Plot for the HFO detection
# bokeh docs: https://docs.bokeh.org/en/2.4.1/docs/first_steps/first_steps_1.html

from utils.line_plot import create_fig  # Import the function to create the figure
from bokeh.models import Range1d

# Define the x and y values
# Should the first input start at 0 or x_step?
# TODO: is it okay to create a range with floats?
x = [val for val in np.arange(x_step, input_duration + x_step, x_step)] 

# Create the y arrays for the voltage plot representing the voltage of each electrode
v_yarrays = [ripple_band_seeg, fr_band_seeg]

## Create the Plot

In [22]:
# Create the plot
# List of tuples containing the y values and the legend label
hfo_y_arrays = [(ripple_band_seeg, "Ripple Band"), (fr_band_seeg, "Fast Ripple Band")]

# Create the SEEG Voltage plot
hfo_plot = create_fig(
    title="SEEG Voltage dynamics of Filtered Ripple and Fast Ripple Bands", 
    x_axis_label='time (ms)', 
    y_axis_label='Voltage (μV)',
    x=x, 
    y_arrays=hfo_y_arrays, 
    sizing_mode="stretch_both", 
    tools="pan, box_zoom, wheel_zoom, hover, undo, redo, zoom_in, zoom_out, reset, save",
    tooltips="Data point @x: @y",
    legend_location="top_right",
    legend_bg_fill_color="navy",
    legend_bg_fill_alpha=0.1,
    # y_range=Range1d(-0.05, 1.05)
)

# If there are more than 30 channels, hide the legend
if len(hfo_y_arrays) > 30:
    # Hide the legend
    hfo_plot.legend.visible = False



## Add Box Annotations to the plot to identify the marked HFOs (ground truth)

In [23]:
from bokeh.models import BoxAnnotation
# from utils.line_plot import color_map

show_markers = False    # Boolean to show the markers

color_map = {                   # Map the label to a color
    'Fast Ripple': 'brown',
    'Ripple': 'yellow',  
}

confidence_range = 100          # TODO: Check this value. When the duration is missing (0), we consider the 200ms window around the marked position 
visited_markers = {}    # Avoid inserting multiple boxes for the same marker (only one of each label)
use_visited = False     # Boolean controlling if we remove duplicate markers
plot_instant = True     # Boolean to plot the markers as instant events or as boxes
instant_width = 100 # 20       # Width of the instant event for visualization purposes

channels_used = {selected_ch_idx}   # Set of channels to be used
if show_markers:
    for ch_idx in channels_used:
        channel_markers = markers[ch_idx]
        # print("channel_markers", channel_markers)
        for idx2, marker in enumerate(channel_markers):
            # print("marker:", marker)
            if len(marker) == 0:
                continue    # Skip empty markers
            
            if use_visited:
                # Check if the marker has already been visited and skip it if it has
                if marker['position'] in visited_markers:
                    visited_labels = visited_markers[marker['position']]    # Get the labels that already have an annotation for this position
                    if marker['label'] in visited_labels:
                        # print("Skipping marker", marker['position'], marker['label'])
                        continue    # Skip this marker
                    else:
                        visited_labels.append(marker['label'])  # Add the label to the visited labels
                else:
                    visited_markers[marker['position']] = [marker['label']] # Add the marker to the visited markers

            # Add a box annotation for each marker
            has_duration = marker['duration'] > 0
            
            confidence_constant = 0 if plot_instant or has_duration else confidence_range

            left = marker['position'] - confidence_constant
            right = marker['position'] + confidence_constant + instant_width
            box_color = color_map[marker['label']]  # Choose a color according to the label
            
            # if left < min_t or right > max_t:
            #     continue    # Skip this marker
            

            box = BoxAnnotation(left=left, right=right, fill_color=box_color, fill_alpha=0.1)
            # print("Added marker for channel: ", ch_idx, " at position: ", left)
            hfo_plot.add_layout(box)

## Show the Plot

In [24]:
import bokeh.plotting as bplt

showPlot = True
if showPlot:
    bplt.show(hfo_plot)

## Export the plot to a file

In [53]:
export = False

if export:
    file_path = f"{PATH_TO_FILE}results/clinical/filtered_seeg_csl_ch30.html"

    # Customize the output file settings
    bplt.output_file(filename=file_path, title="SEEG Data - Filtered Voltage dynamics across time")

    # Save the plot
    bplt.save(hfo_plot)

## Checkpoint 1

Right now, we have the filtered SEEG signal in both the ripple and FR bands. The next step is to convert the signal to spikes.

## Signal-to-Spike Conversion
The signal can be converted to spikes in different ways. First, we will try a method where **two spike trains are generated from the filtered signal**:
- **UP Spike Train**: The spikes are generated based on an increase of the signal's amplitude. The spikes are generated when the signal crosses a certain threshold defined by `threshold_up`.
- **DOWN Spike Train**: The spikes are generated based on a decrease of the signal's amplitude. The spikes are generated when the signal crosses a certain threshold defined by `threshold_down`.

The spike trains are generated by comparing the amount of change in the signal since the last time a spike was generated (UP or DOWN). If the positive/negative amplitude change is greater than the defined threshold, the algorithm stores the current timestep in the respective spike train and takes the new amplitude as the reference for the next comparison.

Another important aspect of this algorithm is to model the time that silicon neurons need before they can generate another spike. Both in hardware and software, we call this time `refractory_period`.

### Configurable Parameters for the Signal-to-Spike Conversion
- `threshold_up`: The threshold for the UP spike train.
- `threshold_down`: The threshold for the DOWN spike train.

The accuracy of this algorithm is heavily dependent on the choice of these parameters. To find the optimal values, we can perform a ***baseline detection*** to determine the optimal spike generation threshold automatically for the signal conversion.

**As a first solution, we set these values manually to have a working prototype. Later, we will find the optimal values and compare the results of both methods.**

In [16]:
# Define variables of the Signal to Spike Conversion Manually
ripple_threshold_up = 5   # Threshold for the UP spike detection (in μV)
ripple_threshold_down = -5 # Threshold for the DOWN spike detection (in μV)

fr_threshold_up = 3   # Threshold for the UP spike detection (in μV)
fr_threshold_down = -3 # Threshold for the DOWN spike detection (in μV)

In [17]:
from hfo.signal_to_spike.signal_to_spike import signal_to_spike, SignalToSpikeParameters

# Convert the filtered ripple signal to spikes
ripple_spike_trains = signal_to_spike(
    SignalToSpikeParameters(
        signal=ripple_band_seeg, times=np.array(x),
        threshold_up=ripple_threshold_up, threshold_down=ripple_threshold_down,
        # refractory_period=0.002, interpolation_factor=1
        )
)

np.set_printoptions(edgeitems=5)
print("Ripple UP Spike Train shape: ", ripple_spike_trains.up.shape, "Preview: ", ripple_spike_trains.up)
print("Ripple DOWN Spike Train shape: ", ripple_spike_trains.down.shape, "Preview: ", ripple_spike_trains.down)

Ripple UP Spike Train shape:  (307,) Preview:  [  4153.3203125    4155.76171875   4165.52734375   7249.0234375
   7272.4609375    7273.92578125   7285.15625      7286.62109375
  13386.23046875  13397.4609375   13398.92578125  16553.7109375
  16554.6875      16555.6640625   16556.640625    16564.94140625
  16565.4296875   16565.91796875  16566.40625     16566.89453125
  16567.3828125   16567.87109375  16568.359375    16568.84765625
  16577.1484375   16577.63671875  16578.125       16578.61328125
  16579.1015625   16579.58984375  16580.078125    16581.0546875
  16589.35546875  16590.33203125  16591.30859375  16606.4453125
  17410.64453125  17992.1875      18169.43359375  18180.17578125
  18375.9765625   18928.22265625  19090.8203125   19542.48046875
  19783.69140625  20202.1484375   21666.50390625  22318.359375
  22545.41015625  22551.7578125   22558.59375     22559.5703125
  22565.4296875   22566.89453125  25450.1953125   25461.9140625
  25476.07421875  25558.59375     25683.59375     2

In [18]:
# Convert the filtered FR signal to spikes
fr_spike_trains = signal_to_spike(
    SignalToSpikeParameters(
        signal=fr_band_seeg, times=np.array(x),
        threshold_up=fr_threshold_up, threshold_down=fr_threshold_down,
        # refractory_period=0.002, interpolation_factor=1
        )
)

np.set_printoptions(edgeitems=5)
print("Fast Ripple UP Spike Train shape: ", fr_spike_trains.up.shape, "Preview: ", fr_spike_trains.up)
print("Fast Ripple DOWN Spike Train shape: ", fr_spike_trains.down.shape, "Preview: ", fr_spike_trains.down)

Fast Ripple UP Spike Train shape:  (198,) Preview:  [  1005.37109375   1008.7890625    1012.20703125   1012.6953125
   1016.11328125   1020.5078125   13357.421875    13359.375
  13359.86328125  13362.3046875   13362.79296875  13365.72265625
  13370.1171875   13372.55859375  19777.34375     19780.2734375
  19783.69140625  27168.9453125   27170.8984375   27171.38671875
  27173.828125    27174.31640625  27177.24609375  27184.5703125
  27187.98828125  29808.59375     29811.03515625  29813.4765625
  29829.1015625   29833.0078125   29836.42578125  29839.35546875
  29843.75        29848.14453125  32293.45703125  32295.8984375
  32298.33984375  32298.828125    32301.7578125   32305.6640625
  32309.08203125  32312.5         33736.328125    34437.01171875
  34633.30078125  34635.7421875   34638.18359375  34640.625
  34643.5546875   34646.484375    34649.90234375  34653.80859375
  34657.2265625   34661.1328125   36347.65625     37800.29296875
  37804.6875      37807.6171875   37810.05859375  3781

## Visualize the Spike Trains
Let's plot the generated spike trains via a raster plot.

In [19]:
from utils.raster_plot import create_raster_fig

# ------------------------------------------------------------------------------- #
# ------- Create the raster plot for the Ripple UP and DOWN spike trains -------- #
# ------------------------------------------------------------------------------- #

# Create a list containing the x values for the raster plot.
ripple_raster_x = np.concatenate((ripple_spike_trains.up, ripple_spike_trains.down), axis=0)

# Create a list containing the y values for the raster plot.
# The UP spike train will be represented by 1s and the DOWN spike train by 0s
ripple_raster_y = [1 for _ in range(len(ripple_spike_trains.up))] + [0 for _ in range(len(ripple_spike_trains.down))]

ripple_train_raster = create_raster_fig("Ripple UP and DOWN spike events", "Time (ms)", "Channel", ripple_raster_x, ripple_raster_y)

In [20]:
showRasterPlot = True

# Plot the raster plot for the Ripple spike trains
if showRasterPlot:
    bplt.show(ripple_train_raster)

In [21]:
# ------------------------------------------------------------------------------- #
# ------- Create the raster plot for the Fast Ripple UP and DOWN spike trains -------- #
# ------------------------------------------------------------------------------- #

# Create a list containing the x values for the raster plot.
fr_raster_x = np.concatenate((fr_spike_trains.up, fr_spike_trains.down), axis=0)

# Create a list containing the y values for the raster plot.
# The UP spike train will be represented by 1s and the DOWN spike train by 0s
fr_raster_y = [1 for _ in range(len(fr_spike_trains.up))] + [0 for _ in range(len(fr_spike_trains.down))]

fr_train_raster = create_raster_fig("Fast Ripple UP and DOWN spike events", "Time (ms)", "Channel", fr_raster_x, fr_raster_y)

In [22]:
# Plot the raster plot for the Fast Ripple spike trains
if showRasterPlot:
    bplt.show(fr_train_raster)

## Export the Spike Trains to CSV Files for the Lava SNN

We have successfully converted the SEEG signal to spikes. The next step is to feed these spikes to the SNN for Ripple and Fast Ripple detection.

For this, we will create a file for each type of spike train (UP and DOWN). I'm not sure if we should join the spikes of both bands in a single file or keep them separate.

In [23]:
# Create a csv file with the spike train data
import csv

def write_spike_train_to_csv(file_name, spike_train, channel_idx):
    """
    This function writes the spike train to a csv file.
    @file_name (str): Name of the file to be created.
    @spike_train (np.ndarray): Array with the spike train data.
    @channel_idx (int): Index of the channel that generated the spike train. (According to the original data)
    """
    with open(file_name, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["time", "channel_idx"])
        for spike_time in spike_train:
            writer.writerow([spike_time, channel_idx])

In [24]:
WRITE_RIPPLE_CSV_FILES = False
WRITE_FR_CSV_FILES = False

In [25]:
if WRITE_RIPPLE_CSV_FILES:
    # Create the csv file for the Ripple UP spike train
    ripple_up_file_name = f"{PATH_TO_FILE}snn/data/ripple_up_spike_train_5.csv"
    write_spike_train_to_csv(ripple_up_file_name, ripple_spike_trains.up, selected_ch_idx)

    # Create the csv file for the Ripple DOWN spike train
    ripple_down_file_name = f"{PATH_TO_FILE}snn/data/ripple_down_spike_train_-5.csv"
    write_spike_train_to_csv(ripple_down_file_name, ripple_spike_trains.down, selected_ch_idx)

In [26]:
if WRITE_FR_CSV_FILES:
    # Create the csv file for the Fast Ripple UP spike train
    fr_up_file_name = f"{PATH_TO_FILE}snn/data/fr_up_spike_train_3.csv"
    write_spike_train_to_csv(fr_up_file_name, fr_spike_trains.up, selected_ch_idx)

    # Create the csv file for the Fast Ripple DOWN spike train
    fr_down_file_name = f"{PATH_TO_FILE}snn/data/fr_down_spike_train_-3.csv"
    write_spike_train_to_csv(fr_down_file_name, fr_spike_trains.down, selected_ch_idx)

## Generate the input files for the SNN

### Is the SNN going to detect the HFO events in windows? Or do we take it as a continous input?
**If so**: The SNN is going to detect HFO events in windows. Therefore, the input to the network must be organized in windows of a certain size. The size of the window is a hyperparameter that can be tuned to improve the performance of the network.

I think windowing makes more sense when we are learning with an ANN. Since we want real-time detection, feeding a continous input makes more sense.

### Let's assume we do NOT need to window the input
In this case, our input will simply be a continous stream of spikes. We can feed the spikes to the SNN in real-time. At each timestep, the SNN will receive 2 binary inputs (UP and DOWN spikes) indicating the presence of a spike in the respective spike train.

In [27]:
# Create a numpy array that will store the input (2D)
snn_input = np.zeros((num_samples, 2))  # 2 columns: UP and DOWN spike trains

# ---------------------------------------------------------------------------------- #
# -------- Select the Spike Trains to be used as input for the SNN ----------------- #
# ---------------------------------------------------------------------------------- #
selected_up_spikes = ripple_spike_trains.up
selected_down_spikes = ripple_spike_trains.down

# Iterate the time steps of the recording and check if there are spikes in the selected spike trains at each timestep
curr_up_idx = 0
curr_down_idx = 0
for (idx, time_step) in enumerate(x):
    # Check if an UP spike occurs at this time step
    if curr_up_idx < len(selected_up_spikes) and selected_up_spikes[curr_up_idx] <= time_step:
        snn_input[idx][0] = 1   # Mark the UP spike in the input array
        curr_up_idx += 1    # Move to the next spike in the UP spike train
    # Check if a DOWN spike occurs at this time step
    elif curr_down_idx < len(selected_down_spikes) and selected_down_spikes[curr_down_idx] <= time_step:
        snn_input[idx][1] = 1   # Mark the DOWN spike in the input array
        curr_down_idx += 1  # Move to the next spike in the DOWN spike train
    
    if curr_up_idx >= len(selected_up_spikes) and curr_down_idx >= len(selected_down_spikes):
        # All the spikes have been added to the input array
        break

np.set_printoptions(edgeitems=5)
print("snn_input: ", snn_input.shape, "Preview:", snn_input)

snn_input:  (245760, 2) Preview: [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 ...
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]


In [28]:
# Export the input array to a numpy file
EXPORT_INPUT = False
if EXPORT_INPUT:
    input_file_name = f"{PATH_TO_FILE}snn/data/ripple_train_5_-5/snn_input_ripple_5_-5.npy"
    np.save(input_file_name, snn_input)

    # Export to CSV for visualization purposes
    with open(f"{PATH_TO_FILE}snn/data/ripple_train_5_-5/snn_input_ripple_5_-5.csv", mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["time", "up_spike", "down_spike"])
        for (idx, time_step) in enumerate(x):
            writer.writerow([time_step, snn_input[idx][0], snn_input[idx][1]])

## Generate the Ground Truth File for the SNN (Target)

Extract the markers from the selected channel

In [38]:
from utils.io import preview_np_array

# The target output for the SNN must have the same length as the input.
target_np = np.zeros((num_samples))  # 1 column: 0/1 for the output classes (No Event, Ripple/Fast Ripple or both)
# TODO: Could have more than 2 classes to differentiate the labels

# Get the markers for the selected channel
# Each marker has the following keys:   position, label, and duration
selected_ch_markers = markers[selected_ch_idx]
preview_np_array(selected_ch_markers, "Selected Channel Markers")

Selected Channel Markers Shape: (42,).
Preview: [('Spike+Fast-Ripple',   1000.  , 0.) ('Spike+Ripple',   4139.16, 0.)
 ('Spike+Ripple',   7255.86, 0.) ('Spike',  10473.6 , 0.)
 ('Spike+Ripple+Fast-Ripple',  13362.8 , 0.) ...
 ('Spike+Ripple', 108516.  , 0.) ('Ripple', 111657.  , 0.)
 ('Ripple+Fast-Ripple', 114574.  , 0.) ('Spike', 116793.  , 0.)
 ('Spike+Ripple', 119000.  , 0.)]


### Export to a CSV File for the Lava SNN

In [None]:
# Export the target array to a numpy file
EXPORT_TARGET_1 = False

if EXPORT_TARGET_1:
    target_file_name = f"{PATH_TO_FILE}snn/ground_truth/instants_ch-{selected_ch_idx}.npy"
    np.save(target_file_name, target_np)

### Convert to the format that the learnable SNN (Slayer) can read

#### Fill the target numpy array with 1s where the HFOs are present

In [55]:
from utils.input import label_has_hfo_event

# Iterate the time steps of the recording and check if there is an annotated event at each timestep
curr_markers_idx = 0
for (idx, time_step) in enumerate(x):
    # Check if an event occurs at this time step
    if curr_markers_idx < len(selected_ch_markers) and selected_ch_markers[curr_markers_idx]['position'] <= time_step:
        # If the label has an HFO event, mark it as 1 in the target array
        if label_has_hfo_event(selected_ch_markers[curr_markers_idx]['label']):
            target_np[idx] = 1   # Mark the Labelled event in the target array
        
        curr_markers_idx += 1    # Move to the next annotated event
    
    if curr_markers_idx >= len(selected_ch_markers):
        # All the spikes have been added to the input array
        break

In [56]:
preview_np_array(target_np, "target_np")

target_np Shape: (245760,).
Preview: [0. 0. 0. 0. 0. ... 0. 0. 0. 0. 0.]


## Export the target file to a numpy file

In [57]:
# Export the target array to a numpy file
EXPORT_TARGET = False
if EXPORT_TARGET:
    target_file_name = f"{PATH_TO_FILE}snn/ground_truth/ch-{selected_ch_idx}.npy"
    np.save(target_file_name, target_np)