# Direct Neural Biasing - Python Demo

## STEP 1 - Python Imports

In [None]:
from pathlib import Path
from time import time
import re

from scipy import signal
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import mne
import direct_neural_biasing as dnb

## STEP 2 - Define Data Array

#### 2.1 - Read EDF File

In [None]:
file = './data/JS.edf' # place an EDF file in the /data dir

raw = mne.io.read_raw_edf(file)
sample_freq = raw.info["sfreq"]
data_full = raw.get_data().flatten()

In [None]:
# Temporary hack that imports 'data' from CSV 
# data = pd.read_csv('signal.csv').values.flatten()

#### 2.2 - Truncate data array for debug

In [None]:
data = data_full[:100000]

#### 2.2 ALT - Read Sample CSV

## STEP 3 - Setup SignalProcesser

#### 3.1 - Create Signal Processor

In [None]:
verbose = False # verbose=True gives more verbose output in results object for debugging

signal_processor = dnb.PySignalProcessor(verbose)

#### 3.2 - Create Filters

In [None]:
slow_wave_filter_id = 'bandpass_filter_slow_wave'
f_low = 0.5 # cutoff_low
f_high = 4.0 # cutoff_high
sample_freq = sample_freq # signal sample rate

signal_processor.add_filter(slow_wave_filter_id, f_low, f_high, sample_freq)

In [None]:
ied_filter_id = 'bandpass_filter_ied'
f_low = 80.0 # cutoff_low
f_high = 120.0 # cutoff_high
sample_freq = sample_freq # signal sample rate

signal_processor.add_filter(ied_filter_id, f_low, f_high, sample_freq)

#### 3.3 - Create Slow Wave Detector

In [None]:
activation_detector_id = 'slow_wave_detector'
z_score_threshold = 3.0 # candidate wave amplitude threhsold
sinusoidness_threshold = 0.5 # cosine wave correlation, between 0 and 1.

signal_processor.add_slow_wave_detector(
    activation_detector_id,
    slow_wave_filter_id, # which filtered_signal should the detector read from
    z_score_threshold,
    sinusoidness_threshold,
)

#### 3.4 - Create IED Detector

In [None]:
inhibition_detector_id = 'ied_detector'
z_score_threshold = 5.0 # threhsold for candidate detection event
buffer_size = 10 # length of buffer - to manage noise resistance
sensitivy = 0.5 # Between 0 and 1. Ratio of values in buffer OVER threshold required to trigger an 'IED Detected' event.

signal_processor.add_threshold_detector(
    inhibition_detector_id,
    ied_filter_id, # which filtered_signal should the detector read from
    z_score_threshold,
    buffer_size,
    sensitivy
)

#### 3.5 - Create Pulse Trigger

In [None]:
trigger_id = 'pulse_trigger'
inhibition_cooldown_ms = 2000 # duration in seconds for cooldown after IED detection
pulse_cooldown_ms = 2000 # duration in ms ms cooldown after pulse event

signal_processor.add_pulse_trigger(
    trigger_id,
    activation_detector_id, # which detector triggers a pulse - SlowWave in this case
    inhibition_detector_id, # which detector triggers an inhibition cooldown - IED in this case
    inhibition_cooldown_ms,
    pulse_cooldown_ms
)

## STEP 4 - Analyse Signal

#### 4.1 - Run

In [None]:
import os
import direct_neural_biasing as dnb

# Set the RUST_BACKTRACE environment variable
os.environ['RUST_BACKTRACE'] = '1'

# Assuming `signal_processor` is an instance of your PySignalProcessor
# and `data` is your input data
try:
    out = signal_processor.run(data)
except Exception as e:
    print(f"An error occurred: {e}")

In [None]:
out = signal_processor.run(data)

#### 4.2 - Example Sample Output

In [None]:
len(out)

In [None]:
out[0]

## STEP 5 - Check Filtered Signal Against Scipy

#### 5.1 - Prepare Signal Arrays

In [None]:
raw_signal_downsampled = signal.decimate(data, downsample_rate) # downsample to match dnb_raw_signal
dnb_raw_signal = [sample['global:raw_sample'] for sample in out]
dnb_filtered_signal = [sample[f'filters:{filter_id}:filtered_sample'] for sample in out]

#### 5.2 - Create Scipy Filtered Signal

In [None]:
# SciPy filtering parameters
order = 2
lowcut = 0.25  # Low cut frequency in Hz
highcut = 4.0  # High cut frequency in Hz

# Design a Butterworth bandpass filter using SciPy
b, a = signal.butter(order, [lowcut, highcut], 'bp', fs= (sample_freq / downsample_rate))

# Apply the filter to the raw signal using SciPy
scipy_filtered_signal = signal.lfilter(b, a, raw_signal_downsampled)

#### 5.3 - Truncate Signals for Easier Debug

In [None]:
start_idx = 0
end_idx = 1000

raw_signal_truncated = raw_signal_downsampled[start_idx:end_idx]
dnb_raw_signal_truncated = dnb_raw_signal[start_idx:end_idx]
dnb_filtered_signal_truncated = dnb_filtered_signal[start_idx:end_idx]
scipy_filtered_signal_truncated = scipy_filtered_signal[start_idx:end_idx]

#### 5.4 - Plot Graphs

In [None]:
# Convert sample indices to time in seconds
time = np.arange(len(raw_signal_truncated)) / (sample_freq * downsample_rate)

# Plot the four graphs
fig, axs = plt.subplots(4, 1, figsize=(12, 12), sharex=True)

# Plot the raw signal
axs[0].plot(time, raw_signal_truncated, label='Raw Signal')
axs[0].set_title('Raw Signal')
axs[0].set_ylabel('Amplitude')
axs[0].legend()

# Plot the dnb_raw signal
axs[1].plot(time, dnb_raw_signal_truncated, label='DNB - Raw Signal')
axs[1].set_title('DNB - Raw Signal')
axs[1].set_ylabel('Amplitude')
axs[1].legend()

# Plot the dnb_filtered signal
axs[2].plot(time, dnb_filtered_signal_truncated, label='DNB - Filtered Signal', color='orange')
axs[2].set_title('DNB - Filtered Signal')
axs[2].set_ylabel('Amplitude')
axs[2].legend()

# Plot the sciPy_filtered signal
axs[3].plot(time, scipy_filtered_signal_truncated, label='SciPy - Filtered Signal', color='green')
axs[3].set_title('SciPy - Filtered Signal')
axs[3].set_xlabel('Time (seconds)')
axs[3].set_ylabel('Amplitude')
axs[3].legend()

# Show the plots
plt.tight_layout()
plt.show()

## STEP 6 - Find Detected Events

In [None]:
events = [item for item in out if item[f'triggers:{trigger_id}:triggered'] == 1.0]

In [None]:
len(events)

## STEP 7 - Visualise Detected Events

In [None]:
signal_padding_factor = 2

# Function to extract indices and plot each event
def plot_events(data, events, signal_padding_factor, sample_freq, downsample_rate):
    for event in events:
        # Extract the indices from the event keys
        slow_wave_key = next(key for key in event.keys() if 'slow_wave_idx' in key)
        idx_str = re.search(r'slow_wave_idx:([\d, ]+):next_maxima', slow_wave_key).group(1)
        idx_list = list(map(int, idx_str.split(',')))

        # Determine the range for plotting
        event_length = len(idx_list)
        extra_length = event_length * signal_padding_factor * downsample_rate
        start_idx = max(0, idx_list[0] - extra_length)
        end_idx = min(len(data), idx_list[-1] + extra_length)

        # Extract the signal segment to plot
        segment = data[start_idx:end_idx]

        # Convert indices to time in seconds
        time = np.arange(start_idx, end_idx) # / (sample_freq * downsample_rate)

        # Plot the signal segment
        plt.figure(figsize=(10, 4))
        plt.plot(time, segment, label='Signal')

        # Highlight the event signal
        event_start_idx = idx_list[0]
        event_end_idx = idx_list[-1]
        event_segment = data[event_start_idx:event_end_idx]
        event_times = np.arange(event_start_idx, event_end_idx) # / (sample_freq * downsample_rate)
        plt.plot(event_times, event_segment, color='red', label='Event')

        # Plot the predicted next maxima as a green vertical line
        next_maxima_index = int(event[slow_wave_key])
        next_maxima_time = next_maxima_index # / (sample_freq * downsample_rate)
        plt.axvline(x=next_maxima_time, color='green', linestyle='--', label='Next Maxima')

        # Add labels and legend
        plt.xlabel('Time (seconds)')
        plt.ylabel('Amplitude')
        plt.title('Event Signal with Next Predicted Maxima')
        plt.ticklabel_format(useOffset=False)
        plt.legend()
        plt.show()

# Example usage
# Assuming `data` is your signal array and `events` is your list of event dictionaries
plot_events(dnb_filtered_signal, events, signal_padding_factor, sample_freq, downsample_rate)

In [None]:
events[0]