In [1]:
import pickle
import numpy as np
import pandas as pd
from scipy.io import loadmat
from collections import Counter

import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

import warnings
warnings.filterwarnings(action='ignore')

with open("data/preprocessed_data.pkl", "rb") as f:
    data = pickle.load(f)

In [2]:
def get_spike_from_window_size(data, window_size):
    y = data['target']
    X = data[f'window_size_{window_size}']
    
    return X, y


def sample_neuron(X, n_sample=98):
    sample_index = np.random.permutation(98)[:n_sample]
    x_ = []
    
    for x in X:
        x_.append(x[sample_index])

    return x_


def train_test_split_(X, y, n_sample=98, test_size=0.2):
    X = sample_neuron(X, n_sample=n_sample)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, shuffle=True, stratify=y)
    
    return X_train, X_test, y_train, y_test
     
    
def fit_predict_models(X_train, y_train, X_test, y_test, window_size=500):
    lm = LogisticRegression().fit(X_train, y_train)
    lm_pred = lm.predict(X_test)
    lm_accuracy = compute_metrics(y_test, lm_pred)
    
    
    svc = SVC().fit(X_train, y_train)
    svc_pred = svc.predict(X_test)
    svc_accuracy = compute_metrics(y_test, svc_pred)
    
    
    tree = DecisionTreeClassifier().fit(X_train, y_train)
    tree_pred = tree.predict(X_test)
    tree_accuracy = compute_metrics(y_test, tree_pred)

    return lm_accuracy, svc_accuracy, tree_accuracy
    

def compute_metrics(y_test, pred):
    accuracy = accuracy_score(y_test, pred)
    
    return accuracy
    
    
def run(data, n_iter=1):
    window_size_list = [400]
    sampled_neuron_number_list = [10, 20, 30, 40, 50, 60, 70, 80, 90, 98]
    
    results = []
    
    for window_size in window_size_list:
        for sampled_neuron_number in sampled_neuron_number_list:
            for iter in range(n_iter):
                X, y = get_spike_from_window_size(data, window_size)

                X_train, X_test, y_train, y_test = train_test_split_(X, y, sampled_neuron_number)
                lm_accuracy, svc_accuracy, tree_accuracy = fit_predict_models(X_train, y_train, X_test, y_test, window_size)
                
                results.append([sampled_neuron_number, lm_accuracy, svc_accuracy, tree_accuracy])
    
    results = pd.DataFrame(results, columns=["sampled_neuron_number", "lm_accuracy", "svc_accuracy", "tree_accuracy"])
    
    return results

In [3]:
results = run(data, n_iter=100)

average_res = results.copy()
average_res.columns = ["sampled_neuron_number", "lm_accuracy_mean", "svc_accuracy_mean", "tree_accuracy_mean"]
average_res = average_res.groupby("sampled_neuron_number").mean()

std_res = results.copy()
std_res.columns = ["sampled_neuron_number", "lm_accuracy_std", "svc_accuracy_std", "tree_accuracy_std"]
std_res = std_res.groupby("sampled_neuron_number").std()

concated_res = pd.merge(average_res, std_res, on="sampled_neuron_number", how="inner")
concated_res.loc[:, ["lm_accuracy_mean", "lm_accuracy_std", 
                     "svc_accuracy_mean", "svc_accuracy_std",
                     "tree_accuracy_mean", "tree_accuracy_std",
                    ]]

Unnamed: 0_level_0,lm_accuracy_mean,lm_accuracy_std,svc_accuracy_mean,svc_accuracy_std,tree_accuracy_mean,tree_accuracy_std
sampled_neuron_number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
10,0.573937,0.086712,0.580562,0.090632,0.451562,0.085518
20,0.74625,0.057728,0.748312,0.067437,0.590375,0.076533
30,0.801312,0.042818,0.822562,0.043044,0.639875,0.055978
40,0.833063,0.03578,0.87575,0.031649,0.682688,0.048663
50,0.857562,0.036349,0.906687,0.027312,0.706375,0.043997
60,0.881437,0.028056,0.926438,0.023784,0.712,0.04786
70,0.892188,0.025552,0.940312,0.019266,0.727313,0.044715
80,0.907562,0.024235,0.946375,0.019056,0.732125,0.039178
90,0.915187,0.018954,0.955812,0.015521,0.7375,0.036335
98,0.918312,0.021801,0.95925,0.013918,0.748063,0.029418


In [4]:
concated_res.to_clipboard()