In [2]:
import numpy as np
import pandas as pd
import random
import os
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from joblib import dump
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import itertools
from sklearn.utils import shuffle
from scipy import signal
%matplotlib inline


from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV

from sklearn.feature_selection import SelectFdr, chi2

from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.metrics import accuracy_score

from utils.svm import preProcess, evaluate_set
from utils.visualize import showMe
from utils.augment import apply_augment
from config.default import *


%load_ext autoreload
%autoreload 2


In [3]:
def create_labels(X):
    y = []
    for i, r in enumerate(X):
        l = np.ones(X[r].shape[0])*i
        y = y + l.tolist()
    y = np.array(y)
    return y


In [4]:
root_path = 'C:/resources/EMG/'
post_fix = '_1s_cleaned' #'_1s_new' #
classes = settings['classes']


sessions_to_val = ['session_4'] # ['session_1','session_2','session_3','session_4']    #[] # 
subject_to_val = ['S001',  'S105']

# use session4 for validation
train_sessions = []
val_sessions = []

for subject in os.listdir(root_path):
        for session in os.listdir(os.path.join(root_path,subject)):
            if session in sessions_to_val or subject in subject_to_val:
                val_sessions.append(os.path.join(root_path,subject, session))
            else:
                train_sessions.append(os.path.join(root_path,subject, session))
    


train_records = {}
for c in classes:
    class_data = []
    for session in train_sessions:
        data = np.load(os.path.join(session,c+post_fix+'.npy'),allow_pickle=True)
        if data.shape[0] != 0:
            class_data.append(data)
        else:
            print(f"No data available for train for class {c}")
    
    train_records[c] = np.concatenate(class_data)
print(f"{len(train_sessions)} sessions loaded for training")


val_records = {}
for c in classes:
    class_data = []
    for session in val_sessions:
        data = np.load(os.path.join(session,c+post_fix+'.npy'),allow_pickle=True)
        if data.shape[0] != 0:
            class_data.append(data)
    if len(class_data) != 0:
        val_records[c] = np.concatenate(class_data)
    else:
        print(f"No data available for validation for class {c}")

print(f"{len(val_sessions)} sessions loaded for validation")

No data available for train for class Chew
No data available for train for class Chew
No data available for train for class Chew
No data available for train for class Smile
No data available for train for class Smile
No data available for train for class Smile
No data available for train for class Smile
No data available for train for class Smile
No data available for train for class Smile
No data available for train for class Smile
39 sessions loaded for training
19 sessions loaded for validation


In [13]:
n_channels = train_records["Rest"].shape[1]
input_length = train_records["Rest"].shape[2]


print('Train')
train_y = create_labels(train_records)
train_X = np.concatenate((list(train_records.values())), axis=0)
print(train_X.shape)
print(train_y.shape)


print('Validation:')
val_y = create_labels(val_records)
val_X = np.concatenate((list(val_records.values())), axis=0)
print(val_X.shape)
print(val_y.shape)

Train
(6315, 4, 500)
(6315,)
Validation:
(3033, 4, 500)
(3033,)


In [14]:
train_X, train_y = apply_augment(train_X, train_y)
print("After augmentation")
print(train_X.shape)
print(train_y.shape)

After augmentation
(18945, 4, 500)
(18945,)


In [15]:
# Reshape to SVM
train_X = train_X.reshape(train_X.shape[0], n_channels*input_length)
val_X = val_X.reshape(val_X.shape[0], n_channels*input_length)
print(train_X.shape)
print(val_X.shape)


(18945, 2000)
(3033, 2000)


In [8]:
# SMALLER C -> better fit
# HIGHER gamma -> better fit
#param_grid = {'C': [1, 10, 100,1000], 'gamma': [1,0.1,0.01,0.001,0.0001]} #acc 88 test acc 45
#param_grid = {'C': [100,1000], 'gamma': [0.01,0.001,0.0001]} #slow

#param_grid = {'C': [100000,1000000], 'gamma': [0.000001,0.0000001]} 
param_grid = {'C': [10], 'gamma': [0.01]}
#param_grid = {'C': [1, 10,100], 'gamma': [0,1, 0.01,0.001]}

In [16]:
accs = []
models = []
def grid(X_train,y_train, X_test, y_test):
    grid = GridSearchCV(SVC(),param_grid,refit=True,verbose=2)
    grid.fit(X_train,y_train)
    return grid.best_estimator_


X_train, X_test, y_train, y_test = train_test_split(train_X, train_y, test_size=0.33, random_state=42)

model = grid(X_train,y_train, X_test, y_test)

    

Fitting 5 folds for each of 1 candidates, totalling 5 fits
[CV] END ...................................C=10, gamma=0.01; total time=  51.4s
[CV] END ...................................C=10, gamma=0.01; total time=  51.9s
[CV] END ...................................C=10, gamma=0.01; total time=  52.4s
[CV] END ...................................C=10, gamma=0.01; total time=  57.8s
[CV] END ...................................C=10, gamma=0.01; total time= 1.1min


In [17]:
evaluate_set(model, train_sessions, classes, post_fix, log = False)

  0%|          | 0/39 [00:00<?, ?it/s]

Global accuracy: 98.77%
           Accuracy
Subject            
S002      97.000000
S004      97.666667
S005      99.000000
S006      98.666667
S007      97.333333
S008      99.333333
S009      99.000000
S010     100.000000
S011     100.000000
S101      99.666667
S102     100.000000
S103      99.000000
S104      96.333333
S106      99.000000


In [18]:
evaluate_set(model, val_sessions, classes, post_fix, log = False)

  0%|          | 0/19 [00:00<?, ?it/s]

Global accuracy: 84.26%
         Accuracy
Subject          
S001        84.00
S004        67.00
S005        82.00
S006        78.00
S007        82.00
S008        85.00
S009        92.00
S010        85.00
S101        92.00
S102        90.00
S104        93.00
S105        85.75
S106        76.00


In [28]:
dump(model, 'saved_models/svm_9subj_no_val.joblib') 

['saved_models/svm_9subj_no_val.joblib']