In [None]:
import glob
import os.path
import re
import mne
import numpy as np
%matplotlib inline
from matplotlib import pyplot as plt
import pandas as pd
import csv
import numpy as np     
import nltk, string
from scipy.signal import resample, hann
from sklearn import feature_extraction
from sklearn import svm
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score, KFold, cross_validate
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import make_scorer, recall_score

In [None]:
def parse_document(document):
    punctuation="[!@#$%^&*()_+{}:\"<>?,./;“”‘’\n]+"
    document = re.sub(punctuation, ' ', document)
    if isinstance(document, str):
        document = document
    else:
        raise ValueError('Document is not string!')
    document = document.strip()
    sentences = nltk.sent_tokenize(document)
    sentences = [sentence.strip() for sentence in sentences]
    return sentences

In [None]:
def get_info_wave(info_file):
    info_content = open(info_file,'r').read()
    sentences = parse_document(info_content) 
    tokenized_sentences = [nltk.word_tokenize(sentence) for sentence in sentences] # tokenize sentences
    tokens = [token.lower() for token in tokenized_sentences[0]] # lower tokens
    
    waves_list = ['gamma', 'beta', 'alpha', 'theta','delta']
    info_wave = []

    for i in waves_list:
        if any(i in waves for waves in tokens):
            info_wave.append(i)
            
    return info_wave

In [None]:
def extract_sampling_frequency(edf_filename):
    edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)
    sampling_frequency = int(edf.info['sfreq'])
    
    return sampling_frequency

In [None]:
def extract_data_and_labels(edf_filename, summary_text):
    folder, basename = os.path.split(edf_filename)
    
    edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)
    sampling_frequency = int(edf.info['sfreq'])
    
    edf.load_data()
    selected_ch_names = []
    wanted_elecs = ['C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2',
                    'P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6','A1','A2'] 
 
    for wanted_part in wanted_elecs:
        wanted_found_name = []
        for ch_name in edf.ch_names:
            if ' ' + wanted_part + '-' in ch_name:
                wanted_found_name.append(ch_name)
        assert len(wanted_found_name) == 1
        selected_ch_names.append(wanted_found_name[0])

    edf = edf.pick_channels(selected_ch_names)
    
    assert np.array_equal(sorted(edf.ch_names), sorted(selected_ch_names))
    
    n_sensors = 0
    n_sensors += 21

    assert len(edf.ch_names)  == n_sensors    
    
    X = edf.get_data().astype(np.float32) * 1e6 # to mV
    
    y = np.zeros(int(X.shape[1]/sampling_frequency), dtype=np.int64)
    
    i_text_start = summary_text.index(basename)

    if 'File Name' in summary_text[i_text_start:]:
        i_text_stop = summary_text.index('File Name', i_text_start)
    else:
        i_text_stop = len(summary_text)
    assert i_text_stop > i_text_start

    file_text = summary_text[i_text_start:i_text_stop]
    
    if 'Seizure Start' in file_text:
        start_sec = re.findall(r"Seizure Start Time: ([0-9]*) seconds", summary_text)
        end_sec = re.findall(r"Seizure End Time: ([0-9]*) seconds", summary_text)
        
        first_seizure = int((start_sec[0]))
        last_seizure = int((end_sec[len(end_sec)-1]))
        
        assert len(start_sec) == len(end_sec)
        
        for i in range(0,len(start_sec)):
            i_seizure_start = int((start_sec[i]))
            i_seizure_stop = int((end_sec[i]))
            y[i_seizure_start:i_seizure_stop] = 1
            

    X_final = []
    y_final = []
    start_cut = max(0,first_seizure-60)
    last_cut = min(last_seizure+60, int(X.shape[1]/sampling_frequency))

    for i in range(start_cut,last_cut): #take 60 seconds before and after seizures
        X_final.append(X[:,sampling_frequency*i:sampling_frequency*(i+1)])
        y_final.append(y[i])
        

    assert len(X_final) == len(y_final)
    return X_final,y_final

In [None]:
def fft(data):
    axis = data.ndim - 1
    data1 = np.fft.fft(data, axis=axis) 
    data2 = np.absolute(data1)  # take absolute value
    data3 = data2.ravel() 
    out = []
    out.append(data3)

    for d in out:
        assert d.ndim == 1
            
    return np.concatenate(out, axis=0)

def get_file(edf_file_names,summary_file):
    X_final,y = extract_data_and_labels(edf_file_names, summary_content)
    
    output = []
    for i in range(0,len(X_final)):
        final = fft(X_final[i])
        output.append(final)
        
    return output, y

In [None]:
def get_fft(snippet, Fs):
    snippet_time = len(snippet)/Fs
    Ts = 1.0/Fs; 
    t = np.arange(0,snippet_time,Ts) 
    y = snippet
    n = len(y)
    k = np.arange(n)
    T = n/Fs
    frq = k/T 
    frq = frq[range(n//2)] 
    Y = np.fft.fft(y) 
    Y = Y[range(n//2)]
    return frq,abs(Y)

In [None]:
def selected_band(f,Y):
    theta_range = (4,8)
    alpha_range = (8,14)
    beta_range = (14,31)
    gamma_range = (30,50)
    delta_range = (0,4)
    
    theta = Y[(f>theta_range[0]) & (f<=theta_range[1])]
    alpha = Y[(f>alpha_range[0]) & (f<=alpha_range[1])]
    beta = Y[(f>beta_range[0]) & (f<=beta_range[1])]
    gamma = Y[(f>gamma_range[0]) & (f<=gamma_range[1])]
    delta = Y[(f>delta_range[0]) & (f<=delta_range[1])]

    return gamma, alpha, theta, delta, beta

In [None]:
def final_selected_band(selected_bands,info_wave):
    final_selected_bands = []

    for i,j in selected_bands.items():
        if i in info_wave:
            final_selected_bands.append(j)
            
    out = []
    for i in range(0,len(final_selected_bands)):
        out.append(final_selected_bands[i].ravel())

    final = []
    for i in range(0,len(out)):
        for j in out[i]:
            final.append(j)
            
    return final

In [None]:
def get_one_file(edf_file_names,summary_file,info_file):
    summary_content = open(summary_file,'r').read()
    X_final,y = extract_data_and_labels(edf_file_names, summary_content)
    sampling_frequency = extract_sampling_frequency(edf_file_names)
    
    info_wave = get_info_wave(info_file)
    
    output = []
    for i in range(0,len(X_final)):
        f, Y = get_fft(X_final[i],sampling_frequency)
        gamma, alpha, theta, delta, beta = selected_band(f, Y)
        selected_bands = {'gamma': gamma, 'alpha': alpha, 'theta': theta, 'delta': delta, 'beta': beta}
        final = final_selected_band(selected_bands,info_wave)
        output.append(final)
        
        
    return output, y

In [None]:
def get_one_whole_file(edf_file_names,summary_file,info_file):
    summary_content = open(summary_file,'r').read()
    X_final,y = extract_data_and_labels(edf_file_names, summary_content)
    sampling_frequency = extract_sampling_frequency(edf_file_names)
    
    info_wave = ['gamma','alpha','theta','delta','beta']
    
    output = []
    for i in range(0,len(X_final)):
        f, Y = get_fft(X_final[i],sampling_frequency)
        gamma, alpha, theta, delta, beta = selected_band(f, Y)
        selected_bands = {'gamma': gamma, 'alpha': alpha, 'theta': theta, 'delta': delta, 'beta': beta}
        final = final_selected_band(selected_bands,info_wave)
        output.append(final)
        
        
    return output, y

In [None]:
def Resample(data):
    """
    Resample time-series data.
    """
    sampling_freq = 250
    
    if len(data[0]) >= sampling_freq:
        out = resample(data, sampling_freq, axis=1)

    return out

In [None]:
all_TNSZ_file = ['00008889_s003_t010','00009044_s001_t000','00008444_s003_t001','00008444_s003_t002','00008444_s003_t003','00008444_s003_t004','00008444_s003_t005','00008444_s003_t010','00008444_s003_t012','00008444_s004_t004','00008889_s004_t000','00008889_s004_t001','00008889_s004_t002','00008889_s004_t003','00008889_s004_t004','00008889_s004_t005','00008889_s004_t006','00008889_s004_t007','00008889_s004_t008','00008889_s004_t009','00008889_s004_t010','00008889_s004_t011','00008889_s004_t012','00008889_s004_t013','00008889_s004_t014','00008889_s004_t015','00008889_s005_t004','00008889_s005_t005']
all_TCSZ_file = ['00010158_s001_t001','00010088_s010_t001','00000906_s005_t000','00000258_s003_t002','00000258_s003_t003','00000258_s003_t004','00008889_s002_t002','00008889_s002_t003','00008889_s002_t005','00008889_s002_t006','00008889_s002_t008','00008889_s002_t010','00008889_s003_t001','00008889_s003_t006','00009578_s003_t003','00009578_s004_t001','00009578_s004_t002','00009578_s004_t004']
all_MYSZ_file = ['00008606_s001_t000']
all_SPSZ_file = ['00006546_s021_t002','00006546_s024_t000','00006546_s024_t001','00008527_s003_t003','00008527_s004_t000','00008527_s004_t002','00008527_s004_t003','00008616_s001_t000']
all_ABSZ_file = ['00008608_s001_t000']
all_CPSZ_file = ['00000883_s002_t000','00006904_s004_t002','00008544_s002_t004','00006904_s004_t003','00008544_s004_t006','00006904_s004_t004','00008544_s004_t007','00006904_s005_t000','00008544_s004_t008','00001981_s009_t001','00006904_s005_t001','00008544_s004_t010','00002806_s001_t001','00006904_s005_t002','00008544_s004_t011','00002806_s001_t002','00006904_s007_t000','00008544_s004_t012','00002806_s001_t003','00006904_s007_t001','00008544_s005_t000','00002806_s001_t004','00006904_s007_t002','00008544_s005_t001','00002806_s001_t006','00006904_s007_t003','00008544_s005_t002','00002806_s001_t007','00006904_s007_t004','00008544_s005_t003','00004456_s012_t002','00006904_s007_t005','00008544_s005_t005',
'00004456_s012_t003','00006904_s007_t006','00008544_s005_t006','00005479_s003_t000','00006904_s008_t001','00008544_s005_t007','00005479_s004_t000','00006904_s008_t002','00006535_s005_t007','00006986_s001_t001','00006535_s006_t006','00008295_s001_t000','00006546_s025_t003','00008345_s001_t000','00008615_s001_t000','00006811_s001_t000','00008453_s005_t000','00013407_s001_t000','00008460_s001_t000','00013407_s001_t004',
'00006904_s004_t000','00008544_s001_t000','00013407_s001_t013','00006904_s004_t001','00008544_s002_t001']

In [None]:
output_all = []
y_all = []


for i in all_XXXX_file[:]:
    edf_file_names = "../clips/{}/{}.edf".format(i,i)
    summary_file = "../clips/{}/{}_summary.txt".format(i,i)
    info_file = "../clips/{}/{}.txt".format(i,i)
    
    summary_content = open(summary_file,'r').read()
    X_final,y = extract_data_and_labels(edf_file_names, summary_content)
    
    #input 1
    output, y = get_file(edf_file_names,summary_file)
    #input 2
    output, y = get_one_whole_file(edf_file_names,summary_file,info_file)
    #input 3
    output, y = get_one_file(edf_file_names,summary_file,info_file)
    
    if len(output[0]) != 0:
        output = Resample(output)
        #combine
        output_all.append(output.tolist())
        y_all.append(y)


output_final = []
y_final = []

for i in range(0, len(y_all)):
    for j in range(0,len(y_all[i])):
        y_final.append(y_all[i][j])
        
y_final = np.array(y_final)

for i in range(0, len(output_all)):
    for j in range(0,len(output_all[i])):
        output_final.append(output_all[i][j])

end = time.time()

In [None]:
scoring = {
        'ACC': 'accuracy',
        'AUC': 'roc_auc',
        'sensitivity': make_scorer(recall_score),
        'specificity': make_scorer(recall_score,pos_label=0)
    }

In [None]:
clf_SVC = SVC(C= 1000,kernel='rbf',gamma= 1e-09,random_state = 0 )

scores_SVC = []
score_SVC = cross_validate(clf_SVC, output_final, y_final, cv=KFold(n_splits=5, shuffle=True, random_state=0), scoring= scoring)#.mean()
scores_SVC.append(score_SVC['test_ACC'].mean())
scores_SVC.append(score_SVC['test_AUC'].mean())
scores_SVC.append(score_SVC['test_sensitivity'].mean())
scores_SVC.append(score_SVC['test_specificity'].mean())


In [None]:
scores_SVC

In [None]:
clf_RF = RandomForestClassifier(n_estimators = 100, min_samples_split = 5, bootstrap=False, random_state=0)

scores_RF = []

score_RF = cross_validate(clf_RF, output_final, y_final, cv=KFold(n_splits=5, shuffle=True, random_state=0), scoring= scoring)
scores_RF.append(score_RF['test_ACC'].mean())
scores_RF.append(score_RF['test_AUC'].mean())
scores_RF.append(score_RF['test_sensitivity'].mean())
scores_RF.append(score_RF['test_specificity'].mean())

In [None]:
scores_RF