In [1]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import pandas
import sys, os
from sklearn.model_selection import KFold
from lpne.models import DcsfaNmf
import torch

NewDataPath = "/work/mk423/Anxiety/New_FLX_Animals_April_12.pkl"
OldDataPath = "/work/mk423/Anxiety/FLX_{}_dict_old_features.pkl"

FEATURE_LIST = ["X_psd","X_coh","X_gc"]
OLD_FEATURE_LIST = ["X_power_1_2","X_coh_1_2","X_gc_1_2"]
FEATURE_WEIGHT = [10,1,1]

np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
newDict = pickle.load(open(NewDataPath,"rb"))
oldTrainDict = pickle.load(open(OldDataPath.format("train"),"rb"))
oldValDict = pickle.load(open(OldDataPath.format("validation"),"rb"))
oldTestDict = pickle.load(open(OldDataPath.format("test"),"rb"))

In [3]:
oldTrainDict.keys()

dict_keys(['X_psd', 'X_psd_first_30', 'X_ds', 'X_ds_first_30', 'y_mouse', 'y_mouse_first_30', 'y_expDate', 'y_expDate_first_30', 'y_time', 'y_time_first_30', 'mice', 'y_flx', 'y_flx_train_first_30', 'X_psd_full', 'X_ds_full', 'y_mouse_full', 'y_expDate_full', 'y_time_full', 'y_flx_full', 'X_power_1_2', 'X_power_1_2_first_30', 'X_power_1_2_full', 'X_coh_1_2', 'X_coh_1_2_first_30', 'X_coh_1_2_full', 'X_gc_1_2', 'X_gc_1_2_first_30', 'X_gc_1_2_full'])

In [4]:
X_new = np.hstack([newDict[feature] * weight for feature, weight in zip(FEATURE_LIST,FEATURE_WEIGHT)])
y_flx = newDict['y_flx']
y_hab = newDict['y_hab'].squeeze()
y_mouse = newDict['y_mouse']

X_new_hab = X_new[y_hab==1]
y_flx_hab = y_flx[y_hab==1]
y_mouse_hab = y_mouse[y_hab==1]

X_train = np.hstack([oldTrainDict[feature]*weight for feature,weight in zip(OLD_FEATURE_LIST,FEATURE_WEIGHT)])
y_train = oldTrainDict['y_flx']
y_mouse_train = oldTrainDict['y_mouse']

X_val = np.hstack([oldValDict[feature]*weight for feature,weight in zip(OLD_FEATURE_LIST,FEATURE_WEIGHT)])
y_val = oldValDict['y_flx']
y_mouse_val = oldValDict['y_mouse']

X_test = np.hstack([oldTestDict[feature]*weight for feature,weight in zip(OLD_FEATURE_LIST,FEATURE_WEIGHT)])
y_test = oldTestDict['y_flx']
y_mouse_test = oldTestDict['y_mouse']

In [5]:
X_full = np.vstack([X_new_hab,X_train,X_val,X_test])
y_flx_full = np.hstack([y_flx_hab.squeeze(),y_train.squeeze(),y_val.squeeze(),y_test.squeeze()])
y_mouse_full = np.hstack([y_mouse_hab.squeeze(),y_mouse_train.squeeze(),y_mouse_val.squeeze(),y_mouse_test.squeeze()])

In [17]:
np.unique(y_mouse).shape

(9,)

In [7]:
results_dict = pickle.load(open("/work/mk423/Anxiety/FLX_model_kfold/flx_kfold_cv_check_April_14th_2023.pkl","rb"))

In [10]:
print(len(results_dict['test_mice'][0]))

4


In [37]:
###KFold CV
TRAIN=False
if TRAIN:
    train_auc_list = []
    val_auc_list = []
    test_auc_list = []
    electome_list = []

    train_mice_list = []
    val_mice_list = []
    test_mice_list = []

    all_mice = np.unique(y_mouse_full)
    kf = KFold(n_splits=5)

    for i, (train_idx,test_idx) in enumerate(kf.split(all_mice)):

        train_mice = all_mice[train_idx]

        val_mice = np.random.choice(train_mice,size=3)
        train_mice = [mouse for mouse in train_mice if mouse not in val_mice]
        test_mice = all_mice[test_idx]

        train_mice_list.append(train_mice)
        val_mice_list.append(val_mice)
        test_mice_list.append(test_mice)

        print(i, train_mice, val_mice, test_mice)

        train_slice = np.array([1 if mouse in train_mice else 0 for mouse in y_mouse_full])
        val_slice = np.array([1 if mouse in val_mice else 0 for mouse in y_mouse_full])
        test_slice = np.array([1 if mouse in test_mice else 0 for mouse in y_mouse_full])

        X_kf_train = X_full[train_slice==1]
        X_kf_val = X_full[val_slice==1]
        X_kf_test = X_full[test_slice==1]

        y_kf_flx_train = y_flx_full[train_slice==1].reshape(-1,1)
        y_kf_flx_val = y_flx_full[val_slice==1].reshape(-1,1)
        y_kf_flx_test = y_flx_full[test_slice==1].reshape(-1,1)


        y_kf_mouse_train = y_mouse_full[train_slice==1]
        y_kf_mouse_val = y_mouse_full[val_slice==1]
        y_kf_mouse_test = y_mouse_full[test_slice==1]

        model = DcsfaNmf(n_components=20,
                         optim_name="SGD",
                        save_folder="/work/mk423/Anxiety/FLX_model_kfold/")

        model.fit(X_kf_train,y_kf_flx_train,n_epochs=1500,n_pre_epochs=400,nmf_max_iter=2000,pretrain=True,X_val=X_kf_val,y_val=y_kf_flx_val,verbose=True)
        torch.save(model,"/work/mk423/Anxiety/FLX_model_kfold/SGD_{}_fold_flx.pt".format(i))


        train_auc_list.append(model.score(X_kf_train,y_kf_flx_train,y_kf_mouse_train))
        val_auc_list.append(model.score(X_kf_val,y_kf_flx_val,y_kf_mouse_val))
        test_auc_list.append(model.score(X_kf_test,y_kf_flx_test,y_kf_mouse_test))
        electome_list.append(model.get_factor(0))

        print(train_auc_list[-1],val_auc_list[-1],test_auc_list[-1])
        
    results_dict = {
    "train_aucs":train_auc_list,
    "val_aucs":val_auc_list,
    "test_aucs":test_auc_list,
    "train_mice":train_mice_list,
    "val_mice":val_mice_list,
    "test_mice":test_mice_list,
    "electomes":electome_list,
    }

    with open("/work/mk423/Anxiety/FLX_model_kfold/flx_kfold_cv_check_April_14th_2023.pkl","wb") as f:
        pickle.dump(results_dict,f)
    
else:
    results_dict = pickle.load(open("/work/mk423/Anxiety/FLX_model_kfold/flx_kfold_cv_check_April_14th_2023.pkl","rb"))
    results_dict['perc_recon'] = []
    for fold in range(4):
        train_mice = results_dict["train_mice"][fold]
        val_mice = results_dict["val_mice"][fold]
        test_mice = results_dict["test_mice"][fold]
        
        train_slice = np.array([1 if mouse in train_mice else 0 for mouse in y_mouse_full])
        val_slice = np.array([1 if mouse in val_mice else 0 for mouse in y_mouse_full])
        test_slice = np.array([1 if mouse in test_mice else 0 for mouse in y_mouse_full])
        
        X_kf_train = X_full[train_slice==1]
        X_kf_val = X_full[val_slice==1]
        X_kf_test = X_full[test_slice==1]

        y_kf_flx_train = y_flx_full[train_slice==1].reshape(-1,1)
        y_kf_flx_val = y_flx_full[val_slice==1].reshape(-1,1)
        y_kf_flx_test = y_flx_full[test_slice==1].reshape(-1,1)


        y_kf_mouse_train = y_mouse_full[train_slice==1]
        y_kf_mouse_val = y_mouse_full[val_slice==1]
        y_kf_mouse_test = y_mouse_full[test_slice==1]

        model = torch.load("/work/mk423/Anxiety/FLX_model_kfold/SGD_{}_fold_flx.pt".format(fold))
        
        s = model.project(X_kf_test)
        sup_recon = model.get_comp_recon(torch.Tensor(s).to("cuda"),0)
        perc_recon = sup_recon / model.reconstruct(X_kf_test)
        perc_recon = np.mean(perc_recon,axis=0)
        results_dict['perc_recon'].append(perc_recon)
        print(fold)
        print("train auc: ",model.score(X_kf_train,y_kf_flx_train,y_kf_mouse_train))
        print("val auc: ",model.score(X_kf_val,y_kf_flx_val,y_kf_mouse_val))
        print("test auc: ",model.score(X_kf_test,y_kf_flx_test,y_kf_mouse_test))

0
train auc:  [1.]
val auc:  [0.73423706]
test auc:  [0.73402772]
1
train auc:  [1.]
val auc:  [0.81082277]
test auc:  [0.58491514]
2
train auc:  [1.]
val auc:  [0.73179101]
test auc:  [0.69662694]
3
train auc:  [1.]
val auc:  [0.63118101]
test auc:  [0.73878715]


In [38]:
with open("/work/mk423/Anxiety/FLX_model_kfold/flx_kfold_cv_check_April_14th_2023.pkl","wb") as f:
    pickle.dump(results_dict,f)

In [36]:
np.mean(perc_recon,axis=0).shape

(5152,)

In [25]:
s.shape

(14604, 20)

In [28]:
model.get_comp_recon(torch.Tensor(s).to("cuda"),0)

array([[0.03460997, 0.05778772, 0.08378834, ..., 0.03386989, 0.05048679,
        0.05238559],
       [0.05343117, 0.08921319, 0.12935317, ..., 0.05228863, 0.07794195,
        0.08087333],
       [0.03068154, 0.05122847, 0.07427787, ..., 0.03002546, 0.04475625,
        0.04643952],
       ...,
       [0.01996471, 0.03333475, 0.04833317, ..., 0.01953779, 0.02912322,
        0.03021854],
       [0.03472406, 0.05797822, 0.08406455, ..., 0.03398154, 0.05065322,
        0.05255828],
       [0.02154848, 0.03597916, 0.05216738, ..., 0.0210877 , 0.03143353,
        0.03261574]], dtype=float32)

In [None]:
len(train_mice)

In [None]:
len(val_mice)

In [None]:
len(test_mice)