In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import mne

from gait_modulation.file_reader import MatFileReader
from gait_modulation.data_processor import DataProcessor
from gait_modulation.viz import Visualise


# Loading the data

In [None]:
# Handle multiple patients with nested directories.
root_directory = '/Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/master_thesis/Chiara/organized_data'
mat_reader = MatFileReader(root_directory, max_workers=1)  #  adjust the number of workers for parallelism

# Read all data from nested folders of multiple patients and sessions
all_data = mat_reader.read_data()
n_sessions = len(all_data)

print(f"Number of sessions: {n_sessions}")

In [3]:
# Initialize lists to store epochs and events
train_all_epochs = []
train_all_events = []

test_all_epochs = []
test_all_events = []

test_session_idx = [2, 3, 4]

for s in range(n_sessions):
    # Access specific sessions for a patient
    session = all_data[s]

    # Extract events and lfp data of the subject/session
    lfp_data = session['data_LFP']
    lfp_data *= 1e-6  # Convert microvolts to volts
    lfp_metadata = DataProcessor.np_to_dict(session['hdr_LFP'])
    events_KIN = DataProcessor.np_to_dict(session['events_KIN'])

    # Load parameters
    lfp_sfreq = lfp_metadata['Fs'].item()
    lfp_ch_names = DataProcessor.rename_lfp_channels(lfp_metadata['labels'])
    lfp_n_channels = lfp_metadata['NumberOfChannels'].item()

    # Prepare for mne data structure
    info = mne.create_info(ch_names=lfp_ch_names[0:6], sfreq=lfp_sfreq, ch_types='dbs')
    # lfp_raw = mne.io.RawArray(lfp_data, info)

    # Handle events
    ori_events, ori_event_dict = DataProcessor.create_events_array(events_KIN, lfp_sfreq)

    # Trim the data and adjust the event onsets accordingly
    lfp_data, events = DataProcessor.trim_data(lfp_data, ori_events, lfp_sfreq)

    # Update raw data after trimming
    lfp_raw = mne.io.RawArray(lfp_data, info)

    # Select one event to work with: mod_start
    event_of_interest = 'mod_start'
    events_mod_start = ori_events[ori_events[:, 2] == ori_event_dict[event_of_interest]]

    # Rename Gait Modulation Events
    mod_start_event_id = 1
    events_mod_start[:, 2] = mod_start_event_id

    # Define parameters
    epoch_tmin = -2.0
    epoch_tmax = 0.0
    epoch_duration_length = epoch_tmax - epoch_tmin
    epoch_sample_length = int(epoch_duration_length * lfp_sfreq)
    gap_duration = 10  # At least 10 seconds away from modulation events
    gap_sample_length = int(gap_duration * lfp_sfreq)
    lfp_duration = lfp_data.shape[1] / lfp_sfreq
    n_samples = int(lfp_duration * lfp_sfreq)

    # Define normal walking events
    normal_walking_event_id = -1
    normal_walking_events = DataProcessor.define_normal_walking_events(
        normal_walking_event_id, events_mod_start,
        gap_sample_length, epoch_sample_length, n_samples
    )

    # Define the event dictionary
    event_dict = {
        'mod_start': mod_start_event_id,
        'normal_walking': normal_walking_event_id
    }
    
    # Preprocessing
    ## Apply band-pass filtering to the raw LFP data.
    l_freq = 1
    h_freq = 50
    lfp_raw.filter(l_freq=l_freq, h_freq=h_freq, fir_design='firwin')

    ## Remove artifacts from raw LFP data using ICA.
    ica_n_components = 6 # 6 = n_channels.
    ica = mne.preprocessing.ICA(n_components=ica_n_components, random_state=97, max_iter=800)
    print(lfp_raw.ch_names)
    ica.fit(lfp_raw)
    
    # Apply ICA to the raw data
    raw_data_clean = ica.apply(lfp_raw)

    # Combine events and create epochs
    events, epochs = DataProcessor.create_epochs_with_events(
        raw_data_clean,
        events_mod_start,
        normal_walking_events,
        mod_start_event_id,
        normal_walking_event_id,
        epoch_tmin,
        epoch_tmax,
        event_dict
    )

    # Store epochs and events    
    if s in test_session_idx:
        test_all_epochs.append(epochs)
        test_all_events.append(events)
    else:
        train_all_epochs.append(epochs)
        train_all_events.append(events)
    
    Visualise.plot_event_class_histogram(events=epochs.events, 
                                        event_dict=epochs.event_id, 
                                        show_fig=False, 
                                        save_fig=True,
                                        file_name=f'plots/{s}_hist.png')

    Visualise.plot_event_occurrence(events=epochs.events, 
                                    epoch_sample_length=epoch_sample_length, 
                                    lfp_sfreq=lfp_sfreq, 
                                    gait_modulation_event_id=mod_start_event_id, 
                                    normal_walking_event_id=normal_walking_event_id, 
                                    show_fig=False, 
                                    save_fig=True, 
                                    file_name=f'plots/{s}_event_classes.png')


# Combine all epochs into one object
train_all_epochs_combined = mne.concatenate_epochs(train_all_epochs)
test_all_epochs_combined = mne.concatenate_epochs(test_all_epochs)

# Create a single event array by concatenating
train_all_events_combined = np.vstack(train_all_events)
train_all_events_combined = train_all_events_combined[np.argsort(train_all_events_combined[:, 0])]  # Sort by onset time

test_all_events_combined = np.vstack(test_all_events)
test_all_events_combined = test_all_events_combined[np.argsort(test_all_events_combined[:, 0])]  # Sort by onset time

# train set
Visualise.plot_event_class_histogram(events=train_all_epochs_combined.events,
                                    event_dict=epochs.event_id, 
                                    show_fig=False, 
                                    save_fig=True,
                                    file_name=f'plots/train_all_epochs_combined_hist.png')

Visualise.plot_event_occurrence(events=train_all_epochs_combined.events,
                                epoch_sample_length=epoch_sample_length, 
                                lfp_sfreq=lfp_sfreq, 
                                gait_modulation_event_id=mod_start_event_id, 
                                normal_walking_event_id=normal_walking_event_id, 
                                show_fig=False, 
                                save_fig=True, 
                                file_name=f'plots/train_all_epochs_combined_event_classes.png')

# test set
Visualise.plot_event_class_histogram(events=test_all_epochs_combined.events,
                                    event_dict=epochs.event_id, 
                                    show_fig=False, 
                                    save_fig=True,
                                    file_name=f'plots/test_all_epochs_combined_hist.png')

Visualise.plot_event_occurrence(events=test_all_epochs_combined.events,
                                epoch_sample_length=epoch_sample_length, 
                                lfp_sfreq=lfp_sfreq, 
                                gait_modulation_event_id=mod_start_event_id, 
                                normal_walking_event_id=normal_walking_event_id, 
                                show_fig=False, 
                                save_fig=True, 
                                file_name=f'plots/test_all_epochs_combined_event_classes.png')


    Zeroing out 0 ICA components
    Projecting back using 6 PCA components
Not setting metadata
66 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 66 events and 501 original time points ...
1 bad epochs dropped
Epochs info: <Epochs |  65 events (all good), -2 – 0 s, baseline off, ~1.5 MB, data loaded,
 'mod_start': 6
 'normal_walking': 59>
Number of epochs: 65


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    0.0s finished


Plot saved as plots/1_hist.png
Plot saved as plots/1_event_classes.png
Number of samples removed: 2282
Number of seconds removed: 9.13 seconds
Creating RawArray with float64 data, n_channels=6, n_times=44688
    Range : 0 ... 44687 =      0.000 ...   178.748 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (3.300 s)

['LFP_L03', 'LFP_L13', 'LFP_L02', 'LFP_R03', 'LFP_R13', 'LFP_R02']
Fitting ICA to data using 6 channels (please be patient, this may take a while)
Selecting by number: 

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    0.0s finished


Plot saved as plots/2_event_classes.png
Number of samples removed: 7049
Number of seconds removed: 28.20 seconds
Creating RawArray with float64 data, n_channels=6, n_times=38563
    Range : 0 ... 38562 =      0.000 ...   154.248 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (3.300 s)

['LFP_L03', 'LFP_L13', 'LFP_L02', 'LFP_R03', 'LFP_R13', 'LFP_R02']
Fitting ICA to data using 6 channels (please be patient, this may take a while)
Selecting by number: 6 components
Fitting ICA took 

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    0.0s finished


Plot saved as plots/3_event_classes.png
Number of samples removed: 1950
Number of seconds removed: 7.80 seconds
Creating RawArray with float64 data, n_channels=6, n_times=43437
    Range : 0 ... 43436 =      0.000 ...   173.744 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (3.300 s)

['LFP_L03', 'LFP_L13', 'LFP_L02', 'LFP_R03', 'LFP_R13', 'LFP_R02']
Fitting ICA to data using 6 channels (please be patient, this may take a while)
Selecting by number: 6 components
Fitting ICA took 0

In [None]:
print(train_all_epochs_combined.events.shape[0], train_all_events_combined.shape[0])
print(test_all_epochs_combined.events.shape[0], test_all_events_combined.shape[0])

nlabels = train_all_epochs_combined.events.shape[0] + test_all_epochs_combined.events.shape[0]

print(train_all_epochs_combined.events.shape[0] / nlabels)
print(test_all_epochs_combined.events.shape[0] / nlabels)

In [None]:
n_mod_labels = np.sum(train_all_epochs_combined.events[:, 2] == -1) + np.sum(test_all_epochs_combined.events[:, 2] == -1)

print(n_mod_labels)

print(np.sum(train_all_epochs_combined.events[:, 2] == -1))
print(np.sum(test_all_epochs_combined.events[:, 2] == -1))

print(np.sum(train_all_epochs_combined.events[:, 2] == -1) / n_mod_labels)
print(np.sum(test_all_epochs_combined.events[:, 2] == -1) / n_mod_labels)

# Pre-Processing

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import classification_report


# Step 1: Train-test split
X_train = train_all_epochs_combined.get_data(copy=True)
X_test = test_all_epochs_combined.get_data(copy=True)

y_train = train_all_epochs_combined.events[:, -1]
y_test = test_all_epochs_combined.events[:, -1]

print(f'X_train shape: {X_train.shape}')# (n_epochs, n_channels, n_samples_per_epoch)
print(f'y_train shape: {y_train.shape}')  # (n_epochs,)
print(f'--- Total epochs: {len(y_train)}, with {sum(y_train == -1)} normal walking and {sum(y_train == 1)} event-related gait modulation')

print(f'X_test shape: {X_test.shape}')# (n_epochs, n_channels, n_samples_per_epoch)
print(f'y_test shape: {y_test.shape}')  # (n_epochs,)
print(f'--- Total epochs: {len(y_train)}, with {sum(y_test == -1)} normal walking and {sum(y_test == 1)} event-related gait modulation')

# Step 2: Flatten the X array (n_epochs, n_channels, n_samples_per_epoch) -> (n_epochs, n_channels * n_samples_per_epoch)
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1] * X_train.shape[2])
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1] * X_test.shape[2])


# Step 3: Standardize the features (mean=0, std=1)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Step 4: Train a Logistic Regression model
clf = LogisticRegression(random_state=42, max_iter=1000)
clf.fit(X_train_scaled, y_train)

# Step 5: Make predictions
y_pred = clf.predict(X_test)

# Step 6: Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)

# Output the results
print(f'Accuracy: {accuracy:.4f}')
print('Confusion Matrix:')
print(conf_matrix)

# Output the classification report
print(classification_report(y_test, y_pred))

# Filtering the data

In [None]:
# raw psd
spectrum = lfp_raw.compute_psd(method='welch', fmax=50, n_fft=2048)

psd_arr = spectrum.get_data()
psd_freqs = spectrum.freqs

print(f"PSD data has shape: {psd_arr.shape}  # channels x frequencies")
print(f"Frequencies has shape: {psd_freqs.shape}  # frequencies")

spectrum.plot()

In [None]:
# epochs psd
spectrum = epochs.compute_psd(method='welch', fmax=50)

psd_arr = spectrum.get_data()
psd_freqs = spectrum.freqs

print(f"PSD data has shape: {psd_arr.shape}  # channels x frequencies")
print(f"Frequencies has shape: {psd_freqs.shape}  # frequencies")

spectrum.plot()

In [None]:
lfp_raw.filter(l_freq = 1, h_freq=None)
# raw_copy.notch_filter(freqs=[5, 10, 20])

In [None]:
spectrum = lfp_raw.compute_psd(method='welch', fmax=50, n_fft=2048)

psd_arr = spectrum.get_data()
psd_freqs = spectrum.freqs

print(f"PSD data has shape: {psd_arr.shape}  # channels x frequencies")
print(f"Frequencies has shape: {psd_freqs.shape}  # frequencies")

spectrum.plot()

In [None]:
lfp_raw.compute_psd(method='welch', fmax=50, n_fft=2048).plot()


In [None]:
# ICA raw
## CODE GOES HERE
ica_95PCA = mne.preprocessing.ICA(n_components=0.95, random_state=0)
ica_95PCA.fit(inst=lfp_raw)
# ica_95PCA.plot_sources(inst=lfp_raw, title="ICA sources (95% variance PCA components)");

In [None]:
# raw_ica_excluded = ica.apply(inst=lfp_raw.copy())
# raw_ica = ica.apply(inst=lfp_raw.copy())

# raw_ica_excluded.plot(scalings='auto', start=12, duration=4, title='clraned sensor signals (without noise)')
# raw_ica.plot(scalings='auto', start=12, duration=4, title='clraned sensor signals (without noise)')

In [None]:
# # ICA rpochs
# ## CODE GOES HERE
# ica_95PCA = mne.preprocessing.ICA(n_components=0.95, random_state=0)
# ica_95PCA.fit(inst=epochs_raw)
# ica_95PCA.plot_sources(inst=epochs_raw, title="ICA sources (95% variance PCA components)");

In [None]:
# # Compute variance explained by PCA components
# explained_variance = ica.pca_explained_variance_ / np.sum(ica.pca_explained_variance_)

# print(f"Variance explained by PCA components: {explained_variance}")
# print(f"Variance explained by first 4 PCA components: {np.sum(explained_variance[:4]) * 100:.2f}%")

In [None]:
# ica = mne.preprocessing.ICA(random_state=0)
# ica.fit(raw_sensors)

# # Remove the first ICA component (the random noise) from the data
# raw_cleaned = ica.apply(inst=raw_sensors.copy(), exclude=[1,2])
# raw_cleaned.plot(scalings='auto', title='clraned sensor signals (without noise)')

In [None]:
epochs_raw["mod_start"].plot_image(combine="mean");

In [None]:
mne.viz.plot_epochs_image(
    epochs_raw['mod_start'],
    picks=[0, 1, 2, 3, 4, 5],
    sigma=0.5,
    # combine="mean",
    # evoked=True
)

# Evoked

evoked_0 = epochs_raw['trial_start'].average()
evoked_4 = epochs_raw['mod_start'].average()

## Global Field Power (GFP)

The GFP is the population standard deviation of the signal across channels.

In [None]:
fig0 = evoked_0.plot(gfp=True);
fig1 = evoked_1.plot(gfp=True);

In [None]:
evoked_0.plot(gfp="only");
evoked_1.plot(gfp="only");

In [None]:
gfp = evoked_0.data.std(axis=0, ddof=0)

# Reproducing the MNE-Python plot style seen above
fig, ax = plt.subplots()
ax.plot(evoked_0.times, gfp * 1e6, color="lime")
ax.fill_between(evoked_0.times, gfp * 1e6, color="lime", alpha=0.2)
ax.set(xlabel="Time (s)", ylabel="GFP (µV)")

In [None]:
gfp = evoked_1.data.std(axis=0, ddof=0)

# Reproducing the MNE-Python plot style seen above
fig, ax = plt.subplots()
ax.plot(evoked_1.times, gfp * 1e6, color="lime")
ax.fill_between(evoked_1.times, gfp * 1e6, color="lime", alpha=0.2)
ax.set(xlabel="Time (s)", ylabel="GFP (µV)")

# # Time-frequency analysis

In [None]:
freqs = np.arange(2, 50, 2) # Frequencies from 2 to 50 Hz
n_cycles = freqs / 2 # Number of cycles in Morlet wavelet


In [None]:
freqs = np.arange(7, 30, 3)
power = epochs_raw['mod_start'].compute_tfr(
    "morlet", 
    n_cycles=2,
    return_itc=False, 
    freqs=freqs,
    decim=3,
    average=True
)
power.plot(title='auto')

In [None]:
freqs = np.arange(7, 30, 3)
power = epochs_raw['min_vel'].compute_tfr(
    "morlet", 
    n_cycles=2,
    return_itc=False, 
    freqs=freqs,
    decim=3,
    average=True
)
power.plot(title='auto')

# CSD

In [None]:
csd_fft = mne.time_frequency.csd_fourier(epochs_raw, fmin=1, fmax=70)
csd_mt = mne.time_frequency.csd_multitaper(epochs_raw, fmin=1, fmax=70, adaptive=True)
frequencies = np.arange(1,71, 1)
csd_wav = mne.time_frequency.csd_morlet(epochs_raw, frequencies, decim=1)

In [None]:
plot_dict = {
    "Short-time Fourier transform": csd_fft,
    "Adaptive multitapers": csd_mt,
    "Morlet wavelet transform": csd_wav,
}
for title, csd in plot_dict.items():
    (fig,) = csd.mean().plot()
    fig.suptitle(title)