In [1]:
from glob import glob
import os
import mne
import numpy as np
import pandas
import matplotlib.pyplot as plt
from joblib import dump

In [None]:
glob('dataverse_files/*.edf')

In [3]:
all_file_path = glob('dataverse_files/*.edf')
print(len(all_file_path))

28


In [4]:
healthy_file_path = [i for i in all_file_path if 'h' in i.split('/')[1]] # if h is detected then it means it is a healthy file
patient_file_path = [i for i in all_file_path if 's' in i.split('/')[1]] # if s is detected then it means it is a schizophrenic file
print(len(healthy_file_path),len(patient_file_path))

14 14


In [5]:
def read_data(file_path):
    data = mne.io.read_raw_edf(file_path, preload=True) # read the data
    data.set_eeg_reference() # average of all channels
    data.filter(l_freq=0.5, h_freq=45) # filter to keep only the data l_freq < freq < h_freq (this steps generates continuous data)
    epochs = mne.make_fixed_length_epochs(data, duration=5, overlap=1) # split data into segments
    array = epochs.get_data() # data into a numpy array
    return array


In [None]:
sample_data = read_data(healthy_file_path[1])

In [7]:
sample_data.shape # no of epochs, channels, length of signal

(216, 19, 1250)

In [None]:
control_epochs_array = [read_data(i) for i in healthy_file_path] # healthy 
patient_epochs_array = [read_data(i) for i in patient_file_path] # schizophrenic

In [9]:
#control_epochs_array[0].shape, control_epochs_array[1].shape # (see the fromat of the data)

In [10]:
# label the data 0 -> healthy, 1 -> shizophrenia
control_epochs_labels = [len(i)*[0] for i in control_epochs_array] 
patient_epochs_labels = [len(i)*[1] for i in patient_epochs_array]
#len(control_epochs_labels), len(patient_epochs_labels)

In [11]:
data_list = control_epochs_array + patient_epochs_array # concatenate all data in one list
label_list = control_epochs_labels + patient_epochs_labels

In [12]:
group_list = [[i]*len(j) for i,j in enumerate(data_list)]
#len(group_list)

In [13]:
#group_list[1]

In [14]:
data_array = np.vstack(data_list)
label_array = np.hstack(label_list)
group_array = np.hstack(group_list)

#print(data_array.shape, label_array.shape, group_array.shape)
# each one of the printed values in parenthesis represent a dimension (an axis x, y, z)

In [15]:
# feature extraction
# np.mean(data_array, axis=-1).shape
# will extract the last value of the tuple (7201, 19)

In [16]:
from scipy import stats
# axis=-1 refers to the last axis in the array

# mean
def mean(data):
    return np.mean(data,axis=-1)

# standard deviation
def std(data):
    return np.std(data,axis=-1)

# peak-to-peak (PTP) / the range of values (maximum - minimum)
def ptp(data):
    return np.ptp(data,axis=-1)

# variance
def var(data):
        return np.var(data,axis=-1)

# minimum value
def minim(data):
      return np.min(data,axis=-1)

# maximum value
def maxim(data):
      return np.max(data,axis=-1)

# index of the minimum value
def argminim(data):
      return np.argmin(data,axis=-1)

# index of the maximum value
def argmaxim(data):
      return np.argmax(data,axis=-1)

# mean of the squared values
def mean_square(data):
      return np.mean(data**2,axis=-1)

# root mean square 
def rms(data): #root mean square
      return  np.sqrt(np.mean(data**2,axis=-1))  

# sum of absolute differences between consecutive elements
def abs_diffs_signal(data):
    return np.sum(np.abs(np.diff(data,axis=-1)),axis=-1)

#skewness (a measure of the asymmetry of the probability distribution)
def skewness(data):
    return stats.skew(data,axis=-1)

# kurtosis (a measure of the "tailedness" or shape of the probability distribution)
def kurtosis(data):
    return stats.kurtosis(data,axis=-1)

# all previous results in a single array
def concatenate_features(data):
    return np.concatenate((mean(data),std(data),ptp(data),var(data),minim(data),maxim(data),argminim(data),argmaxim(data),
                          mean_square(data),rms(data),abs_diffs_signal(data),
                          skewness(data),kurtosis(data)),axis=-1)


In [17]:
features = []
for d in data_array:
    features.append(concatenate_features(d))

In [18]:
features_array = np.array(features)
features_array.shape

(7201, 247)

In [19]:
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold, GridSearchCV

In [20]:
clf = LogisticRegression()
gkf = GroupKFold(5) # the dataset is divided into K subsets (folds) of approximately equal size, each time uses K-1 folds for training and one for validation, for K-times
pipe = Pipeline([('scaler',StandardScaler()),('clf',clf)]) # pipeline allows sequential execution of data processing steps before fitting the model 
param_grid = {'clf__C':[0.01,0.05,0.1,0.5,1,2,3,4,5,8, 10,12,15]} # hyperparameters to search for the best regularization (penalty) to avoid overfitting 
gscv = GridSearchCV(pipe, param_grid, cv=gkf, n_jobs=16)
gscv.fit(features_array, label_array, groups=group_array)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

In [1]:
gscv.best_score_

NameError: name 'gscv' is not defined

In [22]:
dump(gscv, 'model.joblib')

['model.joblib']