In [111]:
from utils import load_buffers,load_task_markers
import pandas as pd
import config
import os
import numpy as np

from mne import create_info
from mne.epochs import EpochsArray
from mne.decoding import Vectorizer

from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline

In [54]:
# load the task and signal files stored in the disk
# Based on the task timestamps, subset the signals and annotate the events
# create epoch data that includes the signal and events
# use the epoch data to train a classifier
# save the classifier model as a pickle file for prediction later

# Load the .npz file for task
file_path = 'processed_data/task_data/task_markers.npz'  # Replace with your file path
df_tasks = load_task_markers(file_path)

# Load the .npz file for buffer
folder_path = 'processed_data/signal_data/'
df_buffers = load_buffers(folder_path)

print(df_buffers)
print("Reading signals and markers from signals from disk...")

df_buffers = df_buffers.sort_values('timestamps').reset_index(drop=True)
df_tasks = df_tasks.sort_values('timestamps').reset_index(drop=True)


Array names in the .npz file: ['event_ids', 'timestamps']
      buffer_col_0  buffer_col_1  buffer_col_2  buffer_col_3  buffer_col_4  \
0      710440.6875   710383.6875   711033.0000   710989.0625   710835.6875   
1      710439.1250   710382.4375   711030.9375   710987.5000   710834.2500   
2      710439.5625   710382.3750   711031.5625   710987.8125   710834.6875   
3      710440.3125   710383.3750   711032.6875   710988.6250   710835.1875   
4      710442.1250   710385.0625   711034.1250   710990.2500   710836.6250   
...            ...           ...           ...           ...           ...   
2620   710441.1875   710385.3125   711034.1250   710990.4375   710836.5625   
2621   710441.3750   710385.5000   711034.1250   710990.8125   710836.8125   
2622   710441.5000   710385.6875   711034.3125   710990.8125   710837.0000   
2623   710442.1250   710386.3750   711034.9375   710991.4375   710837.7500   
2624   710441.1250   710385.2500   711034.3125   710990.8125   710837.2500   

     

In [55]:
# Perform an as of merge to find the closest earlier and later event_id
df_buffers['prev_event_id'] = pd.merge_asof(df_buffers, df_tasks,
                                            left_on='timestamps', right_on='timestamps',
                                            direction='backward')['event_ids']

df_buffers['next_event_id'] = pd.merge_asof(df_buffers, df_tasks,
                                            left_on='timestamps', right_on='timestamps',
                                            direction='forward')['event_ids']

markers = {
    'forward_start': [1],
    'forward_end': [2],
    'reverse_start': [3],
    'reverse_end': [4],
    'rest_start': [99],
    'rest_end': [100],
    'task_start': [-1],
    'task_end': [-2]
}

def determine_phase(row):
    if row['prev_event_id'] in markers['forward_start'] and row['next_event_id'] in markers['forward_end']:
        return 'forward'
    elif row['prev_event_id'] in markers['reverse_start'] and row['next_event_id'] in markers['reverse_end']:
        return 'reverse'
    elif row['prev_event_id'] in markers['rest_start'] and row['next_event_id'] in markers['rest_end']:
        return 'rest'
    else:
        return 'unknown'

df_buffers['phase'] = df_buffers.apply(determine_phase, axis=1)

grouped = df_buffers.groupby(['epoch_number', 'phase']).size().reset_index(name='count')

sfreq = config.device_details['sfreq']
buffer_duration = config.epoch_information['duration']

valid_phases = ['forward', 'reverse', 'rest']


In [56]:
df_buffers.shape

(2625, 22)

In [57]:
df_buffers_filt = df_buffers[df_buffers['phase'].apply(lambda x: any(item in x for item in valid_phases))]

In [58]:
df_buffers_filt

Unnamed: 0,buffer_col_0,buffer_col_1,buffer_col_2,buffer_col_3,buffer_col_4,buffer_col_5,buffer_col_6,buffer_col_7,buffer_col_8,buffer_col_9,...,buffer_col_12,buffer_col_13,buffer_col_14,buffer_col_15,buffer_col_16,timestamps,epoch_number,prev_event_id,next_event_id,phase
250,710433.9375,710377.7500,711026.7500,710982.7500,710829.2500,711187.7500,711126.9375,711360.5000,0.149414,0.998535,...,0.061035,0.579834,60.0,2812.0,1.0,603453.625703,3,1.0,2.0,forward
251,710433.6250,710377.7500,711026.2500,710982.5000,710829.2500,711187.7500,711127.0625,711360.1875,0.147217,0.998291,...,0.030518,0.396728,60.0,2813.0,1.0,603453.625704,3,1.0,2.0,forward
252,710434.6875,710378.1875,711026.8125,710983.0000,710829.6875,711188.5625,711127.8750,711361.3125,0.149902,0.997314,...,0.030518,0.518799,60.0,2814.0,1.0,603453.629201,3,1.0,2.0,forward
253,710434.6875,710378.1875,711027.0000,710983.1875,710829.8750,711188.8125,711128.0000,711361.6250,0.148193,0.995117,...,0.061035,0.671387,60.0,2815.0,1.0,603453.629210,3,1.0,2.0,forward
254,710434.3750,710378.0625,711026.8125,710983.0000,710829.8750,711188.4375,711127.8750,711361.3125,0.147949,0.992920,...,0.030518,0.793457,60.0,2816.0,1.0,603453.640400,3,1.0,2.0,forward
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2516,710440.1875,710384.0625,711033.0625,710989.4375,710835.6875,711194.1875,711134.1250,711367.1875,0.148193,0.997314,...,0.061035,0.671387,60.0,7203.0,1.0,603471.171721,21,99.0,100.0,rest
2517,710439.7500,710384.1875,711032.8750,710989.3750,710835.5625,711193.9375,711133.7500,711366.9375,0.149414,0.996582,...,0.000000,0.701904,60.0,7204.0,1.0,603471.175673,21,99.0,100.0,rest
2518,710438.3125,710383.0625,711031.9375,710988.1250,710834.5625,711193.3125,711132.7500,711366.0000,0.148926,0.995850,...,-0.061035,0.762939,60.0,7205.0,1.0,603471.175689,21,99.0,100.0,rest
2519,710437.6250,710382.5625,711031.0000,710987.3750,710833.6250,711192.6875,711131.8750,711365.3750,0.149658,0.998291,...,-0.030518,0.823975,60.0,7206.0,1.0,603471.175692,21,99.0,100.0,rest


In [63]:
df_buffers_filt['phase_group'] = (df_buffers_filt['phase'] != df_buffers_filt['phase'].shift()).cumsum()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_buffers_filt['phase_group'] = (df_buffers_filt['phase'] != df_buffers_filt['phase'].shift()).cumsum()


In [64]:
df_buffers_filt

Unnamed: 0,buffer_col_0,buffer_col_1,buffer_col_2,buffer_col_3,buffer_col_4,buffer_col_5,buffer_col_6,buffer_col_7,buffer_col_8,buffer_col_9,...,buffer_col_13,buffer_col_14,buffer_col_15,buffer_col_16,timestamps,epoch_number,prev_event_id,next_event_id,phase,phase_group
250,710433.9375,710377.7500,711026.7500,710982.7500,710829.2500,711187.7500,711126.9375,711360.5000,0.149414,0.998535,...,0.579834,60.0,2812.0,1.0,603453.625703,3,1.0,2.0,forward,1
251,710433.6250,710377.7500,711026.2500,710982.5000,710829.2500,711187.7500,711127.0625,711360.1875,0.147217,0.998291,...,0.396728,60.0,2813.0,1.0,603453.625704,3,1.0,2.0,forward,1
252,710434.6875,710378.1875,711026.8125,710983.0000,710829.6875,711188.5625,711127.8750,711361.3125,0.149902,0.997314,...,0.518799,60.0,2814.0,1.0,603453.629201,3,1.0,2.0,forward,1
253,710434.6875,710378.1875,711027.0000,710983.1875,710829.8750,711188.8125,711128.0000,711361.6250,0.148193,0.995117,...,0.671387,60.0,2815.0,1.0,603453.629210,3,1.0,2.0,forward,1
254,710434.3750,710378.0625,711026.8125,710983.0000,710829.8750,711188.4375,711127.8750,711361.3125,0.147949,0.992920,...,0.793457,60.0,2816.0,1.0,603453.640400,3,1.0,2.0,forward,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2516,710440.1875,710384.0625,711033.0625,710989.4375,710835.6875,711194.1875,711134.1250,711367.1875,0.148193,0.997314,...,0.671387,60.0,7203.0,1.0,603471.171721,21,99.0,100.0,rest,10
2517,710439.7500,710384.1875,711032.8750,710989.3750,710835.5625,711193.9375,711133.7500,711366.9375,0.149414,0.996582,...,0.701904,60.0,7204.0,1.0,603471.175673,21,99.0,100.0,rest,10
2518,710438.3125,710383.0625,711031.9375,710988.1250,710834.5625,711193.3125,711132.7500,711366.0000,0.148926,0.995850,...,0.762939,60.0,7205.0,1.0,603471.175689,21,99.0,100.0,rest,10
2519,710437.6250,710382.5625,711031.0000,710987.3750,710833.6250,711192.6875,711131.8750,711365.3750,0.149658,0.998291,...,0.823975,60.0,7206.0,1.0,603471.175692,21,99.0,100.0,rest,10


In [65]:
df_phase_map = df_buffers_filt[['phase_group', 'phase']].drop_duplicates()

In [67]:
val_sel_columns = config.device_details['relevant_channels_from_device']
columns_to_select = df_buffers.columns[:val_sel_columns].tolist() + ['epoch_number', 'timestamps', 'phase', 'phase_group']
df_buffers_filt = df_buffers_filt[columns_to_select]

In [68]:
channel_names = config.device_details['channels']
rename_dict = {old_name: new_name for old_name, new_name in zip(df_buffers_filt.columns[:val_sel_columns], channel_names)}
df_buffers_renamed = df_buffers_filt.rename(columns=rename_dict)

In [76]:
# df_buffers_renamed

In [79]:
# Creating epochs data
epochs_data = []
events = []

event_id = {'forward': 1, 'reverse': 2, 'rest': 3}
phases = ['forward', 'reverse', 'rest']

In [103]:
max_shape = 0  # Initialize maximum shape

for _, df_phase_group in df_buffers_renamed.groupby('phase_group'):
    data_array = df_phase_group[channel_names].values.T
    max_shape = max(max_shape, data_array.shape[1])
    epochs_data.append(data_array)
    
    phase_group = df_phase_group['phase_group'].iloc[0]
    phase = df_phase_map[df_phase_map['phase_group'] == phase_group]['phase'].iloc[0]
    event = event_id[phase]
    
    events.append([len(epochs_data) - 1, 0, event])

In [106]:
events = np.array(events)

In [107]:
events.shape

(10, 3)

In [114]:
# Pad or truncate each epoch's data to match the maximum shape
for i, data_array in enumerate(epochs_data):
    current_shape = data_array.shape[1]
    if current_shape < max_shape:
        padding = ((0, 0), (0, max_shape - current_shape))
        epochs_data[i] = np.pad(data_array, padding, mode='constant')
    elif current_shape > max_shape:
        epochs_data[i] = data_array[:, :max_shape]

# Convert epoch_data to numpy array
epochs_data = np.array(epochs_data)

# Print the shape of epoch_data
print("Shape of epoch_data:", epochs_data.shape)

Shape of epoch_data: (10, 8, 373)


In [115]:
# Create MNE info structure
info = create_info(ch_names=channel_names, sfreq=config.device_details['sfreq'], ch_types='eeg')

# Create MNE Epochs object
epochs = EpochsArray(epochs_data, info, events, event_id=event_id, tmin=0)

X = epochs.get_data(copy=False)  # Shape: (n_epochs, n_channels, n_times)
y = epochs.events[:, 2]  # Shape: (n_epochs,)


Not setting metadata
10 matching events found
No baseline correction applied
0 projection items activated


In [121]:
y

array([1, 3, 2, 3, 1, 3, 2, 3, 1, 3])

In [117]:

# Reshape the data to (n_samples, n_features)
X = X.reshape(len(X), -1)

# Create and train a classifier
clf = make_pipeline(Vectorizer(), RandomForestClassifier(n_estimators=100))
clf.fit(X, y)

test_index = 0  # Use the first training data point for prediction
test_sample = X[test_index].reshape(1, -1)  # Reshape to (1, n_features)
predicted_label = clf.predict(test_sample)
actual_label = y[test_index]
print(f"Predicted label: {predicted_label[0]}, Actual label: {actual_label}")

Predicted label: 1, Actual label: 1
