# Import Necessary Libraries

In [1]:
import numpy as np
import os
import mne
import pywt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow.keras import layers, models
from scipy.interpolate import interp1d
from scipy.stats import skew
from scipy.signal import welch
from scipy.stats import entropy
import numpy as np
import pywt
from scipy.stats import skew, kurtosis
from scipy.signal import find_peaks, welch
import warnings
warnings.filterwarnings('ignore')

# Data and Labels

In [2]:
# Load the saved features from the .npz file
data = np.load('combined_data.npz') 

# Access the concatenated features from the .npz file
X = data['X']

# Close the loaded file
data.close()

#print the merged array
X

array([[[-2.90148955e-06,  1.12909709e-06,  2.63543120e-06, ...,
         -2.89727495e-06, -6.80166352e-07,  1.37569339e-06],
        [-1.20626610e-05, -1.10767687e-05, -1.06193111e-05, ...,
         -5.28302598e-06, -3.08162134e-06, -1.08686640e-06],
        [-3.79762923e-07, -2.07405219e-06, -4.63428751e-06, ...,
         -1.40524844e-06, -2.77773978e-07,  5.66622077e-07],
        ...,
        [ 1.29488217e-06,  1.93576102e-06,  1.93742471e-06, ...,
         -5.94415587e-06, -6.17930937e-06, -4.59392322e-06],
        [ 6.58896488e-06,  7.07184518e-06,  7.21316883e-06, ...,
          1.58482089e-06,  1.10701546e-06,  1.01823366e-06],
        [ 2.77446316e-06, -3.67728944e-06, -5.35505842e-06, ...,
          1.31410061e-05,  9.10587005e-06,  1.86895170e-06]],

       [[ 1.65448073e-05,  1.31738316e-05,  9.04923854e-06, ...,
         -4.42389546e-06, -3.17349259e-06, -1.59048273e-06],
        [ 4.49196432e-06,  3.74644878e-06,  1.34007439e-06, ...,
         -9.20414914e-06, -8.74245961e

In [3]:
data = np.load('combined_labels.npz')

Y = data['Y']

data.close() 

Y

array([0, 0, 0, ..., 1, 1, 1])

In [4]:
data = np.load('combined_groups.npz')

group = data['group']

data.close() 

group

array([  0,   0,   0, ..., 179, 179, 179])

In [5]:
import numpy as np
from scipy.stats import skew, kurtosis
from scipy.signal import find_peaks, welch
import pywt

# Define the correct indices for the selected channels
channel_indices = [2, 17, 8, 15]  # C3, Cz, Fz, T4

# Utility functions
def zero_crossing_rate(signal):
    zero_crossings = np.where(np.diff(np.sign(signal)))[0]
    return len(zero_crossings) / len(signal)

def hjorth_parameters(signal):
    diff_input = np.diff(signal)
    diff_diff_input = np.diff(diff_input)

    activity = np.var(signal)
    mobility = np.sqrt(np.var(diff_input)/activity)
    complexity = np.sqrt(np.var(diff_diff_input)/np.var(diff_input)) / mobility

    return activity, mobility, complexity

# Time-Domain Feature Extraction
def extract_time_domain_features(epochs):
    features = []

    for epoch in epochs:
        epoch_features = []
        for channel_idx in channel_indices:
            channel_data = epoch[channel_idx]
            flattened_data = channel_data.flatten()

            # Basic Time-Domain Features
            mean_val = np.mean(flattened_data)
            median_val = np.median(flattened_data)
            var_val = np.var(flattened_data)
            std_dev = np.std(flattened_data)
            skewness = skew(flattened_data)
            kurt = kurtosis(flattened_data)
            zcr = zero_crossing_rate(flattened_data)
            peak_amp = np.ptp(flattened_data)

            # Hjorth Parameters
            activity, mobility, complexity = hjorth_parameters(flattened_data)

            # Additional Features
            num_waves = len(find_peaks(flattened_data)[0])
            wave_duration = len(flattened_data) / num_waves if num_waves > 0 else 0

            channel_features = [
                mean_val, median_val, var_val, std_dev, skewness, kurt, zcr, num_waves,
                wave_duration, peak_amp, activity, mobility, complexity
            ]
            epoch_features.append(channel_features)

        features.append(epoch_features)

    return np.array(features)

# Frequency-Domain Feature Extraction
def get_wavelet_coeffs(channel_data, wavelet='db4', level=3):
    coeffs = pywt.wavedec(channel_data, wavelet, level=level)
    return coeffs

# Frequency-Domain Feature Extraction
def extract_frequency_domain_features(epochs, sfreq, wavelet='db4', bands={'delta': (1, 3), 'theta': (4, 7), 'alpha': (8, 12), 'beta': (13, 30), 'gamma': (31, 60), 'sigma': (11, 16)}):
    features = []

    for epoch in epochs:
        epoch_features = []
        for channel_idx in channel_indices:
            channel_data = epoch[channel_idx]
            freqs, psd = welch(channel_data, sfreq, nperseg=180)

            # Frequency domain features
            mean_val = np.mean(psd)
            median_val = np.median(psd)
            var_val = np.var(psd)
            std_dev = np.std(psd)
            skewness = skew(psd)
            kurt = kurtosis(psd)

            # Compute wavelet coefficients
            wave_coeffs = get_wavelet_coeffs(channel_data, wavelet, level=3)
            wave_coeffs_mean = np.mean(wave_coeffs[0])

            # Band Power Features
            band_powers = {}
            for band, freq_range in bands.items():
                freq_mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1])
                band_power = np.sum(psd[freq_mask])
                band_powers[band] = band_power

            # Band Power Ratios
            theta_alpha_ratio = band_powers['theta'] / band_powers['alpha']
            beta_alpha_ratio = band_powers['beta'] / band_powers['alpha']
            theta_alpha_beta_ratio = (band_powers['theta'] + band_powers['alpha']) / band_powers['beta']
            theta_beta_ratio = band_powers['theta'] / band_powers['beta']
            theta_alpha_beta_alpha_ratio = (band_powers['theta'] + band_powers['alpha']) / (band_powers['alpha'] + band_powers['beta'])
            gamma_delta_ratio = band_powers['gamma'] / band_powers['delta']
            gamma_beta_delta_alpha_ratio = (band_powers['gamma'] + band_powers['beta']) / (band_powers['delta'] + band_powers['alpha'])

            channel_features = [
                mean_val, median_val, var_val, std_dev, skewness, kurt,
                band_powers['delta'], band_powers['theta'], band_powers['alpha'],
                band_powers['beta'], band_powers['gamma'], band_powers['sigma'],
                theta_alpha_ratio, beta_alpha_ratio, theta_alpha_beta_ratio, theta_beta_ratio,
                theta_alpha_beta_alpha_ratio, gamma_delta_ratio, gamma_beta_delta_alpha_ratio,
                wave_coeffs_mean
            ]
            epoch_features.append(channel_features)

        features.append(epoch_features)

    return np.array(features)

# DWT Feature Extraction
def extract_dwt_features(epochs, wavelet='db4', level=4):
    dwt_features = []

    for epoch in epochs:
        epoch_features = []
        for channel_idx in channel_indices:
            channel_data = epoch[channel_idx]
            coeffs = pywt.wavedec(channel_data, wavelet, level=level)
            concatenated_coeffs = np.concatenate(coeffs, axis=0)
            epoch_features.append(concatenated_coeffs)

        dwt_features.append(epoch_features)

    return np.array(dwt_features)

# Helper Function to Combine All Features
def extract_combined_features(epochs, sfreq):
    # Extract features for combined channels
    time_features = extract_time_domain_features(epochs)
    freq_features = extract_frequency_domain_features(epochs, sfreq)
    dwt_features = extract_dwt_features(epochs)

    # Combine features by concatenating along the features axis
    combined_features = np.concatenate((time_features, freq_features, dwt_features), axis=2)

    return time_features, freq_features, dwt_features, combined_features

# Saving Features in NPZ Format
def save_features(time_features, freq_features, dwt_features, combined_features, filename='features.npz'):
    np.savez(filename, time_features=time_features, freq_features=freq_features, dwt_features=dwt_features, combined_features=combined_features)
    print(f"Features saved to {filename}")

# Example Usage
sfreq = 256  # Sampling frequency

# Assume `epochs` is your EEG data with shape (epochs, channels, samples)
epochs = X

# Extract features
time_features, freq_features, dwt_features, combined_features = extract_combined_features(epochs, sfreq)

# Save features
save_features(time_features, freq_features, dwt_features, combined_features, filename='eeg_features.npz')

Features saved to eeg_features.npz


In [6]:
print('Time Features Shape: ', time_features.shape)
print('Frequency Features Shape: ', freq_features.shape)
print('DWT Features Shape: ', dwt_features.shape)
print('Combined Features Shape: ', combined_features.shape)

Time Features Shape:  (73519, 4, 13)
Frequency Features Shape:  (73519, 4, 20)
DWT Features Shape:  (73519, 4, 205)
Combined Features Shape:  (73519, 4, 238)


# Complete