In [34]:
import pandas as pd
rna_data = pd.read_csv("gene_data/rna_common_complete.csv")
rna_data = rna_data.sort_values(by=['sn','period']).reset_index(drop=True)

In [35]:
X_og_shape = rna_data.drop(['sn','group','caarms_status','period'],axis=1).values
X_reshaped = X_og_shape.reshape(len(set(rna_data['sn'])), 3, X_og_shape.shape[1])
labels_group = rna_data[rna_data['period'] == 24]['group'].values
labels = [0 if i == 'C' else 1 for i in labels_group]

In [36]:
import torch
import numpy as np
from neucube.utils import SNR
from neucube.utils import interpolate
from neucube.encoder import Delta

ratios = SNR(X_reshaped[:,0,:], labels)
top_idx = torch.argsort(ratios, descending=True)[0:20]
X_reshaped_topidx = X_reshaped[:,:,top_idx]
interpolated_X = interpolate(X_reshaped_topidx, num_points=104)

encoder = Delta(threshold=0.008)
X = encoder.encode_dataset(interpolated_X)
y = torch.tensor(labels)

In [37]:
neuron_parm_dict = { 
    'rs' : {'a': 0.02, 'b': 0.2, 'c': -65, 'd': 8}, 
    'ch' : {'a': 0.02, 'b': 0.55, 'c': -45, 'd': 4},
    'ib' : {'a': 0.06, 'b': 0.55, 'c': -55, 'd': 3},
    'fs' : {'a': 0.1, 'b': 0.2, 'c': -65, 'd': 2},
}

In [None]:
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn import svm, metrics
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
from tqdm import tqdm

from neucube import IzhReservoir, Reservoir
from neucube.sampler import SpikeCount, MeanFiringRate, TemporalBinning, ISIstats, DeSNN
from neucube.utils import SeparationIndex

num_folds = 10
kf = KFold(n_splits=num_folds)

samplers = [SpikeCount(), MeanFiringRate(), TemporalBinning(bin_size=10), ISIstats(), DeSNN()]
sampler_names = ['Spike Count', 'Mean Firing Rate', 'Temporal Binning', 'ISI Stats', 'DeSNN']
neuron_types = ['rs', 'ch', 'ib', 'mix']

result_dict_full = {}
result_dict_avg = {}

for n_type in neuron_types:
    sampler_results_full = {}
    sampler_results_avg = {}
    print(f"Neuron Type: {n_type}")
    for sampler, s_names in zip(samplers, sampler_names):
        print(f"Sampler: {s_names}")
        true_labels = []
        predicted_labels = []
        separation_values = []
        sampler_acc_fold = []
        sampler_mcc_fold = []
        for train_index, test_index in tqdm(kf.split(X)):
            X_train_fold, X_test_fold = X[train_index], X[test_index]
            y_train_fold, y_test_fold = y[train_index], y[test_index]

            izh_res = IzhReservoir(inputs=X.shape[2], c=0.7, l=0.18, input_conn_prob=0.85)
            if n_type == 'mix':
                init_n_type = np.random.choice(['rs','ch','ib'], izh_res.n_neurons, replace=True)
                fs_indices = np.random.choice(len(init_n_type), int(0.2 * len(init_n_type)), replace=False)
                init_n_type[fs_indices] = 'fs'
                a, b, c, d = [torch.tensor(list(map(lambda x: neuron_parm_dict[x][i], init_n_type))) for i in ['a', 'b', 'c', 'd']]
                izh_res.update_parms(a=a, b=b, c=c, d=d)
            else:
                izh_res.set_exc_parms(**neuron_parm_dict[n_type])
                #izh_res.set_inh_parms(a=0.01, b=0.2, c=-65, d=8)
                izh_res.set_inh_parms(a=0.1, b=0.2, c=-65, d=2)

            X_train_opt_spike = izh_res.simulate(X_train_fold, mem_thr=30, train=False, verbose=False)
            X_test_opt_spike = izh_res.simulate(X_test_fold, mem_thr=30, train=False, verbose=False)
            X_train_state_vec = sampler.sample(X_train_opt_spike)
            X_test_state_vec = sampler.sample(X_test_opt_spike)

            param_grid = {'C': [2, 3, 4, 5, 6, 7, 8], 'gamma': [0.1, 0.01, 0.001], 'kernel': ['rbf', 'linear', 'poly']}
            svm_model = svm.SVC()
            mcc_scorer = make_scorer(metrics.matthews_corrcoef)
            grid_search = GridSearchCV(estimator=svm_model, param_grid=param_grid, cv=10, scoring={'accuracy': 'accuracy', 'mcc': mcc_scorer}, refit='mcc')
            grid_search.fit(X_train_state_vec, y_train_fold)
            y_pred = grid_search.best_estimator_.predict(X_test_state_vec)

            true_labels.extend(y_test_fold)
            predicted_labels.extend(y_pred)
            separation_values.append([SeparationIndex(X_train_state_vec, y_train_fold), SeparationIndex(X_test_state_vec, y_test_fold)])
            sampler_acc_fold.append(accuracy_score(y_test_fold, y_pred))
            sampler_mcc_fold.append(metrics.matthews_corrcoef(y_test_fold, y_pred))
        
        sampler_results_full[s_names] = {'accuracy': sampler_acc_fold, 'mcc': sampler_mcc_fold, 'separation': separation_values}
        sampler_results_avg[s_names] = {'accuracy': np.mean(sampler_acc_fold), 'mcc': np.mean(sampler_mcc_fold)}
        #update sampler_results dict here
        #make it possible to select acc, mcc
    #update result_dict here
    result_dict_full[n_type] = sampler_results_full
    result_dict_avg[n_type] = sampler_results_avg

        # Calculate accuracy
        # accuracy = accuracy_score(true_labels, predicted_labels)
        # print("10-Fold Cross-Validation Accuracy:", accuracy)
        # print(confusion_matrix(true_labels, predicted_labels))
pd.DataFrame(result_dict_full).to_csv('result_full.csv', sep='|', index=False)
pd.DataFrame(result_dict_avg).to_csv('result_avg.csv', sep='|', index=False)