In [12]:
from pylsl import StreamInlet, resolve_stream
import numpy as np
import joblib  # Used for loading sklearn models
import sys
import os
import json
#import torch

sys.path.append('./src/processing')
from preprocessing import *

sys.path.append('./models')
from eegconformer import EEGConformer

In [13]:
models_dir = './models/trained/'
results_dir = './results/'

# Configuration
srate = 160  #Sampling rate of the EEG data
epoch_length_sec = 5  # Length of the desired sample in seconds
samples_needed = srate * epoch_length_sec  # Number of samples needed for ~5 seconds
# Manually define from eegbci dataset
channel_names = ['FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FT8', 'T7', 'T8', 'T9', 'T10', 'TP7', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2', 'Iz']


In [None]:

# Configuration
srate = 160  # Sampling rate of the EEG data, adjust as needed
epoch_length_sec = 5  # Length of the desired sample in seconds
samples_needed = srate * epoch_length_sec  # Number of samples needed for ~5 seconds

print("Looking for an EEG stream...")
streams = resolve_stream('type', 'EEG')
inlet = StreamInlet(streams[0])

def collect_and_save_single_sample(inlet, samples_needed):
    buffer = []  # Initialize the buffer to hold collected samples
    timestamps = []  # To store timestamps of each sample

    while len(buffer) < samples_needed:
        # Continuously pull samples
        sample, timestamp = inlet.pull_sample()
        if sample:
            buffer.append(sample)  # Add the sample to the buffer
            timestamps.append(timestamp)  # Add the timestamp

        if len(buffer) >= samples_needed:
            # Once we have enough samples, save and exit
            np.save('sample.npy', np.array(buffer))  # Save the buffer as a numpy file
            print(f"Saved ~{epoch_length_sec}-second sample with {len(buffer)} samples.")
            return  # Exit the function, effectively stopping data collection

# Call the function to collect, save, and then stop
collect_and_save_single_sample(inlet, samples_needed)


In [20]:
# # Test with one saved sample sent 
sample = np.load('sample.npy')
sample.shape

(800, 64)

In [None]:
preprocessed_data = preprocess_single_trial(sample, srate, channel_names)

In [None]:
# csp + lda decode
model_path = os.path.join(models_dir, 'csp_lda.pkl')
loaded_model = joblib.load(model_path)
predicted_labels = loaded_model.predict(preprocessed_data)
predicted_labels

# Save result

In [None]:
# csp + logistic regression decode
model_path = os.path.join(models_dir, 'csp_logistic.pkl')
loaded_model = joblib.load(model_path)
predicted_labels = loaded_model.predict(preprocessed_data)
predicted_labels

In [None]:
# csp + svm decode
model_path = os.path.join(models_dir, 'csp_svm.pkl')
loaded_model = joblib.load(model_path)
predicted_labels = loaded_model.predict(preprocessed_data)
predicted_labels

In [None]:
_, n_chans, n_times = preprocessed_data.shape

In [None]:
# eeg_conformer decode
model = EEGConformer(
    n_outputs= 2,
    n_chans = n_chans,
    sfreq= srate,
    n_times = n_times,
    n_filters_time=40, 
    filter_time_length=25,
    pool_time_length=75,
    pool_time_stride=15,
    drop_prob=0.7,
    att_depth=3,
    att_heads=10,
    att_drop_prob=0.7,
    final_fc_length='auto', # could be 'auto' or int
    return_features=False, # returns the features before the last classification layer if True
    chs_info=None,
    input_window_seconds=None,
    add_log_softmax=True,
)

In [None]:
loaded_model = os.path.join(models_dir, 'cross_subject_conformer.pth')
checkpoint = torch.load(loaded_model)

In [None]:
model.load_state_dict(checkpoint)

In [16]:
import numpy as np
import json
import os

def online_decode(inlet, samples_per_epoch, loaded_model, srate, channel_names, results_dir):
    buffer = []  # Initialize buffer for accumulating samples
    pred_hist = []  # History of predictions

    while True:
        sample, timestamp = inlet.pull_sample()
        if sample:
            buffer.append(sample)

        if len(buffer) >= samples_per_epoch:
            epoch = np.array(buffer[:samples_per_epoch])
            buffer = buffer[samples_per_epoch:]

            prediction = decode_sample(epoch, loaded_model, srate, channel_names)
            # Convert the prediction to a list if it's a numpy array
            if isinstance(prediction, np.ndarray):
                prediction = prediction.tolist()
            pred_hist.append(prediction)

            print(f"Timestamp: {timestamp}, Prediction: {prediction}")

        if len(pred_hist) >= 3:
            # Convert the entire history to a format that's JSON serializable
            data_to_save = json.dumps(pred_hist, default=lambda o: o.tolist() if isinstance(o, np.ndarray) else o)

            with open(os.path.join(results_dir, 'results.txt'), 'w') as file:
                file.write(data_to_save)
                
            print("Saved 3 epochs and their predictions.")
            break

def decode_sample(epoch, loaded_model, srate, channel_names):
    preprocessed_epoch = preprocess_single_trial(epoch, srate, channel_names)  # Assume this function is defined elsewhere
    prediction = loaded_model.predict(preprocessed_epoch)
    return prediction


In [17]:
model_path = os.path.join(models_dir, 'csp_logistic.pkl')
loaded_model = joblib.load(model_path)

print("Looking for an EEG stream...")
streams = resolve_stream('type', 'EEG')
inlet = StreamInlet(streams[0])
online_decode(inlet, samples_needed, loaded_model, srate, channel_names, results_dir)

Looking for an EEG stream...
Creating RawArray with float64 data, n_channels=64, n_times=800
    Range : 0 ... 799 =      0.000 ...     4.994 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 7 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 7.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 6.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 265 samples (1.656 s)

Timestamp: 167079.8842762, Prediction: [0]


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Creating RawArray with float64 data, n_channels=64, n_times=800
    Range : 0 ... 799 =      0.000 ...     4.994 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 7 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 7.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 6.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 265 samples (1.656 s)

Timestamp: 167084.8808496, Prediction: [1]


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Creating RawArray with float64 data, n_channels=64, n_times=800
    Range : 0 ... 799 =      0.000 ...     4.994 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 7 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 7.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 6.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 265 samples (1.656 s)

Timestamp: 167089.8904504, Prediction: [1]
Saved 3 epochs and their predictions.


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


In [None]:
samples_per_epoch = srate * epoch_length_sec

# Resolve the stream
print("Looking for an EEG stream...")
streams = resolve_stream('type', 'EEG', 'name', 'BioSemi')
#inlet = StreamInlet(streams[0])

# Initialize a buffer for accumulating samples
buffer = []

model_path = os.path.join(models_dir, 'csp_logistic.pkl')
loaded_model = joblib.load(model_path)

def decode_sample(epoch, loaded_model, srate, channel_names):
    """
    Process and decode an epoch of EEG data.
    """
    preprocessed_epoch = preprocess_single_trial(epoch, srate, channel_names)
    predicted_labels = loaded_model.predict(preprocessed_data)
    return prediction

def online_decode(inlet):
    """
    Continuously pull samples from the LSL stream and decode them.
    """
    pred_hist = []
    while True:
        # Pull sample from LSL stream
        sample, timestamp = inlet.pull_sample()
        buffer.append(sample)
        
        # Check if buffer has enough samples to form an epoch
        if len(buffer) >= samples_per_epoch:
            epoch = np.array(buffer[:samples_per_epoch])  
            buffer = buffer[samples_per_epoch:]  
            
            # Decode the epoch
            prediction = decode_sample(epoch, loaded_model, srate, channel_names)
            pred_hist.append((timestamp, prediction))
            print(f"Timestamp: {timestamp}, Prediction: {prediction}")

        if len(pred_hist) >= 3:
            # Save the results
            np.save(os.path.join(results_dir, 'results.npy'), np.array(pred_hist))
            return

