In [2]:
import mne
import numpy as np
import os
import matplotlib.pyplot as plt

# Define the path to the dataset and the output directory
gdf_path = '/home/jovyan/BCICIV_2a_gdf/A02T.gdf'
output_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2'

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Step 1: Load the BCI-IV 2a dataset
raw = mne.io.read_raw_gdf(gdf_path, preload=True)

# Remove the last three channels (EOG-left, EOG-central, and EOG-right)
raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])

# Step 2: Extract epochs for each class using provided event IDs
event_id = {'769': 7, '770': 8, '771': 9, '772': 10}
events, _ = mne.events_from_annotations(raw, event_id=event_id)
tmin, tmax = 1.5, 6  # 4.5 seconds epochs starting at 1.5s

# Creating epochs for each class
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=None, preload=True)
data = epochs.get_data()  # shape is (n_epochs, n_channels, n_times)

# Standardize each channel
data = (data - data.mean(axis=2, keepdims=True)) / data.std(axis=2, keepdims=True)

# Initial Temporal Feature Extraction Module (iTFE Module)
def extract_temporal_features(epoch_data):
    # Extract temporal features: mean and std for each channel
    mean_features = epoch_data.mean(axis=1)
    std_features = epoch_data.std(axis=1)
    # Concatenate features to form a feature vector
    features = np.concatenate((mean_features, std_features), axis=0)
    return features

# Step 3: Save each epoch and its features
class_counts = {7: 0, 8: 0, 9: 0, 10: 0}

for i, epoch_data in enumerate(data):
    # Adjust labels to start from 0 by subtracting 7 from each label
    label = epochs.events[i, -1]
    file_path = os.path.join(output_dir, f'epoch_{i+1}_class_{label - 7}.npy')
    np.save(file_path, epoch_data)
    
    # Extract and save features
    features = extract_temporal_features(epoch_data)
    features_path = os.path.join(output_dir, f'features_{i+1}_class_{label - 7}.npy')
    np.save(features_path, features)
    
    class_counts[label] += 1

print(f'Saved {len(data)} epochs and their features to {output_dir}')

# Step 4: Plot and save some epochs without axes and color bar
def plot_and_save_epochs(output_dir, save_dir, num_files_to_plot=5):
    os.makedirs(save_dir, exist_ok=True)
    saved_files = [f for f in os.listdir(output_dir) if f.endswith('.npy') and 'epoch' in f]
    saved_file_paths = []
    for i in range(min(num_files_to_plot, len(saved_files))):
        file_path = os.path.join(output_dir, saved_files[i])
        epoch_data = np.load(file_path)
        
        # Plot without axes and color bar
        plt.figure(figsize=(12, 6))
        plt.imshow(epoch_data, aspect='auto', cmap='jet')
        plt.axis('off')
        save_path = os.path.join(save_dir, f'{saved_files[i]}_without_axes.png')
        plt.savefig(save_path)
        plt.close()
        saved_file_paths.append(save_path)
    
    return saved_file_paths

# Step 5: Plot the distribution of files per class and save image
def plot_and_save_class_distribution(class_counts, save_path):
    labels = list(class_counts.keys())
    counts = list(class_counts.values())
    
    plt.figure(figsize=(10, 5))
    plt.bar(labels, counts, color='blue')
    plt.xlabel('Class Labels')
    plt.ylabel('Number of Files')
    plt.title('Distribution of Files per Class')
    plt.savefig(save_path)
    plt.close()
    return save_path

# Directory to save plots
plot_save_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2'
os.makedirs(plot_save_dir, exist_ok=True)

# Plot and save example epochs
saved_epoch_files = plot_and_save_epochs(output_dir, plot_save_dir)

# Save the distribution of files per class
class_distribution_path = plot_and_save_class_distribution(class_counts, os.path.join(plot_save_dir, 'class_distribution.png'))

# Print the paths to the saved files
print("Saved epoch plots:")
for file_path in saved_epoch_files:
    print(file_path)

print(f"Class distribution plot saved at: {class_distribution_path}")


Extracting EDF parameters from /home/jovyan/BCICIV_2a_gdf/A02T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  next(self.gen)


Used Annotations descriptions: ['769', '770', '771', '772']
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 288 events and 1126 original time points ...
1 bad epochs dropped


  data = epochs.get_data()  # shape is (n_epochs, n_channels, n_times)


Saved 287 epochs and their features to /home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2
Saved epoch plots:
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2/epoch_1_class_0.npy_without_axes.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2/epoch_62_class_3.npy_without_axes.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2/epoch_72_class_2.npy_without_axes.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2/epoch_196_class_2.npy_without_axes.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2/epoch_145_class_1.npy_without_axes.png
Class distribution plot saved at: /home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2_p2/class_distribution.png


In [7]:
import mne
import numpy as np
import os
import matplotlib.pyplot as plt
import pywt
from scipy.ndimage import convolve1d

# Define the path to the dataset and the output directory
gdf_path = '/home/jovyan/BCICIV_2a_gdf/A02T.gdf'
output_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2'

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Step 1: Load the BCI-IV 2a dataset
raw = mne.io.read_raw_gdf(gdf_path, preload=True)

# Remove the last three channels (EOG-left, EOG-central, and EOG-right)
raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])

# Step 2: Extract epochs for each class using provided event IDs
event_id = {'769': 7, '770': 8, '771': 9, '772': 10}
events, _ = mne.events_from_annotations(raw, event_id=event_id)
tmin, tmax = 1.5, 6  # 4.5 seconds epochs starting at 1.5s

# Creating epochs for each class
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=None, preload=True)
data = epochs.get_data(copy=True)  # shape is (n_epochs, n_channels, n_times)

# Standardize each channel
data = (data - data.mean(axis=2, keepdims=True)) / data.std(axis=2, keepdims=True)

# Initial Temporal Feature Extraction (iTFE) Module with Convolution
def apply_convolution(epoch_data, kernel):
    return convolve1d(epoch_data, kernel, axis=-1, mode='constant')

def extract_temporal_features_with_convolution(epoch_data):
    kernels = [np.ones((3,))/3, np.ones((5,))/5, np.ones((11,))/11]
    mean_features = []
    std_features = []
    for kernel in kernels:
        convolved_data = apply_convolution(epoch_data, kernel)
        mean_features.append(convolved_data.mean(axis=1))
        std_features.append(convolved_data.std(axis=1))
    mean_features = np.mean(mean_features, axis=0)
    std_features = np.mean(std_features, axis=0)
    return np.concatenate((mean_features, std_features), axis=0)

# Deep EEG-Channel-attention (DEC) Module
def eeg_channel_attention(features, channel_weights):
    # Assuming channel_weights is the same length as number of channels
    n_channels = len(channel_weights)
    mean_features = features[:n_channels]
    std_features = features[n_channels:]
    
    # Apply channel weights
    weighted_mean_features = mean_features * channel_weights
    weighted_std_features = std_features * channel_weights
    
    # Combine weighted features
    weighted_features = np.concatenate((weighted_mean_features, weighted_std_features), axis=0)
    return weighted_features

# Wavelet-based Temporal-Spectral-attention (WTS) Module
def wavelet_temporal_spectral_attention(epoch_data, wavelet='db4', level=4):
    coeffs = pywt.wavedec(epoch_data, wavelet, level=level)
    return np.concatenate(coeffs, axis=1)

# Simple Discrimination Module
def simple_discrimination(feature_matrix):
    # Placeholder for a simple classification model
    # Assuming a binary classification for simplicity
    return np.random.randint(0, 2, size=feature_matrix.shape[0])

# Initialize channel weights (assuming equal importance initially)
channel_weights = np.ones(data.shape[1])

# Step 3: Process each epoch through the modules
class_counts = {7: 0, 8: 0, 9: 0, 10: 0}
features_list = []

for i, epoch_data in enumerate(data):
    # Extract temporal features using the iTFE module
    temporal_features = extract_temporal_features_with_convolution(epoch_data)
    
    # Deep EEG-Channel-attention
    attended_features = eeg_channel_attention(temporal_features, channel_weights)
    
    # Wavelet-based Temporal-Spectral-attention
    wts_features = wavelet_temporal_spectral_attention(epoch_data)
    
    # Combine features
    combined_features = np.concatenate((attended_features, wts_features.flatten()), axis=0)
    features_list.append(combined_features)
    
    # Adjust labels to start from 0 by subtracting 7 from each label
    label = epochs.events[i, -1]
    class_counts[label] += 1

# Convert features list to a matrix
feature_matrix = np.array(features_list)

# Simple Discrimination
predictions = simple_discrimination(feature_matrix)

# Print the results
print(f'Feature matrix shape: {feature_matrix.shape}')
print(f'Predictions: {predictions}')

# Save feature matrix and predictions
np.save(os.path.join(output_dir, 'feature_matrix.npy'), feature_matrix)
np.save(os.path.join(output_dir, 'predictions.npy'), predictions)

# Step 4: Plot and save some epochs without axes and color bar
def plot_and_save_epochs(output_dir, save_dir, num_files_to_plot=5):
    os.makedirs(save_dir, exist_ok=True)
    saved_files = [f for f in os.listdir(output_dir) if f.endswith('.npy') and 'epoch' in f]
    saved_file_paths = []
    for i in range(min(num_files_to_plot, len(saved_files))):
        file_path = os.path.join(output_dir, saved_files[i])
        epoch_data = np.load(file_path)
        
        # Plot without axes and color bar
        plt.figure(figsize=(12, 6))
        plt.imshow(epoch_data, aspect='auto', cmap='jet')
        plt.axis('off')
        save_path = os.path.join(save_dir, f'{saved_files[i]}_without_axes.png')
        plt.savefig(save_path)
        plt.close()
        saved_file_paths.append(save_path)
    
    return saved_file_paths

# Step 5: Plot the distribution of files per class and save image
def plot_and_save_class_distribution(class_counts, save_path):
    labels = list(class_counts.keys())
    counts = list(class_counts.values())
    
    plt.figure(figsize=(10, 5))
    plt.bar(labels, counts, color='blue')
    plt.xlabel('Class Labels')
    plt.ylabel('Number of Files')
    plt.title('Distribution of Files per Class')
    plt.savefig(save_path)
    plt.close()
    return save_path

# Directory to save plots
plot_save_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5'
os.makedirs(plot_save_dir, exist_ok=True)

# Plot and save example epochs
saved_epoch_files = plot_and_save_epochs(output_dir, plot_save_dir)

# Save the distribution of files per class
class_distribution_path = plot_and_save_class_distribution(class_counts, os.path.join(plot_save_dir, 'class_distribution.png'))

# Print the paths to the saved files
print("Saved epoch plots:")
for file_path in saved_epoch_files:
    print(file_path)

print(f"Class distribution plot saved at: {class_distribution_path}")


Extracting EDF parameters from /home/jovyan/BCICIV_2a_gdf/A02T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  next(self.gen)


Used Annotations descriptions: ['769', '770', '771', '772']
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 288 events and 1126 original time points ...
1 bad epochs dropped
Feature matrix shape: (287, 25344)
Predictions: [0 1 1 0 1 1 1 0 1 0 1 0 1 0 1 0 1 0 0 0 0 0 0 1 0 0 0 1 0 1 0 1 1 1 0 1 1
 0 1 0 1 1 0 1 0 0 1 0 1 1 1 0 0 1 1 0 0 1 0 0 0 1 1 1 0 0 1 1 1 0 0 0 0 1
 1 1 0 1 0 1 1 1 1 1 1 0 0 0 1 1 1 0 1 1 1 1 1 1 0 0 1 0 0 0 0 0 1 1 1 1 1
 0 0 0 1 1 0 1 1 0 0 0 1 1 1 0 1 0 0 1 0 1 1 1 1 1 0 1 0 1 0 1 1 1 0 1 0 0
 1 0 1 1 0 0 0 1 1 0 1 0 1 1 1 1 1 0 1 0 1 1 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0
 1 1 1 0 1 0 0 0 0 0 1 1 0 0 1 0 1 0 0 1 0 1 0 0 1 0 1 0 0 1 1 1 0 1 1 0 1
 1 0 1 0 1 1 0 1 0 0 1 0 1 1 1 0 0 1 1 1 0 0 0 0 0 1 0 0 1 0 1 0 1 0 0 1 0
 1 0 1 0 1 0 1 1 1 0 1 1 1 1 0 1 1 0 1 0 1 1 0 1 1 1 0 0]
Saved epoch plots:
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_1_class_0.npy_without_axes.pn

In [8]:
import mne
import numpy as np
import os
import matplotlib.pyplot as plt
from scipy.ndimage import convolve1d

# Define the path to the dataset and the output directory
gdf_path = '/home/jovyan/BCICIV_2a_gdf/A02T.gdf'
output_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2'

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Step 1: Load the BCI-IV 2a dataset
raw = mne.io.read_raw_gdf(gdf_path, preload=True)

# Remove the last three channels (EOG-left, EOG-central, and EOG-right)
raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])

# Step 2: Extract epochs for each class using provided event IDs
event_id = {'769': 7, '770': 8, '771': 9, '772': 10}
events, _ = mne.events_from_annotations(raw, event_id=event_id)
tmin, tmax = 1.5, 6  # 4.5 seconds epochs starting at 1.5s

# Creating epochs for each class
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=None, preload=True)
data = epochs.get_data(copy=True)  # shape is (n_epochs, n_channels, n_times)

# Standardize each channel
data = (data - data.mean(axis=2, keepdims=True)) / data.std(axis=2, keepdims=True)

# Step 3: Apply the iTFE module
def apply_convolution(epoch_data, kernel):
    return convolve1d(epoch_data, kernel, axis=-1, mode='constant')

def extract_temporal_features_with_convolution(epoch_data):
    kernels = [np.ones((3,))/3, np.ones((5,))/5, np.ones((11,))/11]
    features = []
    for kernel in kernels:
        convolved_data = apply_convolution(epoch_data, kernel)
        mean_features = convolved_data.mean(axis=1)
        std_features = convolved_data.std(axis=1)
        features.append(mean_features)
        features.append(std_features)
    # Concatenate the features from different kernels
    return np.concatenate(features, axis=0)

# Initialize an empty list to store features
features_list = []

# Process each epoch through the iTFE module
for i, epoch_data in enumerate(data):
    temporal_features = extract_temporal_features_with_convolution(epoch_data)
    features_list.append(temporal_features)

# Convert features list to a matrix
feature_matrix = np.array(features_list)

# Print the shape of the feature matrix
print(f'Feature matrix shape: {feature_matrix.shape}')

# Save the feature matrix
np.save(os.path.join(output_dir, 'feature_matrix_iTFE.npy'), feature_matrix)

# Plot and save some example epochs after iTFE processing
def plot_and_save_example_epochs(feature_matrix, save_dir, num_files_to_plot=5):
    os.makedirs(save_dir, exist_ok=True)
    saved_file_paths = []
    for i in range(min(num_files_to_plot, len(feature_matrix))):
        epoch_features = feature_matrix[i]
        
        # Plot the features
        plt.figure(figsize=(12, 6))
        plt.plot(epoch_features)
        plt.title(f'Epoch {i+1} Features after iTFE')
        save_path = os.path.join(save_dir, f'epoch_{i+1}_iTFE_features.png')
        plt.savefig(save_path)
        plt.close()
        saved_file_paths.append(save_path)
    
    return saved_file_paths

# Directory to save plots
plot_save_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5'
os.makedirs(plot_save_dir, exist_ok=True)

# Plot and save example epochs
saved_epoch_files = plot_and_save_example_epochs(feature_matrix, plot_save_dir)

# Print the paths to the saved files
print("Saved epoch feature plots after iTFE:")
for file_path in saved_epoch_files:
    print(file_path)


Extracting EDF parameters from /home/jovyan/BCICIV_2a_gdf/A02T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  next(self.gen)


Used Annotations descriptions: ['769', '770', '771', '772']
Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 288 events and 1126 original time points ...
1 bad epochs dropped
Feature matrix shape: (287, 132)
Saved epoch feature plots after iTFE:
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_1_iTFE_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_2_iTFE_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_3_iTFE_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_4_iTFE_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_5_iTFE_features.png


In [10]:
import numpy as np
import os
import matplotlib.pyplot as plt

# Define paths
feature_matrix_path = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2/feature_matrix_iTFE.npy'
output_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2'
plot_save_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5'

# Load the feature matrix from the iTFE module
feature_matrix = np.load(feature_matrix_path)

# Print the shape of the loaded feature matrix
print(f'Loaded feature matrix shape: {feature_matrix.shape}')

# Helper functions for ELU and Softmax
def elu(x, alpha=1.0):
    return np.where(x >= 0, x, alpha * (np.exp(x) - 1))

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

# Custom Dense layer implementation
def dense_layer(input_data, weights, bias):
    return np.dot(input_data, weights) + bias

# Step 4: Apply the DEC module
def dec_module(input_features):
    # Squeeze operation
    global_avg_pool = np.mean(input_features, axis=-1)
    
    # Define the weights and biases for the dense layers
    input_dim = global_avg_pool.shape[-1]
    hidden_units = input_dim // 2
    
    weights_1 = np.random.randn(input_dim, hidden_units) * 0.01
    bias_1 = np.zeros(hidden_units)
    weights_2 = np.random.randn(hidden_units, input_dim) * 0.01
    bias_2 = np.zeros(input_dim)
    
    # Excitation operation
    fc1_output = elu(dense_layer(global_avg_pool, weights_1, bias_1))
    fc2_output = softmax(dense_layer(fc1_output, weights_2, bias_2))
    
    # Rescale features
    rescaled_features = input_features * fc2_output[:, np.newaxis]
    return rescaled_features

# Apply DEC module to the feature matrix
dec_feature_matrix = np.array([dec_module(features.reshape(22, -1)) for features in feature_matrix])

# Save the DEC feature matrix
dec_feature_matrix_path = os.path.join(output_dir, 'dec_feature_matrix.npy')
np.save(dec_feature_matrix_path, dec_feature_matrix)

# Plot and save some example epochs after DEC processing
def plot_and_save_example_epochs(feature_matrix, save_dir, num_files_to_plot=5):
    os.makedirs(save_dir, exist_ok=True)
    saved_file_paths = []
    for i in range(min(num_files_to_plot, len(feature_matrix))):
        epoch_features = feature_matrix[i]
        
        # Plot the features
        plt.figure(figsize=(12, 6))
        plt.plot(epoch_features.flatten())
        plt.title(f'Epoch {i+1} Features after DEC')
        save_path = os.path.join(save_dir, f'epoch_{i+1}_dec_features.png')
        plt.savefig(save_path)
        plt.close()
        saved_file_paths.append(save_path)
    
    return saved_file_paths

# Directory to save plots
os.makedirs(plot_save_dir, exist_ok=True)

# Plot and save example epochs
saved_epoch_files = plot_and_save_example_epochs(dec_feature_matrix, plot_save_dir)

# Print the paths to the saved files
print("Saved epoch feature plots after DEC:")
for file_path in saved_epoch_files:
    print(file_path)

# Print path to the saved DEC feature matrix
print(f"DEC feature matrix saved at: {dec_feature_matrix_path}")


Loaded feature matrix shape: (287, 132)
Saved epoch feature plots after DEC:
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_1_dec_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_2_dec_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_3_dec_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_4_dec_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_5_dec_features.png
DEC feature matrix saved at: /home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2/dec_feature_matrix.npy


In [12]:
import numpy as np
import os
import matplotlib.pyplot as plt
import pywt

# Define paths
dec_feature_matrix_path = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2/dec_feature_matrix.npy'
output_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2'
plot_save_dir = '/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5'

# Load the DEC feature matrix
dec_feature_matrix = np.load(dec_feature_matrix_path)

# Print the shape of the loaded feature matrix
print(f'Loaded DEC feature matrix shape: {dec_feature_matrix.shape}')

# Step 1: Apply Continuous Wavelet Transform (CWT) using Morlet wavelet
def apply_cwt(epoch_data, scales=np.arange(1, 31), wavelet='cmor'):
    coefficients, frequencies = pywt.cwt(epoch_data, scales, wavelet, axis=-1)
    return coefficients

# Step 2: Transform DEC features using CWT
cwt_feature_matrix = np.array([apply_cwt(epoch) for epoch in dec_feature_matrix])

# Step 3: Perform Independent Sample T-statistics
def calculate_t_statistics(group1, group2):
    mean1 = np.mean(group1, axis=0)
    mean2 = np.mean(group2, axis=0)
    std1 = np.std(group1, axis=0, ddof=1)
    std2 = np.std(group2, axis=0, ddof=1)
    n1 = group1.shape[0]
    n2 = group2.shape[0]

    t_statistics = (mean1 - mean2) / np.sqrt((std1**2 / n1) + (std2**2 / n2))
    return t_statistics

# Group labels based on MI tasks (assuming the classes are labeled from 0 to 3)
# Adjust the indices according to your specific data
group_A_indices = np.where((epochs.events[:, -1] == 0) | (epochs.events[:, -1] == 1))[0]
group_B_indices = np.where((epochs.events[:, -1] == 2) | (epochs.events[:, -1] == 3))[0]

group_A = cwt_feature_matrix[group_A_indices]
group_B = cwt_feature_matrix[group_B_indices]

# Calculate T-statistics between groups A and B
t_statistics = calculate_t_statistics(group_A, group_B)

# Step 4: Weight features using T-statistics
weighted_cwt_feature_matrix = cwt_feature_matrix * t_statistics[np.newaxis, :, :, :]

# Save the weighted CWT feature matrix
weighted_cwt_feature_matrix_path = os.path.join(output_dir, 'weighted_cwt_feature_matrix.npy')
np.save(weighted_cwt_feature_matrix_path, weighted_cwt_feature_matrix)

# Step 5: Plot and save some example weighted CWT feature maps
def plot_and_save_example_weighted_cwt_features(feature_matrix, save_dir, num_files_to_plot=5):
    os.makedirs(save_dir, exist_ok=True)
    saved_file_paths = []
    
    for i in range(min(num_files_to_plot, len(feature_matrix))):
        epoch_features = feature_matrix[i]

        # Calculate the magnitude of CWT coefficients
        magnitude_features = np.abs(epoch_features)

        # Normalize magnitude to [0, 1] for visualization
        normalized_features = (magnitude_features - np.min(magnitude_features)) / (np.max(magnitude_features) - np.min(magnitude_features))

        # Plot the features
        plt.figure(figsize=(12, 6))
        plt.imshow(normalized_features.mean(axis=0), aspect='auto', cmap='jet', extent=[0, normalized_features.shape[-1], normalized_features.shape[1], 1])
        plt.title(f'Epoch {i+1} Weighted CWT Features')
        plt.colorbar(label='Magnitude')
        save_path = os.path.join(save_dir, f'epoch_{i+1}_weighted_cwt_features.png')
        plt.savefig(save_path)
        plt.close()
        saved_file_paths.append(save_path)

    return saved_file_paths


# Plot and save example weighted CWT feature maps
saved_epoch_files = plot_and_save_example_weighted_cwt_features(weighted_cwt_feature_matrix, plot_save_dir)

# Print the paths to the saved files
print("Saved weighted CWT feature plots:")
for file_path in saved_epoch_files:
    print(file_path)

# Print path to the saved weighted CWT feature matrix
print(f"Weighted CWT feature matrix saved at: {weighted_cwt_feature_matrix_path}")


Loaded DEC feature matrix shape: (287, 22, 6)


  t_statistics = (mean1 - mean2) / np.sqrt((std1**2 / n1) + (std2**2 / n2))


Saved weighted CWT feature plots:
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_1_weighted_cwt_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_2_weighted_cwt_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_3_weighted_cwt_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_4_weighted_cwt_features.png
/home/jovyan/BCICIV_2a_gdf/processed_epoch_sub5/epoch_5_weighted_cwt_features.png
Weighted CWT feature matrix saved at: /home/jovyan/BCICIV_2a_gdf/processed_epoch_sub2/weighted_cwt_feature_matrix.npy


In [8]:
import os
import numpy as np
import pywt
import mne
from sklearn.preprocessing import StandardScaler
from scipy.stats import kurtosis, skew
from spectrum import arburg
from mne.decoding import CSP

class FeatureExtractor:
    def __init__(self, dataset_path='data/BCICIV_2a_gdf', save_path='output/path', n_sub=9, sub_list=None,
                 sample_freq=250, data_aug=False, n_hop=0.1, window_sz=2, low_freq=0.5, high_freq=35,
                 wavelet=True, f_bank=False, wpd_noc=False, n_bands=8, low_frequencies=None,
                 high_frequencies=None, feature_list=None):
        self.dataset_path = dataset_path
        self.save_path = save_path
        self.n_sub = n_sub
        self.subjects = sub_list if sub_list is not None else list(range(1, n_sub + 1))
        self.sample_freq = sample_freq
        self.data_aug = data_aug
        self.n_hop = n_hop
        self.window_sz = window_sz
        self.low_freq = low_freq
        self.high_freq = high_freq
        self.wavelet = wavelet
        self.f_bank = f_bank
        self.wpd_noc = wpd_noc
        self.n_bands = n_bands
        self.low_frequencies = low_frequencies if low_frequencies is not None else np.arange(4, 37, 1)
        self.high_frequencies = high_frequencies if high_frequencies is not None else np.arange(8, 41, 1)
        self.feature_list = feature_list if feature_list is not None else [0, 1, 2, 6, 8, 12, 13, 19, 21]

        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    def load_gdf_data(self, file_path):
        raw = mne.io.read_raw_gdf(file_path, preload=True)
        raw.filter(self.low_freq, self.high_freq, fir_design='firwin')
        return raw

    def extract_features(self, raw):
        csp = CSP(n_components=2, reg=None, log=True, norm_trace=False)
        epochs = mne.make_fixed_length_epochs(raw, duration=self.window_sz, overlap=self.n_hop)
        X = epochs.get_data()
        y = epochs.events[:, -1]

        X = X.reshape(X.shape[0], -1)  # Flatten the data for CSP
        X = csp.fit_transform(X, y)

        features = []
        for x in X:
            feature_vector = []
            if 0 in self.feature_list:  # Mean Absolute Value
                feature_vector.append(np.mean(np.abs(x)))
            if 1 in self.feature_list:  # Root Mean Square
                feature_vector.append(np.sqrt(np.mean(x**2)))
            if 2 in self.feature_list:  # Average amplitude change
                feature_vector.append(np.mean(np.diff(x)))
            if 6 in self.feature_list:  # Mean Energy
                feature_vector.append(np.mean(x**2))
            if 8 in self.feature_list:  # Standard deviation
                feature_vector.append(np.std(x))
            if 12 in self.feature_list:  # Energy_ratio
                feature_vector.append(np.sum(x**2) / np.sum(x))
            if 13 in self.feature_list:  # Hjorth Activity and Complexity
                feature_vector.extend(self.hjorth(x))
            if 19 in self.feature_list:  # FFT features
                feature_vector.extend(self.fft_features(x))
            if 21 in self.feature_list:  # Autoregression model- Burg Algorithm
                feature_vector.extend(self.burg_features(x))

            features.append(feature_vector)
        return np.array(features)

    def hjorth(self, x):
        hjorth_activity = np.var(x)
        hjorth_mobility = np.sqrt(np.var(np.diff(x)) / hjorth_activity)
        hjorth_complexity = np.sqrt(np.var(np.diff(np.diff(x))) / np.var(np.diff(x))) / hjorth_mobility
        return [hjorth_activity, hjorth_mobility, hjorth_complexity]

    def fft_features(self, x):
        fft_vals = np.fft.fft(x)
        return [np.max(np.abs(fft_vals))]

    def burg_features(self, x):
        ar_model = arburg(x, 4)
        return ar_model

    def process_data(self):
        for sub in self.subjects:
            file_name = f'A{sub:02d}E.gdf'
            file_path = os.path.join(self.dataset_path, file_name)
            if os.path.exists(file_path):
                raw = self.load_gdf_data(file_path)
                features = self.extract_features(raw)
                np.save(os.path.join(self.save_path, f'{file_name}_features.npy'), features)
                print(f'Features extracted and saved for {file_name}')
            else:
                print(f'File {file_name} not found in {self.dataset_path}')

if __name__ == '__main__':
    # Example configuration
    extractor = FeatureExtractor(
        dataset_path='data/BCICIV_2a_gdf',
        save_path='output/path',
        n_sub=9,
        sub_list=list(range(1, 10)),
        sample_freq=250,
        data_aug=False,
        n_hop=0.1,
        window_sz=2,
        low_freq=0.5,
        high_freq=35,
        wavelet=True,
        f_bank=False,
        wpd_noc=False,
        n_bands=8,
        low_frequencies=np.arange(4, 37, 1),
        high_frequencies=np.arange(8, 41, 1),
        feature_list=[0, 1, 2, 6, 8, 12, 13, 19, 21]
    )
    extractor.process_data()


File A01E.gdf not found in data/BCICIV_2a_gdf
File A02E.gdf not found in data/BCICIV_2a_gdf
File A03E.gdf not found in data/BCICIV_2a_gdf
File A04E.gdf not found in data/BCICIV_2a_gdf
File A05E.gdf not found in data/BCICIV_2a_gdf
File A06E.gdf not found in data/BCICIV_2a_gdf
File A07E.gdf not found in data/BCICIV_2a_gdf
File A08E.gdf not found in data/BCICIV_2a_gdf
File A09E.gdf not found in data/BCICIV_2a_gdf


In [9]:
import os

dataset_path = 'data/BCICIV_2a_gdf'
files = os.listdir(dataset_path)
print("Files in directory:", files)


Files in directory: ['processed_epoch_sub1_QC']


In [12]:
import os
import numpy as np
import mne
import pywt
from scipy import signal
from spectrum import arburg
import time

class FeatureExtractor:
    def __init__(self, **kwargs):
        # Default dataset path and configuration
        self.dataset_path = kwargs.pop('dataset_path', 'data/BCICIV_2a_gdf')
        self.save_path = kwargs.pop('save_path', 'output/path')
        self.n_sub = kwargs.pop('n_sub', 9)
        self.subjects = kwargs.pop('sub_list', list(range(1, self.n_sub + 1)))  # assuming subjects are 1-based
        self.sample_freq = kwargs.pop('sample_freq', 250)
        self.data_aug = kwargs.pop('data_aug', True)
        self.n_hop = kwargs.pop('n_hop', 0.1)
        self.window_sz = kwargs.pop('window_sz', 2)
        self.low_freq = kwargs.pop('low_freq', 0.5)
        self.high_freq = kwargs.pop('high_freq', 35)
        self.wavelet = kwargs.pop('wavelet', True)
        self.f_bank = kwargs.pop('f_bank', False)
        self.wpd_noc = kwargs.pop('wpd_noc', False)
        self.n_bands = kwargs.pop('n_bands', 8)
        self.low_frequencies = kwargs.pop('low_frequencies', np.arange(4, 37, 1))
        self.high_frequencies = kwargs.pop('high_frequencies', np.arange(8, 41, 1))
        self.feature_list = kwargs.pop('feature_list', [0, 1, 2, 6, 8, 12, 13, 19, 21])

        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    def downsample(self, x_data, sample_freq=128):
        q = self.sample_freq / sample_freq
        return mne.filter.resample(x_data, down=q, npad='auto')

    def sliding_window(self, x_data, y_data, fs):
        duration = x_data.shape[2] / fs
        windows = np.arange(0, duration - self.window_sz + self.n_hop, self.n_hop)
        n_windows = len(windows)
        X_aug = np.zeros((x_data.shape[0] * n_windows, x_data.shape[1], int(self.window_sz * fs)))
        y_aug = np.zeros((y_data.shape[0] * n_windows, 1))
        for i in range(x_data.shape[0]):
            for j in range(x_data.shape[1]):
                for idx, w in enumerate(windows):
                    X_aug[(i * n_windows) + idx, j, :] = x_data[i, j, int(w * fs):int((w + self.window_sz) * fs)]
                    y_aug[(i * n_windows) + idx] = y_data[i]
        return X_aug, y_aug

    def filter_data(self, x_data, fs, low, high):
        iir_params = dict(order=6, ftype='butter')
        filt = mne.filter.create_filter(x_data, fs, l_freq=low, h_freq=high, method='iir', iir_params=iir_params, verbose=False)
        return signal.sosfiltfilt(filt['sos'], x_data)

    def wpd(self, x_data):
        coeffs = pywt.WaveletPacket(x_data, 'db4', mode='symmetric', maxlevel=6)
        return coeffs

    def feature_bands(self, x_data, level=6, start=1, n_bands=8):
        all_bands = []
        for i in range(x_data.shape[0]):
            bands = []
            for j in range(x_data.shape[1]):
                subbands = []
                C = self.wpd(x_data[i, j, :])
                wpd_bands = C.get_level(level, 'natural')
                for b in range(start, start + n_bands):
                    subbands.append(wpd_bands[b].data)
                bands.append(subbands)
            all_bands.append(bands)
        return np.array(all_bands)

    def filter_bank(self, x_data, fs):
        filtered_X = np.zeros((x_data.shape[0], x_data.shape[1], len(self.low_frequencies), x_data.shape[2]))
        for i in range(len(self.low_frequencies)):
            filtered_X[:, :, i] = self.filter_data(x_data, fs, self.low_frequencies[i], self.high_frequencies[i])
        return filtered_X

    def hjorth(self, xV):
        hjorth_activity = np.var(xV, axis=1)
        hjorth_mobility = np.sqrt(np.var(np.diff(xV, axis=1), axis=1) / hjorth_activity)
        hjorth_diffmobility = np.sqrt(np.var(np.diff(np.diff(xV, axis=1), axis=1), axis=1) / np.var(np.diff(xV, axis=1), axis=1))
        hjorth_complexity = hjorth_diffmobility / hjorth_mobility
        return hjorth_activity, hjorth_mobility, hjorth_complexity

    def extract_features(self):
        start_time = time.time()
        for sub_id in self.subjects:
            sub_id_str = f'A{sub_id:02d}E'
            path = os.path.join(self.dataset_path, f'{sub_id_str}.gdf')
            save_path = os.path.join(self.save_path, f'sub_{sub_id_str}')
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            # Check if file exists
            if not os.path.isfile(path):
                print(f"File {path} not found.")
                continue

            # Load GDF file
            raw = mne.io.read_raw_gdf(path, preload=True)
            x_data = raw.get_data()
            y_data = raw.annotations.description  # Assuming annotations are used as labels

            if self.sample_freq:
                x_data = self.downsample(x_data, sample_freq=self.sample_freq)

            if self.data_aug:
                x_data, y_data = self.sliding_window(x_data, y_data, fs=self.sample_freq)

            x_data = self.filter_data(x_data, fs=self.sample_freq, low=self.low_freq, high=self.high_freq)

            if self.wavelet:
                x_data = self.feature_bands(x_data, level=6, start=1, n_bands=self.n_bands)
                x_data = np.transpose(x_data, (0, 1, 2, 3))
            else:
                if self.f_bank:
                    x_data = self.filter_bank(x_data, fs=self.sample_freq)
                x_data = np.transpose(x_data, (0, 1, 2, 3))

            features = np.zeros((x_data.shape[0], x_data.shape[1], 22))
            for i in range(x_data.shape[0]):
                for j in range(x_data.shape[1]):
                    x = x_data[i, j, :]
                    if 0 in self.feature_list:
                        features[i, j, 0] = np.mean(np.abs(x))
                    if 1 in self.feature_list:
                        features[i, j, 1] = np.sqrt(np.mean(x ** 2))
                    if 2 in self.feature_list:
                        features[i, j, 2] = np.mean(np.diff(x))
                    if 6 in self.feature_list:
                        features[i, j, 6] = np.mean(np.abs(x) ** 2)
                    if 8 in self.feature_list:
                        features[i, j, 8] = np.std(x)
                    if 12 in self.feature_list:
                        features[i, j, 12] = np.sum(x ** 2)
                    if 13 in self.feature_list:
                        hjorth_activity, hjorth_mobility, hjorth_complexity = self.hjorth(x)
                        features[i, j, 13] = np.mean(hjorth_activity)
                        features[i, j, 14] = np.mean(hjorth_mobility)
                        features[i, j, 15] = np.mean(hjorth_complexity)
                    if 19 in self.feature_list:
                        fft_vals = np.abs(np.fft.fft(x))
                        features[i, j, 19] = np.max(fft_vals)
                    if 21 in self.feature_list:
                        ar_coeffs = arburg(x, 4)
                        features[i, j, 21] = np.mean(np.abs(ar_coeffs))

            features = np.mean(features, axis=1)
            features_save_path = os.path.join(save_path, f'features_sub_{sub_id_str}.npy')
            np.save(features_save_path, features)
            print(f"Extracted features for subject {sub_id_str} in {time.time() - start_time:.2f} seconds.")

if __name__ == "__main__":
    # Example usage
    extractor = FeatureExtractor(
        dataset_path='data/BCICIV_2a_gdf',
        save_path='data/BCICIV_2a_gdf/processed',
        n_sub=9
    )
    extractor.extract_features()


File data/BCICIV_2a_gdf/A01E.gdf not found.
File data/BCICIV_2a_gdf/A02E.gdf not found.
File data/BCICIV_2a_gdf/A03E.gdf not found.
File data/BCICIV_2a_gdf/A04E.gdf not found.
File data/BCICIV_2a_gdf/A05E.gdf not found.
File data/BCICIV_2a_gdf/A06E.gdf not found.
File data/BCICIV_2a_gdf/A07E.gdf not found.
File data/BCICIV_2a_gdf/A08E.gdf not found.
File data/BCICIV_2a_gdf/A09E.gdf not found.


In [15]:
import os

dataset_path = 'data/BCICIV_2a_gdf'

# List all files in the directory
try:
    files = os.listdir(dataset_path)
    if not files:
        print(f"No files found in directory: {dataset_path}")
    else:
        print("Files found in directory:")
        for file in files:
            print(file)
except FileNotFoundError:
    print(f"Directory not found: {dataset_path}")


Files found in directory:
processed_epoch_sub1_QC
processed


In [3]:
import os
import numpy as np
import pywt
from scipy import signal
from sklearn import preprocessing
from spectrum import arburg

class FeatureExtractor:
    def __init__(self, **kwargs):
        # Default dataset path and configuration
        self.dataset_path = kwargs.pop('dataset_path', 'data/BCICIV_2a_gdf/sub1_segment_QC')
        self.save_path = kwargs.pop('save_path', 'output/path')
        self.sample_freq = kwargs.pop('sample_freq', 250)
        self.data_aug = kwargs.pop('data_aug', True)
        self.n_hop = kwargs.pop('n_hop', 0.1)
        self.window_sz = kwargs.pop('window_sz', 2)
        self.low_freq = kwargs.pop('low_freq', 0.5)
        self.high_freq = kwargs.pop('high_freq', 35)
        self.wavelet = kwargs.pop('wavelet', True)
        self.f_bank = kwargs.pop('f_bank', False)
        self.n_bands = kwargs.pop('n_bands', 8)
        self.low_frequencies = kwargs.pop('low_frequencies', np.arange(4, 37, 1))
        self.high_frequencies = kwargs.pop('high_frequencies', np.arange(8, 41, 1))
        self.feature_list = kwargs.pop('feature_list', [0, 1, 2, 6, 8, 12, 13, 19, 21])

        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    def load_data(self):
        # Initialize lists to store data and labels
        data_list = []
        labels = []

        # Loop through class folders
        for class_folder in range(1, 5):  # assuming class folders are named class_1 to class_4
            class_path = os.path.join(self.dataset_path, f'class_{class_folder}')
            print(f"Checking class folder: {class_path}")
            
            if not os.path.isdir(class_path):
                print(f"Class folder not found: {class_path}")
                continue

            # Loop through .npy files in each class folder
            for file_name in os.listdir(class_path):
                if file_name.endswith('.npy'):
                    file_path = os.path.join(class_path, file_name)
                    print(f"Loading file: {file_path}")
                    if not os.path.isfile(file_path):
                        print(f"File not found: {file_path}")
                        continue
                    data = np.load(file_path)
                    print(f"Loaded data shape: {data.shape}")
                    data_list.append(data)
                    labels.extend([class_folder] * len(data))

        # Convert to numpy arrays
        data_array = np.concatenate(data_list, axis=0) if data_list else np.array([])
        labels_array = np.array(labels) if labels else np.array([])

        print(f"Data shape: {data_array.shape}")
        print(f"Labels shape: {labels_array.shape}")

        return data_array, labels_array

    def downsample(self, x_data, sample_freq=128):
        q = self.sample_freq / sample_freq
        return mne.filter.resample(x_data, down=q, npad='auto')

    def sliding_window(self, x_data, y_data, fs):
        duration = x_data.shape[2] / fs
        windows = np.arange(0, duration - self.window_sz + self.n_hop, self.n_hop)
        n_windows = len(windows)
        X_aug = np.zeros((x_data.shape[0] * n_windows, x_data.shape[1], int(self.window_sz * fs)))
        y_aug = np.zeros((y_data.shape[0] * n_windows, 1))
        for i in range(x_data.shape[0]):
            for j in range(x_data.shape[1]):
                for idx, w in enumerate(windows):
                    X_aug[(i * n_windows) + idx, j, :] = x_data[i, j, int(w * fs):int((w + self.window_sz) * fs)]
                    y_aug[(i * n_windows) + idx] = y_data[i]
        return X_aug, y_aug

    def filter_data(self, x_data, fs, low, high):
        iir_params = dict(order=6, ftype='butter')
        filt = mne.filter.create_filter(x_data, fs, l_freq=low, h_freq=high, method='iir', iir_params=iir_params, verbose=False)
        return signal.sosfiltfilt(filt['sos'], x_data)

    def wpd(self, x_data):
        coeffs = pywt.WaveletPacket(x_data, 'db4', mode='symmetric', maxlevel=6)
        return coeffs

    def feature_bands(self, x_data, level=6, start=1, n_bands=8):
        all_bands = []
        for i in range(x_data.shape[0]):
            bands = []
            for j in range(x_data.shape[1]):
                subbands = []
                C = self.wpd(x_data[i, j, :])
                wpd_bands = C.get_level(level, 'natural')
                for b in range(start, start + n_bands):
                    subbands.append(wpd_bands[b].data)
                bands.append(subbands)
            all_bands.append(bands)
        return np.array(all_bands)

    def filter_bank(self, x_data, fs):
        filtered_X = np.zeros((x_data.shape[0], x_data.shape[1], len(self.low_frequencies), x_data.shape[2]))
        for i in range(len(self.low_frequencies)):
            filtered_X[:, :, i] = self.filter_data(x_data, fs, self.low_frequencies[i], self.high_frequencies[i])
        return filtered_X

    def hjorth(self, xV):
        hjorth_activity = np.var(xV, axis=1)
        hjorth_mobility = np.sqrt(np.var(np.diff(xV, axis=1), axis=1) / hjorth_activity)
        hjorth_diffmobility = np.sqrt(np.var(np.diff(np.diff(xV, axis=1), axis=1), axis=1) / np.var(np.diff(xV, axis=1), axis=1))
        hjorth_complexity = hjorth_diffmobility / hjorth_mobility
        return hjorth_activity, hjorth_mobility, hjorth_complexity

    def extract_features(self):
        x_data, y_data = self.load_data()
        
        if x_data.size == 0:
            print("No data loaded. Check the file paths and try again.")
            return

        if self.sample_freq:
            x_data = self.downsample(x_data, sample_freq=self.sample_freq)

        if self.data_aug:
            x_data, y_data = self.sliding_window(x_data, y_data, fs=self.sample_freq)

        x_data = self.filter_data(x_data, fs=self.sample_freq, low=self.low_freq, high=self.high_freq)

        if self.wavelet:
            x_data = self.feature_bands(x_data, level=6, start=1, n_bands=self.n_bands)
            x_data = np.transpose(x_data, (0, 1, 2, 3))
        else:
            if self.f_bank:
                x_data = self.filter_bank(x_data, fs=self.sample_freq)
            x_data = np.transpose(x_data, (0, 1, 2, 3))

        features = np.zeros((x_data.shape[0], x_data.shape[1], 22))
        for i in range(x_data.shape[0]):
            for j in range(x_data.shape[1]):
                x = x_data[i, j, :]
                if 0 in self.feature_list:
                    features[i, j, 0] = np.mean(np.abs(x))
                if 1 in self.feature_list:
                    features[i, j, 1] = np.sqrt(np.mean(x ** 2))
                if 2 in self.feature_list:
                    features[i, j, 2] = np.mean(np.diff(x))
                if 6 in self.feature_list:
                    features[i, j, 6] = np.mean(np.abs(x) ** 2)
                if 8 in self.feature_list:
                    features[i, j, 8] = np.std(x)
                if 12 in self.feature_list:
                    features[i, j, 12] = np.sum(x ** 2)
                if 13 in self.feature_list:
                    hjorth_activity, hjorth_mobility, hjorth_complexity = self.hjorth(x)
                    features[i, j, 13] = np.mean(hjorth_activity)
                    features[i, j, 14] = np.mean(hjorth_mobility)
                    features[i, j, 15] = np.mean(hjorth_complexity)
                if 19 in self.feature_list:
                    fft_vals = np.abs(np.fft.fft(x))
                    features[i, j, 19] = np.max(fft_vals)
                if 21 in self.feature_list:
                    ar_coeffs = arburg(x, 4)
                    features[i, j, 21] = np.mean(np.abs(ar_coeffs))

        features = np.mean(features, axis=1)
        features_save_path = os.path.join(self.save_path, f'features_subject_1.npy')
        np.save(features_save_path, features)
        print(f"Extracted features for subject 1 saved to {features_save_path}")

# Example usage
feature_extractor = FeatureExtractor()
feature_extractor.extract_features()


Checking class folder: data/BCICIV_2a_gdf/sub1_segment_QC/class_1
Class folder not found: data/BCICIV_2a_gdf/sub1_segment_QC/class_1
Checking class folder: data/BCICIV_2a_gdf/sub1_segment_QC/class_2
Class folder not found: data/BCICIV_2a_gdf/sub1_segment_QC/class_2
Checking class folder: data/BCICIV_2a_gdf/sub1_segment_QC/class_3
Class folder not found: data/BCICIV_2a_gdf/sub1_segment_QC/class_3
Checking class folder: data/BCICIV_2a_gdf/sub1_segment_QC/class_4
Class folder not found: data/BCICIV_2a_gdf/sub1_segment_QC/class_4
Data shape: (0,)
Labels shape: (0,)
No data loaded. Check the file paths and try again.


In [10]:
import os
import numpy as np
import mne
import pywt
from scipy import signal
from scipy.io import loadmat
from sklearn import preprocessing
from spectrum import arburg

class FeatureExtractor:
    def __init__(self, **kwargs):
        # Default configuration
        self.dataset_path = kwargs.pop('dataset_path', 'data/BCICIV_2a_gdf/sub1_segment_QC')
        self.save_path = kwargs.pop('save_path', 'output/path')
        self.n_sub = kwargs.pop('n_sub', 1)  # Assuming just one subject here
        self.subjects = kwargs.pop('sub_list', [1])
        self.sample_freq = kwargs.pop('sample_freq', 250)
        self.data_aug = kwargs.pop('data_aug', True)
        self.n_hop = kwargs.pop('n_hop', 0.1)
        self.window_sz = kwargs.pop('window_sz', 2)
        self.low_freq = kwargs.pop('low_freq', 0.5)
        self.high_freq = kwargs.pop('high_freq', 35)
        self.wavelet = kwargs.pop('wavelet', True)
        self.f_bank = kwargs.pop('f_bank', False)
        self.wpd_noc = kwargs.pop('wpd_noc', False)
        self.n_bands = kwargs.pop('n_bands', 8)
        self.low_frequencies = kwargs.pop('low_frequencies', np.arange(4, 37, 1))
        self.high_frequencies = kwargs.pop('high_frequencies', np.arange(8, 41, 1))
        self.feature_list = kwargs.pop('feature_list', [0, 1, 2, 6, 8, 12, 13, 19, 21])

        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    def filter_data(self, x_data, fs, low, high):
        iir_params = dict(order=6, ftype='butter')
        filt = mne.filter.create_filter(x_data, fs, l_freq=low, h_freq=high, method='iir', iir_params=iir_params, verbose=False)
        return signal.sosfiltfilt(filt['sos'], x_data)

    def feature_bands(self, x_data, level=6, start=1, n_bands=8):
        all_bands = []
        for i in range(x_data.shape[0]):
            bands = []
            for j in range(x_data.shape[1]):
                subbands = []
                C = self.wpd(x_data[i, j, :])
                wpd_bands = C.get_level(level, 'natural')
                for b in range(start, start + n_bands):
                    subbands.append(wpd_bands[b].data)
                bands.append(subbands)
            all_bands.append(bands)
        return np.array(all_bands)

    def wpd(self, x_data):
        coeffs = pywt.WaveletPacket(x_data, 'db4', mode='symmetric', maxlevel=6)
        return coeffs

    def hjorth(self, xV):
        hjorth_activity = np.var(xV, axis=1)
        hjorth_mobility = np.sqrt(np.var(np.diff(xV, axis=1), axis=1) / hjorth_activity)
        hjorth_diffmobility = np.sqrt(np.var(np.diff(np.diff(xV, axis=1), axis=1), axis=1) / np.var(np.diff(xV, axis=1), axis=1))
        hjorth_complexity = hjorth_diffmobility / hjorth_mobility
        return hjorth_activity, hjorth_mobility, hjorth_complexity

    def extract_features(self):
        for sub_id in self.subjects:
            sub_id_str = f'sub{sub_id}'
            path = os.path.join(self.dataset_path, sub_id_str)
            save_path = os.path.join(self.save_path, sub_id_str)
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            # Check for class folders
            for class_label in range(1, 5):
                class_path = os.path.join(path, f'class_{class_label}')
                if not os.path.exists(class_path):
                    print(f"Class folder not found: {class_path}")
                    continue
                
                # Process each file in the class folder
                for file_name in os.listdir(class_path):
                    if file_name.endswith('.npy'):
                        file_path = os.path.join(class_path, file_name)
                        print(f"Processing file: {file_path}")
                        data = np.load(file_path)
                        print(f"Data shape: {data.shape}")

                        # Process data (assuming 22 channels, 1125 samples)
                        x_data = data
                        fs = self.sample_freq

                        # Filter data if needed
                        x_data = self.filter_data(x_data, fs, low=self.low_freq, high=self.high_freq)

                        # Extract features
                        if self.wavelet:
                            x_data = self.feature_bands(x_data, level=6, start=1, n_bands=self.n_bands)
                            x_data = np.transpose(x_data, (0, 1, 2, 3))
                        else:
                            if self.f_bank:
                                x_data = self.filter_bank(x_data, fs=self.sample_freq)
                            x_data = np.transpose(x_data, (0, 1, 2, 3))

                        features = np.zeros((x_data.shape[0], x_data.shape[1], 22))
                        for i in range(x_data.shape[0]):
                            for j in range(x_data.shape[1]):
                                x = x_data[i, j, :]
                                if 0 in self.feature_list:
                                    features[i, j, 0] = np.mean(np.abs(x))
                                if 1 in self.feature_list:
                                    features[i, j, 1] = np.sqrt(np.mean(x ** 2))
                                if 2 in self.feature_list:
                                    features[i, j, 2] = np.mean(np.diff(x))
                                if 6 in self.feature_list:
                                    features[i, j, 6] = np.mean(np.abs(x) ** 2)
                                if 8 in self.feature_list:
                                    features[i, j, 8] = np.std(x)
                                if 12 in self.feature_list:
                                    features[i, j, 12] = np.sum(x ** 2)
                                if 13 in self.feature_list:
                                    hjorth_activity, hjorth_mobility, hjorth_complexity = self.hjorth(x)
                                    features[i, j, 13] = np.mean(hjorth_activity)
                                    features[i, j, 14] = np.mean(hjorth_mobility)
                                    features[i, j, 15] = np.mean(hjorth_complexity)
                                if 19 in self.feature_list:
                                    fft_vals = np.abs(np.fft.fft(x))
                                    features[i, j, 19] = np.max(fft_vals)
                                if 21 in self.feature_list:
                                    ar_coeffs = arburg(x, 4)
                                    features[i, j, 21] = np.mean(np.abs(ar_coeffs))

                        features = np.mean(features, axis=1)
                        features_save_path = os.path.join(save_path, f'features_{file_name}')
                        np.save(features_save_path, features)
                        print(f"Extracted features for file {file_name}. Saved to {features_save_path}")

# Example usage
extractor = FeatureExtractor(
    dataset_path='data/BCICIV_2a_gdf/sub1_segment_QC',
    save_path='output/path'
)
extractor.extract_features()


Class folder not found: data/BCICIV_2a_gdf/sub1_segment_QC/sub1/class_1
Class folder not found: data/BCICIV_2a_gdf/sub1_segment_QC/sub1/class_2
Class folder not found: data/BCICIV_2a_gdf/sub1_segment_QC/sub1/class_3
Class folder not found: data/BCICIV_2a_gdf/sub1_segment_QC/sub1/class_4


In [11]:
import os

file_path = 'data/BCICIV_2a_gdf/sub1_segment_QC/class_1/sub_epoch_103.npy'

# Check if the file exists
if os.path.isfile(file_path):
    print(f"File exists: {file_path}")
else:
    print(f"File does not exist: {file_path}")


File does not exist: data/BCICIV_2a_gdf/sub1_segment_QC/class_1/sub_epoch_103.npy


In [12]:
import os

base_path = 'data/BCICIV_2a_gdf/sub1_segment_QC'
class_folder = 'class_1'

# Full path to the expected directory
full_path = os.path.join(base_path, class_folder)

print(f"Checking folder: {full_path}")

if os.path.exists(full_path):
    print(f"Directory exists: {full_path}")
    files = os.listdir(full_path)
    print(f"Files in {full_path}: {files}")
else:
    print(f"Directory not found: {full_path}")


Checking folder: data/BCICIV_2a_gdf/sub1_segment_QC/class_1
Directory not found: data/BCICIV_2a_gdf/sub1_segment_QC/class_1
