# Direct Neural Biasing - Python Demo

### STEP 1 - Import Python Modules

In [None]:
from pathlib import Path
from time import time
import re

from scipy import signal
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import mne
import direct_neural_biasing as dnb

### STEP 2 - Define Data Array

#### 2.1 - Read EDF File

In [None]:
file = './data/JS.edf'
raw = mne.io.read_raw_edf(file)
sample_freq = raw.info["sfreq"]
data_full = raw.get_data().flatten()

#### Truncate data array for debug

In [None]:
data = data_full[:1000000]

#### 2.2 ALT - Read Sample CSV

In [None]:
# Temporary hack that imports 'data' from CSV 
# data = pd.read_csv('signal.csv').values.flatten()

### STEP 3 - Setup Signal Processer

#### 3.1 - Create Signal Processor

In [None]:
logging = False # logging=True generates a csv log file
downsample_rate = 100 # 1=full sampling rate. Higher numbers create downsampling. Useful for large files and demos 

signal_processor = dnb.PySignalProcessor(logging, downsample_rate)

#### 3.2 - Create Filter

In [None]:
id = 'simple_filter'
f0 = 0.5 # bandpass filter center frequency
fs = sample_freq # signal sample rate

signal_processor.add_filter(id, 0.5, fs)

#### 3.3 - Create IED Detector

In [None]:
id = 'ied_detector'
filter_id = 'simple_filter' # which filtered_signal should the detector read from
z_score_threshold = 5.0 # threhsold for candidate detection event
buffer_size = 10 # length of buffer - to increase noise resistance
sensitivy = 0.5 # Between 0 and 1. Ratio of values in buffer over threshold required to trigger an 'IED Detected' event.

signal_processor.add_threshold_detector(
    id,
    filter_id,
    z_score_threshold,
    buffer_size,
    sensitivy
)

#### 3.4 - Create Slow Wave Detector

In [None]:
id = 'slow_wave_detector'
filter_id = 'simple_filter' # which filtered_signal should the detector read from
sinusoid_threshold = 0.8 # Between 0 and 1 
absolute_min_threshold = 0.0
absolute_max_threshold = 100.0


signal_processor.add_slow_wave_detector(
    id,
    filter_id,
    sinusoid_threshold,
    absolute_min_threshold,
    absolute_max_threshold
)

#### 3.5 - Create Pulse Trigger

In [None]:
id = 'pulse_trigger'
activation_detector_id = 'slow_wave_detector' # which detector triggers a pulse
inhibition_detector_id = 'ied_detector' # which detector triggers an inhibition coolson - an IED in this case
activation_cooldown = 2 # duration in seconds for cooldown after pulse event
inhibition_cooldown = 2 # duration in seconds for cooldown after IED detection

signal_processor. add_pulse_trigger(
    id,
    activation_detector_id,
    inhibition_detector_id,
    activation_cooldown,
    inhibition_cooldown
)

### STEP 4 - Analyse Signal

#### 4.1 - Run

In [None]:
out = signal_processor.run(data)

#### 4.2 - Example Sample Output

In [None]:
out[0]

### STEP 5 - Find Detected Events

In [None]:
events = [item for item in out if item['triggers:pulse_trigger:triggered'] == 1.0]

In [None]:
# events

In [None]:
len(events)

### STEP 6 - Visualise Detected Events

In [None]:
extra_signal_factor = 10

# Function to extract indices and plot each event
def plot_events(data, events, extra_signal_factor, sample_freq, downsample_rate):
    for event in events:
        # Extract the indices from the event keys
        slow_wave_key = next(key for key in event.keys() if 'slow_wave_idx' in key)
        idx_str = re.search(r'slow_wave_idx:([\d, ]+):next_maxima', slow_wave_key).group(1)
        idx_list = list(map(int, idx_str.split(',')))

        # Determine the range for plotting
        event_length = len(idx_list)
        extra_length = event_length * extra_signal_factor
        start_idx = max(0, idx_list[0] - extra_length)
        end_idx = min(len(data), idx_list[-1] + extra_length)

        # Extract the signal segment to plot
        segment = data[start_idx:end_idx]

        # Convert indices to time in seconds
        time = np.arange(start_idx, end_idx) / (sample_freq * downsample_rate)

        # Plot the signal segment
        plt.figure(figsize=(10, 4))
        plt.plot(time, segment, label='Signal')

        # Highlight the event signal
        event_start_idx = idx_list[0]
        event_end_idx = idx_list[-1] + 1
        event_segment = data[event_start_idx:event_end_idx]
        event_times = np.arange(event_start_idx, event_end_idx) / (sample_freq * downsample_rate)
        plt.plot(event_times, event_segment, color='red', label='Event')

        # Plot the predicted next maxima as a green vertical line
        next_maxima_index = int(event[slow_wave_key])
        next_maxima_time = next_maxima_index / (sample_freq * downsample_rate)
        plt.axvline(x=next_maxima_time, color='green', linestyle='--', label='Next Maxima')

        # Add labels and legend
        plt.xlabel('Time (seconds)')
        plt.ylabel('Amplitude')
        plt.title('Event Signal with Next Predicted Maxima')
        plt.legend()
        plt.show()

# Example usage
# Assuming `data` is your signal array and `events` is your list of event dictionaries
plot_events(data, events, extra_signal_factor, sample_freq, downsample_rate)