In [3]:
import itertools

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, recall_score, precision_score
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.naive_bayes import GaussianNB
import xgboost as xgb
from imblearn.over_sampling import SMOTE 

import dataloader
import bandpower
import preprocessing
from plot_confusion_matrix import plot_confusion_matrix


# Leave-one-subject out cross validation

In [4]:
show_acc = False
average_regions = True
output_cm = True
folder_name = 'BP_classification_raw'
label_formats = list(range(2,3))

for label_format in label_formats:

    X, Y, df_all = dataloader.read_data(label_format=label_format, data_folder='rawdata')
    
    # Remove bad trials
    X, Y, df_all = preprocessing.remove_trials(X, Y, df_all)

    if average_regions:
        X = preprocessing.avg_channels_into_regions(X, df_all)
    
    # Transform into power
    low, high = list(range(1,50)), list(range(2,51))
    powers, psds, freqs = bandpower.get_bandpower(X, 1000, low=low, high=high, dB_scale=True)
    
    '''
    # Use psd as feature
    psds = np.array(psds)
    psds = psds[..., freqs<50]
    freqs = freqs[freqs<50]
    powers = psds
    '''

    # Select models
    model_names = ['naive', 'svm', 'xgboost']
    sampling_names = ['original', 'SMOTE']
    #cv_modes = ['KFold', 'LOSO']
    preprocess_methods = ['None', 'PCAVar', 'PCAL1select', 'L1select']
    #cost_sensitive = [True, False]
    
    cv_modes = ['KFold']
    cost_sensitive = [False]
    
    for model_name, sampling_name, cv_mode, preprocess_name, cs in itertools.product(model_names, sampling_names, cv_modes, preprocess_methods, cost_sensitive):

        file_name = '%s/label%d_%s_%s_cs%s_%s_%s.csv'%(folder_name, label_format, model_name, sampling_name, str(cs), cv_mode, preprocess_name)
        description = 'Label: %d, Model: %s, Sampling: %s, Cost-sensitive: %s, CV mode: %s'%(label_format, model_name, sampling_name, str(cs), cv_mode)

        num_sample = len(X)
        if average_regions:
            X_power = np.array(powers)
            X_power = X_power.reshape((len(X_power),-1))
        else:
            # Average over channels
            X_power = np.zeros((num_sample, powers[1].shape[1]))
            for i in range(len(X)):
                X_power[i,:] = np.mean(powers[i],0)

        kf = KFold(n_splits=10, shuffle=True, random_state=23)
        S = df_all['subject'].values
        folds = kf.split(X_power) if cv_mode=='KFold' else np.unique(S)

        train_acc_list = []
        val_acc_list = []
        train_recall_list = []
        val_recall_list = []
        train_prec_list = []
        val_prec_list = []
        num_val_list = []
        chance_list = []
        fold_name_list = []

        if show_acc:
            print('Sub  \t Chance\t | Train | Val')
        for i_fold, fold in enumerate(folds):
            if cv_mode == 'KFold':
                train_indices = fold[0]
                val_indices = fold[1]
            else:
                subID = fold
                train_indices = np.where(S!=subID)[0]
                val_indices = np.where(S==subID)[0]

            X_train, Y_train = X_power[train_indices,...], Y[train_indices]
            X_val, Y_val = X_power[val_indices,...], Y[val_indices]

            # Preprocessing features
            if preprocess_name == 'L1select':
                scaler = StandardScaler()
                X_train = scaler.fit_transform(X_train)
                X_val = scaler.transform(X_val)
                X_train, X_val = preprocessing.select_features(X_train, X_val, Y_train)
            elif preprocess_name in ['PCAVar', 'PCAL1select']:
                n_components = 0.9 if preprocess_name=='PCAVar' else np.min(X_train.shape)
                scaler, pca = StandardScaler(), PCA(n_components=n_components)
                X_train = scaler.fit_transform(X_train)
                X_val = scaler.transform(X_val)
                
                X_train = pca.fit_transform(X_train)
                X_val = pca.transform(X_val)
                
                if preprocess_name == 'PCAL1select':
                    X_train, X_val = preprocessing.select_features(X_train, X_val, Y_train)
            
            # Resample training data
            if sampling_name == 'SMOTE':
                sm = SMOTE(random_state=23)
                X_train, Y_train = sm.fit_resample(X_train, Y_train)

            # Train classifier
            num_neg, num_pos = tuple(np.bincount(Y_train))
            if model_name in ['svm','xgboost','naive']:
                if model_name == 'svm':
                    if cs:
                        clf = svm.SVC(class_weight='balanced')
                    else:
                        clf = svm.SVC()
                elif model_name == 'xgboost':
                    if cs:
                        clf = xgb.XGBClassifier(scale_pos_weight=num_neg/num_pos)
                    else:
                        clf = xgb.XGBClassifier()
                elif model_name == 'naive':
                    if cs:
                        clf = GaussianNB(priors = [num_pos/len(Y_train), num_neg/len(Y_train)])
                    else:
                        clf = GaussianNB()
                        
                clf.fit(X_train, Y_train)

                # Test classifier
                pred_train = clf.predict(X_train)
                pred_val = clf.predict(X_val)
                
                train_acc = accuracy_score(Y_train, pred_train)
                val_acc = accuracy_score(Y_val, pred_val)
                
                train_recall = recall_score(Y_train, pred_train)
                val_recall = recall_score(Y_val, pred_val)
                
                train_prec = precision_score(Y_train, pred_train)
                val_prec = precision_score(Y_val, pred_val)
                
            fold_name = 'Fold %d'%(i_fold) if cv_mode == 'KFold' else 'Sub %d'%(subID)
            chance = np.sum(Y_train)/len(Y_train)
            chance = (1-chance) if chance<0.5 else chance

            train_acc_list.append(train_acc)
            val_acc_list.append(val_acc)
            train_recall_list.append(train_recall)
            val_recall_list.append(val_recall)
            train_prec_list.append(train_prec)
            val_prec_list.append(val_prec)
            num_val_list.append(len(Y_val))
            fold_name_list.append(fold_name)
            chance_list.append(chance)

            if show_acc:
                print('%s\t %.1f%%\t | %.1f%% | %.1f%%'%(fold_name, chance*100, train_acc*100, val_acc*100))
            
            if output_cm:
                fileName_cm = './results/' + file_name[:file_name.find('.')] + '_' + fold_name
                plot_confusion_matrix(Y_train, pred_train, np.array(['Normal','Increase']), 
                                     fileName='%s_train.png'%(fileName_cm), title=fold_name)
                plot_confusion_matrix(Y_val, pred_val, np.array(['Normal','Increase']), 
                                     fileName='%s_val.png'%(fileName_cm), title=fold_name)

        avg_acc = sum([val_acc_list[i]*num_val_list[i] for i in range(len(num_val_list))])/num_sample
        avg_recall = sum([val_recall_list[i]*num_val_list[i] for i in range(len(num_val_list))])/num_sample
        avg_prec = sum([val_prec_list[i]*num_val_list[i] for i in range(len(num_val_list))])/num_sample
        if show_acc:
            print('Average val acc: %.1f%%'%(avg_acc*100))

        # Save result as csv file
        df_result = pd.DataFrame({description: fold_name_list, 'Chance_train_acc': chance_list, 
                                  'Train_acc': train_acc_list, 'Val_acc': val_acc_list,
                                  'Train_recall' : train_recall_list, 'Val_recall': val_recall_list,
                                  'Train_precision' : train_prec_list, 'Val_precision': val_prec_list})
        df_result.loc[len(df_result),description] = 'Average'
        df_result.loc[len(df_result)-1, 'Val_acc'] = avg_acc
        df_result.loc[len(df_result)-1, 'Val_recall'] = avg_recall
        df_result.loc[len(df_result)-1, 'Val_precision'] = avg_prec
        df_result.to_csv('./results/%s'%(file_name))

Load data from .mat files...


  val = np.array(val, copy=False)
  return array(a, dtype, copy=False, order=order)


Calculating the bandpower of time-series data...
freqs:  [0.000e+00 2.000e-01 4.000e-01 ... 4.996e+02 4.998e+02 5.000e+02]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
