In [88]:
import numpy as np
import pandas as pd
from scipy.io import loadmat
import matplotlib.pyplot as plt
from data_utils import trim_intervals, get_data
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split, GridSearchCV
import GAN
import torch

In [102]:
def PCA(data, k=70):
 # preprocess the data
    diff = np.diff(data)
    X = torch.from_numpy(diff)
    X_mean = torch.mean(X,0)
    X = X - X_mean.expand_as(X)

 # svd
    U,S,V = torch.svd(torch.t(X))
    return torch.mm(X,U[:,:k])

def load_data(train_size, test_size, pca_flag = False):                                              
    keep_channels=['C3']                                                           
    trial_len = 1.5                                                                
                                                                                   
    # X, y = get_data("../data/CLASubjectA1601083StLRHand.mat", trial_len, keep_channels)
    X, y = get_data("../data/CLASubjectB1512153StLRHand.mat", trial_len, keep_channels)
                                                                                   
    X = X[y != 3]                                                                  
    y = y[y != 3]                                                                  
    # 0 is left hand                                                               
    y[y == 1] = 0                                                                  
    # 1 is right hand                                                              
    y[y == 2] = 1                                                                  
    interval_len = .45                                                             
    X = trim_intervals(X, .15, interval_len)                                       
                                                                                   
    num_channels= len(keep_channels)                                               
    d2 = np.ceil(num_channels * interval_len / 0.005).astype(int)                  
    X = X.reshape(642, d2)                                                         
                             
        
    if pca_flag:
        X = PCA(X).numpy()
        
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_size, test_size=test_size)
                                                                                   
    return X_train, X_test, y_train, y_test

In [103]:
def shuffle(X, y):                                                              
    shape = X.shape[0]                                                          
    data = np.c_[X, y]                                                          
                                                                                
    np.random.shuffle(data)                                                     
                                                                                
    return data[:shape], data[-1]

In [104]:
def train_GAN(X_train, y_train):                                                
        gan = GAN.GAN((X_train, y_train), g_in=X_train.shape[1], g_hid=100, g_out=X_train.shape[1],
                       d_in=X_train.shape[1], d_hid=10, d_out=1)  
        gan.train(10000)                                                         
        return gan.generate_data(100).reshape((100,90))                         
                                                           

In [105]:
def classify(X_train, y_train, X_test, y_test):

    svc = SVC(gamma = 'scale')
    svc.fit(X_train, y_train)
    #print(svc.score(X_test, y_test))
    return svc.score(X_test, y_test)



In [106]:
accuracy, accuracy_PCA = 0, 0
trials = 5

for i in range(trials):
    
    X_train, X_test, y_train, y_test = load_data(0.8, 0.2, False) 
    accuracy += classify(X_train, y_train, X_test, y_test)

    X_train, X_test, y_train, y_test = load_data(0.8, 0.2, True) 
    accuracy_PCA += classify(X_train, y_train, X_test, y_test)
print("Classifier Accuracy:", accuracy/trials)
print("Classifier Accuracy after PCA:", accuracy_PCA/trials)
#plt.figure()

#plt.legend()
#plt.title('PCA of IRIS dataset')
#plt.show()

Classifier Accuracy: 0.7689922480620155
Classifier Accuracy after PCA: 0.5488372093023256
