### Imports

In [1]:
# !pip install mne
# !pip install mne-connectivity

In [30]:
import os

import numpy as np

import mne
from mne.time_frequency import psd_welch

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.model_selection import cross_val_score

from matplotlib import pyplot as plt

import gc

import time

from joblib import Parallel, delayed
import matlab.engine

from sklearn.model_selection import train_test_split
import shutil

## Preprocessing

### Loading edf

In [3]:
file = "..\dataverse_files\h01.edf"
edfs_path = "..\dataverse_files"
manifest_path = "..\dataverse_files\MANIFEST.txt"

In [4]:
def load_patients_data(edfs_path):
    raw_patients_data = []
    
    edfs_file_names = [f for f in os.listdir(edfs_path) if f.endswith('.edf')]
    
    for file_name in edfs_file_names:
        path = edfs_path + '\\' + file_name 
        raw_data = mne.io.read_raw_edf(path, preload=True, verbose=False)
        raw_patients_data.append(raw_data)

    return raw_patients_data

In [5]:
raw_patients_data = load_patients_data(edfs_path)

### Filtered EEG signals segmentation

In [6]:
def get_label(edf):
    patient_edf_file_name = edf.filenames[0].split('\\')[-1]
    isSick = patient_edf_file_name.lower().startswith('s')
    return int(isSick == True) # 1 - is sick, 0 is healthy

In [7]:
def get_min_max_duration_for_classes(print_durations=False):
    min_SZ_negative_duration = float("inf") # healthy
    min_SZ_positive_duration = float("inf") # sick

    max_SZ_negative_duration = 0 # healthy
    max_SZ_positive_duration = 0 # sick

    for edf in raw_patients_data:
        duration = edf.times[-1]

        if(get_label(edf) == 0):
            min_SZ_negative_duration = duration if duration < min_SZ_negative_duration else min_SZ_negative_duration
            max_SZ_negative_duration = duration if duration > max_SZ_negative_duration else max_SZ_negative_duration
        else:
            min_SZ_positive_duration = duration if duration < min_SZ_positive_duration else min_SZ_positive_duration
            max_SZ_positive_duration = duration if duration > max_SZ_positive_duration else max_SZ_positive_duration


    print('SZ_negative: min =', min_SZ_negative_duration, ', max =', max_SZ_negative_duration)
    print('SZ_positive: min =', min_SZ_positive_duration, ', max =', max_SZ_positive_duration)
    
    return min_SZ_negative_duration, min_SZ_positive_duration, max_SZ_negative_duration, max_SZ_positive_duration

In [8]:
def crop_raw_data_to_equalize_duration_per_class():
    print("Duration per class before cropping: ")
    min_dur_neg, min_dur_pos, *_ = get_min_max_duration_for_classes(True)
    
    for edf in raw_patients_data:
        duration = edf.times[-1]

        if(get_label(edf) == 0):
            if(duration > min_dur_neg):
                edf.crop(tmin=0, tmax=min_dur_neg, include_tmax=True)
        else:
            if(duration > min_dur_pos):
                edf.crop(tmin=0, tmax=min_dur_pos, include_tmax=True)
                
    print("\nDuration per class after cropping: ")

    get_min_max_duration_for_classes(True)


In [9]:
def print_info(epochs_num_per_patient, labels):
    print('\nEpochs number per patient: ', epochs_num_per_patient)
    
    class_SZ_positive = sum(labels) 
    class_SZ_negative= len(labels)-sum(labels)

    print('\nnegative: ', class_0_num)
    print('positive: ', class_1_num)

In [10]:
def transform_patients_data_into_X_y_sets(patients_data, segment_duration=1.0, info=True):
    epochs_per_patient = []
    labels = []
    
    epochs_num_per_patient = []
    for edf in raw_patients_data:
        epochs = mne.make_fixed_length_epochs(edf, duration=segment_duration, preload=True, verbose=False)
        epochs_per_patient.append(epochs)
        epochs_num_per_patient.append(len(epochs))
        
        label = get_label(edf)
        labels.extend([label for epoch in epochs])
    
    epochs = mne.concatenate_epochs(epochs_per_patient)

    if info:
        print_info(epochs_num_per_patient, labels)
        
    del epochs_num_per_patient
    gc.collect()
    
    return (epochs, np.array(labels)) # (X, y)

In [11]:
crop_raw_data_to_equalize_duration_per_class()

Duration per class before cropping: 
SZ_negative: min = 864.996 , max = 1114.996
SZ_positive: min = 739.996 , max = 2169.996

Duration per class after cropping: 
SZ_negative: min = 864.996 , max = 864.996
SZ_positive: min = 739.996 , max = 739.996


In [12]:
X, y = transform_patients_data_into_X_y_sets(patients_data=raw_patients_data, segment_duration=5.0)

Not setting metadata
4494 matching events found
No baseline correction applied
0 bad epochs dropped

Epochs number per patient:  [173, 173, 173, 173, 173, 173, 173, 173, 173, 173, 173, 173, 173, 173, 148, 148, 148, 148, 148, 148, 148, 148, 148, 148, 148, 148, 148, 148]

negative:  2072
positive:  2422


In [13]:
del raw_patients_data
gc.collect()

0

In [14]:
print(len(y))
print(len(X))
print(X[0].get_data().shape)

4494
4494
(1, 19, 1250)


In [15]:
X.to_data_frame().head()

Unnamed: 0,time,condition,epoch,Fp2,F8,T4,T6,O2,Fp1,F7,...,O1,F4,C4,P4,F3,C3,P3,Fz,Cz,Pz
0,0,1,0,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,...,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025
1,4,1,0,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,...,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025
2,8,1,0,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,...,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025
3,12,1,0,0.461215,0.461215,0.30831,0.30831,0.155405,0.0025,0.0025,...,0.0025,0.0025,0.0025,0.0025,0.0025,-0.150405,-0.30331,0.0025,0.0025,-0.30331
4,16,1,0,0.461215,0.461215,0.461215,0.30831,0.155405,0.0025,0.0025,...,-0.30331,0.0025,0.155405,0.0025,0.0025,-0.150405,-0.30331,0.0025,0.0025,-0.150405


In [16]:
X[0].to_data_frame().head()

Unnamed: 0,time,condition,epoch,Fp2,F8,T4,T6,O2,Fp1,F7,...,O1,F4,C4,P4,F3,C3,P3,Fz,Cz,Pz
0,0,1,0,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,...,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025
1,4,1,0,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,...,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025
2,8,1,0,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,...,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025,0.0025
3,12,1,0,0.461215,0.461215,0.30831,0.30831,0.155405,0.0025,0.0025,...,0.0025,0.0025,0.0025,0.0025,0.0025,-0.150405,-0.30331,0.0025,0.0025,-0.30331
4,16,1,0,0.461215,0.461215,0.461215,0.30831,0.155405,0.0025,0.0025,...,-0.30331,0.0025,0.155405,0.0025,0.0025,-0.150405,-0.30331,0.0025,0.0025,-0.150405


In [17]:
### Data preparation

np.set_printoptions(precision=50)

x_data = X.get_data()
print('x_data shape:', x_data.shape)

column_names = X[0].to_data_frame().columns
column_names = column_names[-19:]
print('column_names:', column_names)

epoch_num, channel_num, epoch_len = x_data.shape

x_data shape: (4494, 19, 1250)
column_names: Index(['Fp2', 'F8', 'T4', 'T6', 'O2', 'Fp1', 'F7', 'T3', 'T5', 'O1', 'F4',
       'C4', 'P4', 'F3', 'C3', 'P3', 'Fz', 'Cz', 'Pz'],
      dtype='object')


### Generating spectrograms with matlab engine

In [18]:
# create spec for all segments v3 flatten segments before creating spectrogram
import matplotlib.pyplot as plt

from math import floor

from scipy.io import savemat

In [19]:
def create_spectrogram(x, mat_path, spec_path, eng):

    combined_channels_data = np.array([])
    
    for i in range(len(x)):
        combined_channels_data = np.concatenate((combined_channels_data, x[i]))
        
#     print(combined_channels_data.shape)
        
    mdic = {"data": combined_channels_data, "filename": spec_path}

    savemat(mat_path, mdic)

                    
#     eng = matlab.engine.start_matlab()

#     channels_data_mat = eng.cell2mat(channels_data.tolist())
    
#     print(mat_path) 
    eng.workspace['mat_path'] = mat_path
    
    eng.evalc("M = load(mat_path);")
    
#     eng.evalc("axes('Units','Normalize','Position',[0 0 1 1])")
    
    eng.evalc("f = figure(Position=[0 0 224 224])")
    eng.evalc("ax = axes('Units','Normalize','Position',[0 0 1 1])")
    

    eng.evalc("spectrogram(M.data, 'yaxis')") 
#     matlab.double(s, is_complex=True)

    
        
#     eng.spectrogram(eng.cell2mat(s),'yaxis');
    eng.evalc("colormap turbo")
    eng.evalc("colorbar off")
    eng.evalc("axis off")
    
#     eng.evalc("saveas(gcf, M.filename)")
    eng.evalc("cdata = print(gcf, '-RGBImage', '-r96')")
    eng.evalc("imwrite(cdata, turbo, M.filename)")
    
#     eng.quit()

In [20]:
def get_spectrogram(epoch_index, label, epoch_data, eng):
#     print('get_spectrogram')
    filename = file_name.format(
        epoch_index,
        label)

    spec_path = specs_dir_path.format(filename)
    mat_dir_path_path = mat_dir_path.format(filename)
#     print(filename)
#     print(epoch_data.shape)
    if(not os.path.isfile(spec_path)):
        create_spectrogram(epoch_data, mat_dir_path_path, spec_path, eng)

In [21]:
def run_partial_loop(from_index, step):
    to_index = from_index+step-1

    if to_index > epoch_num:
        to_index = epoch_num
        
    eng = matlab.engine.start_matlab()

    for epoch_index in range(from_index, to_index, 1):
        get_spectrogram(epoch_index=epoch_index,
                        label = y[epoch_index],
                        epoch_data=x_data[epoch_index],
                        eng=eng
                       )
    eng.quit()

In [22]:
# path = '../spectrograms/{0}_{1}'
file_name = '{0}_{1}'
specs_dir_path = '../spectrograms/{0}.png'
mat_dir_path = '../mats/{0}.mat'

In [26]:
# generating spectrograms parallel
n_jobs = 12
step = floor(epoch_num/n_jobs)+1 # number of epochs/segments per thread/job

Parallel(n_jobs=n_jobs)(delayed(run_partial_loop)(epoch_index, step) for epoch_index in range(0, epoch_num, step))


[None, None, None, None, None, None, None, None, None, None, None, None]

### Splitting spectrograms images to train and test datasets

In [32]:
# defining paths

generated_spectrograms_path = '../spectrograms/'

dataset_path = '../spectrograms_dataset/'
train_datase_path = dataset_path+'train/'
train_SZ_negative_class_dir = train_datase_path+'SZ_negative/'
train_SZ_positive_class_dir = train_datase_path+'SZ_positive/'

test_datase_path = dataset_path+'test/'
test_SZ_negative_class_dir = test_datase_path+'SZ_negative/'
test_SZ_positive_class_dir = test_datase_path+'SZ_positive/'

In [33]:
#creating dataset directory structure
if os.path.exists(train_datase_path):
    os.removedirs(train_datase_path)
    
os.makedirs(train_SZ_negative_class_dir)
os.makedirs(train_SZ_positive_class_dir)  
    
if os.path.exists(test_datase_path):
    os.removedirs(test_datase_path)

os.makedirs(test_SZ_negative_class_dir)
os.makedirs(test_SZ_positive_class_dir)

In [39]:
# moving images to train dataset directory and splitting according to class belonging
files = os.listdir(generated_spectrograms_path)

for file_name in files:
    src_img_path = generated_spectrograms_path + file_name

    if file_name.split('.')[0][-1] == '0':
        shutil.copy(src_img_path, train_SZ_negative_class_dir)
    else:
        shutil.copy(src_img_path, train_SZ_positive_class_dir)

In [41]:
class_dirs = os.listdir(train_datase_path)

src_class_dir_path = '../spectrograms_dataset/train/{0}'

src_path = '../spectrograms_dataset/train/{0}/{1}'
dst_path = '../spectrograms_dataset/test/{0}/{1}'

test_size_ration = 0.2

for class_dir in class_dirs:
    class_dir_path = src_class_dir_path.format(class_dir)
    img_files = os.listdir(class_dir_path)
    _, test_img_files = train_test_split(img_files, test_size=test_size_ration, random_state=1, shuffle=True)
    
    
    for img_file in test_img_files:
        src_img_path = src_path.format(class_dir, img_file)
        dst_img_path = dst_path.format(class_dir, img_file)
        
        shutil.move(src_img_path, dst_img_path)
           