In [1]:
import mne
import numpy as np
import pickle
import yaml

from gait_modulation import FeatureExtractor, DataProcessor
from gait_modulation.utils.utils import split_data, load_config

In [2]:
with open('processed/all_lfp_data.pkl', 'rb') as f:
    all_lfp_data = pickle.load(f)
    
with open('gait_modulation/configs/data_preprocessing.yaml', 'r') as file:
    config = yaml.safe_load(file)

# 1. Time domain

## Feature 1: Time domain representation of continuous data

In [3]:
# pad or truncate
# all_lfp_data should be a list of arrays, each of shape (n_channels, time)
all_lfp_uniform_size = DataProcessor.pad_or_truncate(all_lfp_data, config)
all_lfp_uniform_size.shape

(16, 6, 38213)

In [4]:
np.savez('processed/features/time_continuous_uniform-feat.npz',
         times_uniform=all_lfp_uniform_size)

In [5]:
# import matplotlib.pyplot as plt
# # plt.plot(processed_trials[0][0])
# # plt.plot(processed_trials[0][1])
# # plt.plot(processed_trials[0][2])
# # plt.plot(processed_trials[0][3])
# plt.plot(processed_trials[2][4])
# # plt.plot(processed_trials[0][5])

## Feature 2: Time domain representation of continuous data - combine channels * times

In [6]:
n_trials = all_lfp_uniform_size.shape[0]
n_channels = all_lfp_uniform_size.shape[1]
n_samples = all_lfp_uniform_size.shape[2]

lfp_data_combined_ch_time = all_lfp_uniform_size.reshape((
    n_trials, n_channels*n_samples))
lfp_data_combined_ch_time.shape

(16, 229278)

In [7]:
np.savez('processed/features/time_continuous_uniform_combined_ch_time-feat.npz',
         lfp_data_combined_ch_time=lfp_data_combined_ch_time)

## Features 3: Summary Statistics on continuous data in the time domain

In [8]:
time_domain_stats_feat = {
    'trials_stat': {
        'mean': np.mean(all_lfp_uniform_size, axis=0),
        'std': np.std(all_lfp_uniform_size, axis=0),
        'median': np.median(all_lfp_uniform_size, axis=0)
    },
    'channels_stat': {
        'mean': np.mean(all_lfp_uniform_size, axis=1),
        'std': np.std(all_lfp_uniform_size, axis=1),
        'median': np.median(all_lfp_uniform_size, axis=1)
    },
    'times_stat': {
        'mean': np.mean(all_lfp_uniform_size, axis=2),
        'std': np.std(all_lfp_uniform_size, axis=2),
        'median': np.median(all_lfp_uniform_size, axis=2)
    }
}

In [9]:
np.savez('processed/features/time_continuous_stats-feat.npz', 
         time_domain_stats_feat=time_domain_stats_feat)

## Features 4: Summary Statistics on fixed windows in the time domain

In [10]:
time_windowed_stat_feat = FeatureExtractor.extract_windowed_stat_features(
    all_lfp_uniform_size,
    methods=['mean', 'std', 'median'],
    window_size=100, step_size=50, verbose=False)

# Check the shape of each feature in the dictionary
for method, features in time_windowed_stat_feat.items():
    print(f"{method} features shape:", features.shape)

np.savez('processed/features/time_windowed_stat-feat.npz', 
         time_windowed_stat_feat=time_windowed_stat_feat)
all_lfp_uniform_size.shape

mean features shape: (16, 6, 763)
std features shape: (16, 6, 763)
median features shape: (16, 6, 763)


(16, 6, 38213)

## Feature 5: Time domain representation of epochs

In [11]:
epochs = mne.read_epochs('processed/lfp_-3.0tmin_5gap-epo.fif')
time_epoched = epochs.get_data(copy=False).shape

Reading /Users/orabe/Library/Mobile Documents/com~apple~CloudDocs/0_TU/Master/master_thesis/gait_modulation/processed/lfp_-3.0tmin_5gap-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =   -3000.00 ...       0.00 ms
        0 CTF compensation matrices available
Not setting metadata
893 matching events found
No baseline correction applied
0 projection items activated


In [12]:
np.savez('processed/features/time_epoched-feat.npz', 
         time_epoched=time_epoched)

## Feature 6: Time domain representation of epochs

In [13]:
time_epoched_stat_feat = FeatureExtractor.extract_epoched_stat_features(
    epochs=epochs, methods=['mean', 'std', 'median'])

np.savez('processed/features/time_epoched_stat-feat.npz', 
         time_epoched_stat_feat=time_epoched_stat_feat)

for agg_method, features in time_epoched_stat_feat.items():
    for feat_name, feat_values in features.items():
        print(f"Avergave by: {agg_method}, feature: {feat_name}, shape: {feat_values.shape}")

Avergave by: epochs, feature: mean, shape: (6, 751)
Avergave by: epochs, feature: std, shape: (6, 751)
Avergave by: epochs, feature: median, shape: (6, 751)
Avergave by: channels, feature: mean, shape: (893, 751)
Avergave by: channels, feature: std, shape: (893, 751)
Avergave by: channels, feature: median, shape: (893, 751)
Avergave by: times, feature: mean, shape: (893, 6)
Avergave by: times, feature: std, shape: (893, 6)
Avergave by: times, feature: median, shape: (893, 6)


# 2. Frequency domain

## Feature 1: Frequency domain representation of epoched data

In [14]:
freq_bands = {
    'delta': (0.1, 3),
    'theta': (4, 7),
    'alpha': (8, 12),
    'low_beta': (12, 16),
    'middle_beta': (16, 20),
    'high_beta': (20, 30),
    'gamma': (30, 100),
    'high_gamma': (100, 125)
}

In [15]:
# Exrtact spectral features for both classes at once
psds, freqs, band_power = FeatureExtractor.extract_psd_and_band_power(
    epochs,
    freq_bands,
    fmin=min([f[0] for f in freq_bands.values()]),
    fmax=max([f[1] for f in freq_bands.values()])
)
np.savez_compressed('processed/features/psd_bandPower-feat.npz', 
                    psds=psds, 
                    band_power=band_power)

print(freqs.shape)
print(psds.shape)  
print(band_power.shape)

Effective window size : 3.004 (s)
(375,)
(893, 6, 375)
(893, 6, 8)


In [16]:
# # mod_start features
# spectral_feat_mod_start = FeatureExtractor.extract_psd_and_band_power(
#     epochs['mod_start'],
#     freq_bands,
#     fmin=min([f[0] for f in freq_bands.values()]),
#     fmax=max([f[1] for f in freq_bands.values()])
# )
# psds_mod_start, freqs_mod_start, band_power_mod_start = spectral_feat_mod_start

# np.savez_compressed('processed/features/spectral_feat_mod_start.npz', 
#                     psds_mod_start=psds_mod_start, 
#                     band_power_mod_start=band_power_mod_start)

# print(freqs_mod_start.shape)
# print(psds_mod_start.shape)  
# print(band_power_mod_start.shape)

In [17]:
# # normal_walking features
# spectral_feat_normal_walking = FeatureExtractor.extract_psd_and_band_power(
#     epochs['normal_walking'],
#     freq_bands,
#     fmin=min([f[0] for f in freq_bands.values()]),
#     fmax=max([f[1] for f in freq_bands.values()])
# )

# psds_normal_walking, freqs_normal_walking, band_power_normal_walking = spectral_feat_normal_walking

# np.savez_compressed('processed/features/spectral_feat_normal_walking.npz', 
#                     psds_normal_walking=psds_normal_walking, 
#                     band_power_normal_walking=band_power_normal_walking)

# print(freqs_normal_walking.shape)
# print(psds_normal_walking.shape)  
# print(band_power_normal_walking.shape)

In [18]:
# psds_bandPower_mod_start = np.concatenate((psds_mod_start,
#                                            band_power_mod_start), axis=2)

# psds_bandPower_normal_walking = np.concatenate((psds_normal_walking, 
#                                                 band_power_normal_walking), axis=2)

# psds_bandPower_both_classes = np.concatenate((psds_bandPower_mod_start,
#                                               psds_bandPower_normal_walking), axis=0)
# # Generate labels
# labels = np.concatenate((np.ones(psds_bandPower_mod_start.shape[0]),
#                          -np.ones(psds_bandPower_normal_walking.shape[0])), axis=0)

# print(psds_mod_start.shape, band_power_mod_start.shape, psds_bandPower_mod_start.shape)
# print(psds_normal_walking.shape, band_power_normal_walking.shape, psds_bandPower_normal_walking.shape)

# print(psds_bandPower_both_classes.shape, labels.shape)
# psds_bandPower_both_classes.reshape(psds_bandPower_both_classes.shape[0], -1).shape

In [19]:
# # Combine modulation start features
# psds_bandPower_mod_start2 = np.concatenate((
#     psds_mod_start.reshape(psds_mod_start.shape[0], -1),
#     band_power_mod_start.reshape(band_power_mod_start.shape[0], -1)), axis=1)

# # Combine normal walking features
# psds_bandPower_normal_walking2 = np.concatenate((
#     psds_normal_walking.reshape(psds_normal_walking.shape[0], -1),
#     band_power_normal_walking.reshape(band_power_normal_walking.shape[0], -1)), axis=1)

# combined_psds_bandPower2 = np.concatenate((psds_bandPower_mod_start2,
#                                           psds_bandPower_normal_walking2), axis=0)
# # Generate labels
# labels2 = np.concatenate((np.ones(psds_bandPower_mod_start2.shape[0]),
#                          np.zeros(psds_bandPower_normal_walking2.shape[0])), axis=0)


# print(psds_mod_start.shape, band_power_mod_start.shape, psds_bandPower_mod_start2.shape)
# print(psds_normal_walking.shape, band_power_normal_walking.shape, psds_bandPower_normal_walking2.shape)

# print(combined_psds_bandPower2.shape, labels2.shape)
# combined_psds_bandPower2.reshape(combined_psds_bandPower2.shape[0], -1).shape

In [20]:
# X = psds_bandPower_both_classes.reshape(psds_bandPower_both_classes.shape[0], -1)
# y = labels
# splits = split_data(X, y, n_splits=5)

# print(X.shape, y.shape, '\n')
# for i, (X_train, X_test, y_train, y_test) in enumerate(splits):
#     print(f"Fold {i+1}:")
#     print("X_train shape:", X_train.shape)
#     print("X_test shape:", X_test.shape)
#     print("y_train shape:", y_train.shape)
#     print("y_test shape:", y_test.shape)
#     print("---------------------------")
