# Feature Extraction

In [87]:
import os
import mne 
import pywt
import numpy as np
import pandas as pd
import scipy
import matplotlib.pyplot as plt
import scipy.stats as stats
from scipy.signal import welch
from scipy.integrate import simpson
from scipy.stats import skew

In [None]:
load_dir = r"C:\Users\ADMIN\Documents\Summerof2024\Epilepsy_Detection\Epochs"
loaded_epoch_list = []

for filename in os.listdir(load_dir):
    if filename.endswith("_epo.fif"):
        load_path = os.path.join(load_dir, filename)
        epochs = mne.read_epochs(load_path, preload = True)
        loaded_epoch_list.append(epochs)

print("Epoch list loaded")

In [89]:
def time_domain_features(epochs):
 

  features = [] # feature list
  epoch_data = epochs.get_data()
  n_epochs, n_channels, n_times = epoch_data.shape
  f_dict = {}
  for ch in range(n_channels): # Iterating through each channel
    data = epoch_data[0,ch,:]
    f_dict[f'ch{ch}'] = np.mean(data)                           # Index 0 = Mean
  features.append(f_dict)

  f_dict = {}
  for ch in range(n_channels):
    data = epoch_data[0,ch,:]                                   # Index 2 = STD
    f_dict[f'ch{ch}']  = np.std(data)
  features.append(f_dict)

  f_dict = {}
  for ch in range(n_channels):
    data = epoch_data[0,ch,:]
    f_dict[f'ch{ch}']  = stats.skew(data)                       # Index 3 = skew
  features.append(f_dict)

  f_dict = {}
  for ch in range(n_channels):
    data = epoch_data[0,ch,:]
    f_dict[f'ch{ch}']  = stats.kurtosis(data)                # Index 4 = Kurtosis
  features.append(f_dict) # Each element in the list is a feature dict of each channel

  return features

In [90]:

def BandPower(data, sfreq, band, window_sec = None, relative = False):
  if band == 'delta':
    low, high = 0.5, 4
  elif band == 'theta':
    low, high = 4, 8
  elif band == 'alpha':
    low, high = 8, 12
  elif band == 'beta':
    low, high = 12, 30
  elif band == 'gamma':
    low, high = 30, 100
  else:
    raise ValueError('Invalid band')

  nperseg = (2/low) * sfreq

  freq, psd = welch(data, sfreq, nperseg = nperseg)

  freq_res = freq[1] - freq[0]

  idx_band = np.logical_and(freq >= low, freq <= high)

  bp = simpson(psd[idx_band], dx = freq_res)

  if relative:
    bp = bp/simpson(psd, dx = freq_res)
  return bp


In [None]:
# change the frequency based separation: 

def time_features_aggregation(epoch):

    time_features = time_domain_features(epoch)

    stats = ['Mean','Std','Skew','Kurtosis']
   
   
    pd_list = []

    for j in range(len(stats)):
        c = pd.DataFrame.from_dict(time_features, columns = [stats[j]], orient= 'index')
        pd_list.append(c)
    print(pd_list)
    result_df = pd.concat(pd_list, axis = 1)
    return result_df

def band_feature_aggregation(epoch):
    band_list = ['alpha', 'beta', 'gamma', 'theta', 'delta']
    Band_feature = []
    epoch_data = epoch.get_data()
#channel_data =epoch_data[0,0,:]
    for j in range(len(band_list)):
        Band_dict = {}
        for i in range(23):
            d = BandPower(data = epoch_data[0,i,:],sfreq = 210 , band = band_list[j])
            Band_dict[f'ch{i}'] = d
        Band_feature.append(Band_dict)
    
# Creating Dataframe
    band_columns = []
    for i in range(len(band_list)):
        c = pd.DataFrame.from_dict(Band_feature[i], columns = [band_list[i]], 
                                   orient = 'index')
        band_columns.append(c)
    Band_feature_df = pd.concat(band_columns, axis = 1)
    return Band_feature_df

def Merge_csv(time_feature_df, Band_feature_df, index ):
    df_list = [time_feature_df, Band_feature_df]
    feature_vector_df = pd.concat(df_list, axis = 1)
    feature_vector_df.to_csv(rf"C:\Users\ADMIN\Documents\Summerof2024\Epilepsy_Detection\\Feature Vector\Ch{index}_feature_df.csv")
    return feature_vector_df

for i, item in enumerate(loaded_epoch_list):
   time_df = time_features_aggregation(item)
   band_df = band_feature_aggregation(item)
   Merge_csv(time_feature_df = time_df, Band_feature_df = band_df, index = i)
   

In [None]:
for item in loaded_epoch_list:
    print(len(item.info['ch_names']))