In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import mne

from gait_modulation.file_reader import MatFileReader
from gait_modulation.data_processor import DataProcessor
from gait_modulation.viz import Visualise


# Loading the data

In [2]:
# Handle multiple patients with nested directories.
root_directory = '/Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/master_thesis/Chiara/organized_data'
mat_reader = MatFileReader(root_directory, max_workers=1)  #  adjust the number of workers for parallelism

# Read all data from nested folders of multiple patients and sessions
all_data = mat_reader.read_data()
n_sessions = len(all_data)

print(f"Number of sessions: {n_sessions}")

Loading data from file: /Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/master_thesis/Chiara/organized_data/EM_FH_HK/PW_HK59/26_10_22/walking_sync_2_short.mat
Loading data from file: /Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/master_thesis/Chiara/organized_data/EM_FH_HK/PW_HK59/26_10_22/walking_sync_3_short.mat
Loading data from file: /Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/master_thesis/Chiara/organized_data/EM_FH_HK/PW_FH57/no_date/walking_sync_4_short.mat
Loading data from file: /Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/master_thesis/Chiara/organized_data/EM_FH_HK/PW_FH57/no_date/walking_sync_2_short.mat
Loading data from file: /Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/master_thesis/Chiara/organized_data/EM_FH_HK/PW_FH57/no_date/walking_sync_3_short.mat
Loading data from file: /Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/mast

In [3]:
# Access specific sessions for a patient
session = all_data[0] # pick any session e.g. first one to load the meta data

# Extract LFP meta data for subject/session
lfp_metadata = DataProcessor.np_to_dict(session['hdr_LFP'])

# Load LFP parameters
lfp_sfreq = lfp_metadata['Fs'].item()
lfp_ch_names = DataProcessor.rename_lfp_channels(lfp_metadata['labels'])
lfp_n_channels = lfp_metadata['NumberOfChannels'].item()

# Prepare for mne data structure
info = mne.create_info(ch_names=lfp_ch_names[0:6], sfreq=lfp_sfreq, ch_types='dbs', verbose=40)

# Select one event to work with: mod_start
event_of_interest = 'mod_start'
mod_start_event_id = 1

# Define normal walking events
normal_walking_event_id = -1

# Define the event dictionary
event_dict = {
    'mod_start': mod_start_event_id,
    'normal_walking': normal_walking_event_id
}

# Define parameters
epoch_tmin = -2.0
epoch_tmax = 0.0
epoch_duration = epoch_tmax - epoch_tmin
epoch_sample_length = int(epoch_duration * lfp_sfreq)
gap_duration = 10  # At least 10 seconds away from modulation events
gap_sample_length = int(gap_duration * lfp_sfreq)

epochs_list = []
events_list = []

for s in range(n_sessions):
    print(f'Session: {s}')
    session = all_data[s] # Access specific patient/sessions

    # Extract events and lfp data of the subject/session
    lfp_data = session['data_LFP'] * 1e-6  # Convert microvolts to volts
    
    # lfp_raw = mne.io.RawArray(lfp_data, info, verbose=40)
    # lfp_raw.plot(start=0, duration=np.inf, remove_dc=False)
    # plt.show()

    # Handle events
    events_KIN = DataProcessor.np_to_dict(session['events_KIN'])
    events_before_trim, event_dict_before_trim = DataProcessor.create_events_array(events_KIN, lfp_sfreq)

    # Trim the data and adjust the event onsets accordingly
    lfp_data, events_after_trim = DataProcessor.trim_data(lfp_data, events_before_trim, lfp_sfreq)
    lfp_duration = lfp_data.shape[1] / lfp_sfreq
    n_samples = int(lfp_duration * lfp_sfreq)

    # Update raw data after trimming
    lfp_raw = mne.io.RawArray(lfp_data, info, verbose=40)

    # events_mod_start = events_before_trim[events_before_trim[:, 2] == event_dict_before_trim[event_of_interest]]
    events_mod_start = events_after_trim[events_after_trim[:, 2] == event_dict_before_trim[event_of_interest]]
    events_mod_start[:, 1] = s # mark the session nr  
    # print("--->", np.unique(events_mod_start[:, 1]), np.unique(events_mod_start[:, 2], return_counts=True))

    # Rename Gait Modulation Events
    events_mod_start[:, 2] = mod_start_event_id
        
    # Define normal walking events
    normal_walking_events = DataProcessor.define_normal_walking_events(
        normal_walking_event_id, events_mod_start,
        gap_sample_length, epoch_sample_length, n_samples
    )
    
    events_mod_start[:, 1] = s # mark the session nr
    normal_walking_events[:, 1] = s # mark the session nr

    # ## Remove artifacts from raw LFP data using ICA.
    # ica_n_components = 6 # 6 = n_channels.
    # ica = mne.preprocessing.ICA(n_components=ica_n_components, random_state=97, max_iter=800, verbose=40)
    # print(lfp_raw.ch_names)
    # ica.fit(lfp_raw)
    # raw_data_clean = ica.apply(lfp_raw, verbose=40) # Apply ICA to the raw data

    # Combine events and create epochs
    events, epochs = DataProcessor.create_epochs_with_events(
        lfp_raw,
        events_mod_start,
        normal_walking_events,
        mod_start_event_id,
        normal_walking_event_id,
        epoch_tmin,
        epoch_tmax,
        event_dict
    )
    print(f"Total epochs: {len(epochs)}")
    for cls in event_dict.keys():
        print(f"{cls}: {len(epochs[cls])} epochs", end='; ')
    
    # events[:, 1] = s # No need to mark the session nr for events again!
    epochs.events[:, 1] = s # mark the session nr
    
    my_annot = mne.Annotations(
        onset=events[:, 0]/lfp_sfreq,  # in seconds
        duration=len(events)*[epoch_duration],  # in seconds, too
        description=events[:, 2],
    )
    lfp_raw.set_annotations(my_annot)
    
    fig = lfp_raw.plot(start=0, duration=np.inf, show=False) # lfp_duration
    fig.suptitle(f'Session {s}', fontsize=16)
    plt.tight_layout()
    plt.savefig(f'plots/session{s}.png')
    plt.close(fig)
    
    epochs_list.append(epochs)
    events_list.append(events)
    
    print("\n==========================================================")


epochs = mne.concatenate_epochs(epochs_list, verbose=40)
events = np.vstack(events_list)
events = events[np.argsort(events[:, 0])]  # TODO: Sort by onset time

# Preprocessing
## Apply band-pass filtering to the raw LFP data.
l_freq = 1
h_freq = 50
epochs.filter(l_freq=l_freq, h_freq=h_freq, fir_design='firwin', verbose=40)

Session: 0
Number of samples removed: 14738
Number of seconds removed: 58.95 seconds
Total epochs: 105
mod_start: 8 epochs; normal_walking: 97 epochs; Using matplotlib as 2D backend.

Session: 1
No trimming needed as the beginning of signal is not flat.
Total epochs: 64
mod_start: 6 epochs; normal_walking: 58 epochs; 
Session: 2
Number of samples removed: 2282
Number of seconds removed: 9.13 seconds
Total epochs: 25
mod_start: 15 epochs; normal_walking: 10 epochs; 
Session: 3
Number of samples removed: 7049
Number of seconds removed: 28.20 seconds
Total epochs: 21
mod_start: 12 epochs; normal_walking: 9 epochs; 
Session: 4
Number of samples removed: 1950
Number of seconds removed: 7.80 seconds
Total epochs: 53
mod_start: 8 epochs; normal_walking: 45 epochs; 

  lfp_raw.set_annotations(my_annot)
  lfp_raw.set_annotations(my_annot)



Session: 5
No trimming needed as the beginning of signal is not flat.
Total epochs: 28
mod_start: 11 epochs; normal_walking: 17 epochs; 

  lfp_raw.set_annotations(my_annot)



Session: 6
Number of samples removed: 1819
Number of seconds removed: 7.28 seconds
Total epochs: 40
mod_start: 8 epochs; normal_walking: 32 epochs; 
Session: 7
Number of samples removed: 672
Number of seconds removed: 2.69 seconds
Total epochs: 26
mod_start: 11 epochs; normal_walking: 15 epochs; 
Session: 8
No trimming needed as the beginning of signal is not flat.
Total epochs: 40
mod_start: 13 epochs; normal_walking: 27 epochs; 
Session: 9
Number of samples removed: 731
Number of seconds removed: 2.92 seconds
Total epochs: 26
mod_start: 15 epochs; normal_walking: 11 epochs; 
Session: 10
No trimming needed as the beginning of signal is not flat.
Total epochs: 85
mod_start: 3 epochs; normal_walking: 82 epochs; 
Session: 11
Number of samples removed: 3215
Number of seconds removed: 12.86 seconds
Total epochs: 54
mod_start: 10 epochs; normal_walking: 44 epochs; 

  lfp_raw.set_annotations(my_annot)



Session: 12
Number of samples removed: 6608
Number of seconds removed: 26.43 seconds
Total epochs: 66
mod_start: 5 epochs; normal_walking: 61 epochs; 
Session: 13
Number of samples removed: 13713
Number of seconds removed: 54.85 seconds
Total epochs: 48
mod_start: 5 epochs; normal_walking: 43 epochs; 

  lfp_raw.set_annotations(my_annot)



Session: 14
Number of samples removed: 1090
Number of seconds removed: 4.36 seconds
Total epochs: 42
mod_start: 9 epochs; normal_walking: 33 epochs; 
Session: 15
Number of samples removed: 13850
Number of seconds removed: 55.40 seconds
Total epochs: 34
mod_start: 8 epochs; normal_walking: 26 epochs; 


0,1
Number of events,757
Events,mod_start: 147 normal_walking: 610
Time range,-2.000 – 0.000 s
Baseline,off


In [4]:
Visualise.plot_event_occurrence(events=events, 
                                epoch_sample_length=epoch_sample_length, 
                                lfp_sfreq=lfp_sfreq, 
                                event_dict=event_dict,
                                # gait_modulation_event_id=mod_start_event_id, 
                                # normal_walking_event_id=normal_walking_event_id, 
                                n_sessions=n_sessions,
                                show_fig=False, 
                                save_fig=True, 
                                file_name=f'plots/event_classes.png')


Visualise.plot_event_occurrence(events=epochs.events, 
                                epoch_sample_length=epoch_sample_length, 
                                lfp_sfreq=lfp_sfreq, 
                                event_dict=event_dict,
                                # gait_modulation_event_id=mod_start_event_id, 
                                # normal_walking_event_id=normal_walking_event_id, 
                                n_sessions=n_sessions,
                                show_fig=False, 
                                save_fig=True, 
                                file_name=f'plots/epochs.event_classes.png')

Plot saved as plots/event_classes.png
Plot saved as plots/epochs.event_classes.png


In [5]:
Visualise.plot_event_class_histogram(events=events,
                                    event_dict=epochs.event_id,
                                    n_sessions=n_sessions,
                                    show_fig=False, 
                                    save_fig=True,
                                    file_name=f'plots/event_class_histogram.png')

Visualise.plot_event_class_histogram(events=epochs.events,
                                    event_dict=epochs.event_id,
                                    n_sessions=n_sessions,
                                    show_fig=False, 
                                    save_fig=True,
                                    file_name=f'plots/epochs.event_class_histogram.png')

Plot saved as plots/event_class_histogram.png
Plot saved as plots/epochs.event_class_histogram.png


# Baseline Model

In [6]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import classification_report


# Step 1: Train-test split
X_train = train_all_epochs_combined.get_data(copy=True)
X_test = test_all_epochs_combined.get_data(copy=True)

y_train = train_all_epochs_combined.events[:, -1]
y_test = test_all_epochs_combined.events[:, -1]

print(f'X_train shape: {X_train.shape}')# (n_epochs, n_channels, n_samples_per_epoch)
print(f'y_train shape: {y_train.shape}')  # (n_epochs,)
print(f'--- Total epochs: {len(y_train)}, with {sum(y_train == -1)} normal walking and {sum(y_train == 1)} event-related gait modulation')

print(f'X_test shape: {X_test.shape}')# (n_epochs, n_channels, n_samples_per_epoch)
print(f'y_test shape: {y_test.shape}')  # (n_epochs,)
print(f'--- Total epochs: {len(y_train)}, with {sum(y_test == -1)} normal walking and {sum(y_test == 1)} event-related gait modulation')

# Step 2: Flatten the X array (n_epochs, n_channels, n_samples_per_epoch) -> (n_epochs, n_channels * n_samples_per_epoch)
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1] * X_train.shape[2])
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1] * X_test.shape[2])


# Step 3: Standardize the features (mean=0, std=1)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Step 4: Train a Logistic Regression model
clf = LogisticRegression(random_state=42, max_iter=1000)
clf.fit(X_train_scaled, y_train)

# Step 5: Make predictions
y_pred = clf.predict(X_test)

# Step 6: Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)

# Output the results
print(f'Accuracy: {accuracy:.4f}')
print('Confusion Matrix:')
print(conf_matrix)

# Output the classification report
print(classification_report(y_test, y_pred))

NameError: name 'train_all_epochs_combined' is not defined

# Filtering the data

In [None]:
# raw psd
raw_spectrum = raw_data_clean.compute_psd(method='welch', fmin=1, fmax=50, n_fft=2048)

psd_arr = raw_spectrum.get_data()
psd_freqs = raw_spectrum.freqs
print(raw_data_clean.get_data().shape)
print(f"PSD data has shape: {psd_arr.shape}  # channels x frequencies")
print(f"Frequencies has shape: {psd_freqs.shape}  # frequencies")

raw_spectrum.plot()

In [None]:
raw_spectrum.get_data().shape, raw_data_clean.get_data().shape

In [None]:
# epochs psd: Train set
train_epoch_spectrum = train_all_epochs_combined.compute_psd(method='welch', fmax=50)

psd_arr = train_epoch_spectrum.get_data()
psd_freqs = train_epoch_spectrum.freqs

print(train_all_epochs_combined.get_data().shape)
print(f"PSD data has shape: {psd_arr.shape}")
print(f"Frequencies has shape: {psd_freqs.shape}")

train_epoch_spectrum.plot(average=False)

In [None]:
train_epoch_spectrum['mod_start'].plot(average=False)

In [None]:
train_epoch_spectrum['normal_walking'].plot(average=False)

In [None]:
# epochs psd: Test set
test_epoch_spectrum = test_all_epochs_combined.compute_psd(method='welch', fmax=50)

psd_arr = test_epoch_spectrum.get_data()
psd_freqs = test_epoch_spectrum.freqs

print(test_all_epochs_combined.get_data().shape)
print(f"PSD data has shape: {psd_arr.shape}")
print(f"Frequencies has shape: {psd_freqs.shape}")

test_epoch_spectrum.plot(average=False)

In [None]:
test_epoch_spectrum['mod_start'].plot(average=False)

In [None]:
test_epoch_spectrum['normal_walking'].plot(average=False)

In [None]:
train_all_epochs_combined.get_data().shape, train_epoch_spectrum.get_data().shape



# train_all_epochs_combined.events[:, -1].shape, test_all_epochs_combined.events[:, -1].shape

## Logistic Regression based on PSD

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import classification_report


# Step 1: Train-test split
X_train = train_epoch_spectrum.get_data()
X_test = test_epoch_spectrum.get_data()

y_train = train_all_epochs_combined.events[:, -1]
y_test = test_all_epochs_combined.events[:, -1]

print(f'X_train shape: {X_train.shape}')# (n_epochs, n_channels, n_samples_per_epoch)
print(f'y_train shape: {y_train.shape}')  # (n_epochs,)
print(f'--- Total epochs: {len(y_train)}, with {sum(y_train == -1)} normal walking and {sum(y_train == 1)} event-related gait modulation')

print(f'X_test shape: {X_test.shape}')# (n_epochs, n_channels, n_samples_per_epoch)
print(f'y_test shape: {y_test.shape}')  # (n_epochs,)
print(f'--- Total epochs: {len(y_train)}, with {sum(y_test == -1)} normal walking and {sum(y_test == 1)} event-related gait modulation')

# Step 2: Flatten the X array (n_epochs, n_channels, n_samples_per_epoch) -> (n_epochs, n_channels * n_samples_per_epoch)
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1] * X_train.shape[2])
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1] * X_test.shape[2])


# Step 3: Standardize the features (mean=0, std=1)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Step 4: Train a Logistic Regression model
clf = LogisticRegression(random_state=42, max_iter=1000)
clf.fit(X_train_scaled, y_train)

# Step 5: Make predictions
y_pred = clf.predict(X_test)

# Step 6: Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)

# Output the results
print(f'Accuracy: {accuracy:.4f}')
print('Confusion Matrix:')
print(conf_matrix)

# Output the classification report
print(classification_report(y_test, y_pred))

In [None]:
# ICA raw
## CODE GOES HERE
ica_95PCA = mne.preprocessing.ICA(n_components=0.95, random_state=0)
ica_95PCA.fit(inst=lfp_raw)
# ica_95PCA.plot_sources(inst=lfp_raw, title="ICA sources (95% variance PCA components)");

In [None]:
# raw_ica_excluded = ica.apply(inst=lfp_raw.copy())
# raw_ica = ica.apply(inst=lfp_raw.copy())

# raw_ica_excluded.plot(scalings='auto', start=12, duration=4, title='clraned sensor signals (without noise)')
# raw_ica.plot(scalings='auto', start=12, duration=4, title='clraned sensor signals (without noise)')

In [None]:
# # ICA rpochs
# ## CODE GOES HERE
# ica_95PCA = mne.preprocessing.ICA(n_components=0.95, random_state=0)
# ica_95PCA.fit(inst=epochs_raw)
# ica_95PCA.plot_sources(inst=epochs_raw, title="ICA sources (95% variance PCA components)");

In [None]:
# # Compute variance explained by PCA components
# explained_variance = ica.pca_explained_variance_ / np.sum(ica.pca_explained_variance_)

# print(f"Variance explained by PCA components: {explained_variance}")
# print(f"Variance explained by first 4 PCA components: {np.sum(explained_variance[:4]) * 100:.2f}%")

In [None]:
# ica = mne.preprocessing.ICA(random_state=0)
# ica.fit(raw_sensors)

# # Remove the first ICA component (the random noise) from the data
# raw_cleaned = ica.apply(inst=raw_sensors.copy(), exclude=[1,2])
# raw_cleaned.plot(scalings='auto', title='clraned sensor signals (without noise)')

In [None]:
epochs_raw["mod_start"].plot_image(combine="mean");

In [None]:
mne.viz.plot_epochs_image(
    epochs_raw['mod_start'],
    picks=[0, 1, 2, 3, 4, 5],
    sigma=0.5,
    # combine="mean",
    # evoked=True
)

# Evoked

evoked_0 = epochs_raw['trial_start'].average()
evoked_4 = epochs_raw['mod_start'].average()

## Global Field Power (GFP)

The GFP is the population standard deviation of the signal across channels.

In [None]:
fig0 = evoked_0.plot(gfp=True);
fig1 = evoked_1.plot(gfp=True);

In [None]:
evoked_0.plot(gfp="only");
evoked_1.plot(gfp="only");

In [None]:
gfp = evoked_0.data.std(axis=0, ddof=0)

# Reproducing the MNE-Python plot style seen above
fig, ax = plt.subplots()
ax.plot(evoked_0.times, gfp * 1e6, color="lime")
ax.fill_between(evoked_0.times, gfp * 1e6, color="lime", alpha=0.2)
ax.set(xlabel="Time (s)", ylabel="GFP (µV)")

In [None]:
gfp = evoked_1.data.std(axis=0, ddof=0)

# Reproducing the MNE-Python plot style seen above
fig, ax = plt.subplots()
ax.plot(evoked_1.times, gfp * 1e6, color="lime")
ax.fill_between(evoked_1.times, gfp * 1e6, color="lime", alpha=0.2)
ax.set(xlabel="Time (s)", ylabel="GFP (µV)")

# # Time-frequency analysis

In [None]:
freqs = np.arange(2, 50, 2) # Frequencies from 2 to 50 Hz
n_cycles = freqs / 2 # Number of cycles in Morlet wavelet


In [None]:
freqs = np.arange(7, 30, 3)
power = epochs_raw['mod_start'].compute_tfr(
    "morlet", 
    n_cycles=2,
    return_itc=False, 
    freqs=freqs,
    decim=3,
    average=True
)
power.plot(title='auto')

In [None]:
freqs = np.arange(7, 30, 3)
power = epochs_raw['min_vel'].compute_tfr(
    "morlet", 
    n_cycles=2,
    return_itc=False, 
    freqs=freqs,
    decim=3,
    average=True
)
power.plot(title='auto')

# CSD

In [None]:
csd_fft = mne.time_frequency.csd_fourier(train_all_epochs_combined, fmin=1, fmax=50)
csd_mt = mne.time_frequency.csd_multitaper(train_all_epochs_combined, fmin=1, fmax=50, adaptive=True)
frequencies = np.arange(1,51, 1)
csd_wav = mne.time_frequency.csd_morlet(train_all_epochs_combined, frequencies, decim=1)

In [None]:
plot_dict = {
    "Short-time Fourier transform": csd_fft,
    "Adaptive multitapers": csd_mt,
    "Morlet wavelet transform": csd_wav,
}
for title, csd in plot_dict.items():
    (fig,) = csd.mean().plot()
    fig.suptitle(title)