In [4]:
import torch
import numpy as np
import pickle
from lpne.models import DcsfaNmf
from lpne.plotting import circle_plot
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
from PIL import Image
import matplotlib.pyplot as plt
import os, sys

umc_data_tools_path = "/hpc/home/mk423/Anxiety/Universal-Mouse-Code/"
sys.path.append(umc_data_tools_path)
import umc_data_tools as umc_dt

N_COMPONENTS=30

fold = 4#int(os.environ['SLURM_ARRAY_TASK_ID'])

flx_data_path = "/work/mk423/Anxiety/flx_kf_dict_fold_{}.pkl".format(fold)
epm_data_path = "/work/mk423/Anxiety/epm_kf_dict_fold_{}.pkl".format(fold)
oft_data_path = "/work/mk423/Anxiety/oft_kf_dict_fold_{}.pkl".format(fold)
anx_info_dict = "/work/mk423/Anxiety/Anx_Info_Dict.pkl"

saved_model_path = "/work/mk423/Anxiety/kfold_models/"
saved_model_name = "bootstrap_321_net_{}_model.pt".format(fold)

results_path = "/hpc/home/mk423/Anxiety/FullDataWork/Validations/"
results_file = results_path + "bootstrap_321_net_{}_results.pkl".format(fold)

projection_save_path = "/hpc/home/mk423/Anxiety/FullDataWork/Projections/"
plots_path = "/hpc/home/mk423/Anxiety/FullDataWork/Figures/"


def reshapeData(X_psd,X_coh,n_rois,n_freqs,pow_features,coh_features,areas):
    X_3d = np.zeros((n_rois,n_rois,n_freqs))
    
    for i in range(n_rois):
        X_3d[i,i,:] = X_psd[i*n_freqs:(i+1)*n_freqs]
        
    
    split_coh_features = np.array([feature.split(' ')[0] for feature in coh_features])
    #print(split_coh_features)
    unique_coh_features = np.unique(split_coh_features)
    for i in range(n_rois):
        for j in range(n_rois):
            if i != j:
                area_1 = areas[i]
                area_2 = areas[j]
                temp_feature = area_1 + "-" + area_2
                temp_feature_2 = area_2 + "-" + area_1
                if temp_feature in unique_coh_features:
                    feature_mask = np.where(split_coh_features==temp_feature,True,False)
                    X_3d[i,j,:] = X_coh[feature_mask==1]
                    X_3d[j,i,:] = X_coh[feature_mask==1]

                elif temp_feature_2 in unique_coh_features:
                    feature_mask = np.where(split_coh_features==temp_feature_2,1,0)
                    X_3d[i,j,:] = X_coh[feature_mask==1]
                    X_3d[j,i,:] = X_coh[feature_mask==1]

                else:
                    print("temp_feature: {} not found".format(temp_feature))

    return X_3d

with open(flx_data_path,"rb") as f:
    flx_dict = pickle.load(f)
    
with open(epm_data_path,"rb") as f:
    epm_dict = pickle.load(f)
    
with open(oft_data_path,"rb") as f:
    oft_dict = pickle.load(f)
    
with open(anx_info_dict,"rb") as f:
    anxInfo = pickle.load(f)

info_dict = anxInfo
feature_groups = [(0,len(info_dict["powerFeatures"])),
                   (len(info_dict["powerFeatures"]),len(info_dict["powerFeatures"])+len(info_dict["cohFeatures"])),
                   (len(info_dict["powerFeatures"])+len(info_dict["cohFeatures"]),
                    len(info_dict["powerFeatures"])+len(info_dict["cohFeatures"])+len(info_dict["gcFeatures"]))]
                   
mt_X_train = np.vstack([flx_dict["X_train"],epm_dict["X_train"],oft_dict["X_train"]])
mt_y_train = np.hstack([flx_dict["y_train"],epm_dict["y_train"],oft_dict["y_train"]]).reshape(-1,1)
mt_y_train_3_net = np.hstack([mt_y_train,mt_y_train,mt_y_train])
mt_y_mouse_train = np.hstack([flx_dict["y_mouse_train"],epm_dict["y_mouse_train"],oft_dict["y_mouse_train"]])

mt_y_exp_train = np.hstack([np.ones(flx_dict["X_train"].shape[0])*0,
                           np.ones(epm_dict["X_train"].shape[0]),
                           np.ones(oft_dict["X_train"].shape[0])*2])
intercept_mask = OneHotEncoder().fit_transform(mt_y_mouse_train.reshape(-1,1)).todense()
sample_groups = OrdinalEncoder().fit_transform(mt_y_mouse_train.reshape(-1,1))

mt_X_val = np.vstack([flx_dict["X_val"],epm_dict["X_val"],oft_dict["X_val"]])
mt_y_val = np.hstack([flx_dict["y_val"],epm_dict["y_val"],oft_dict["y_val"]]).reshape(-1,1)
mt_y_val_3_net = np.hstack([mt_y_val,mt_y_val,mt_y_val])
mt_y_mouse_val = np.hstack([flx_dict["y_mouse_val"],epm_dict["y_mouse_val"],oft_dict["y_mouse_val"]])

model = torch.load(saved_model_path + saved_model_name)
model.eval()
#Multitask Performance
mt_train_auc = model.score(mt_X_train,mt_y_train)
mt_val_auc = model.score(mt_X_val,mt_y_val)

#FLX Performance

flx_y_train = np.hstack([flx_dict["y_train"].reshape(-1,1),flx_dict["y_train"].reshape(-1,1)])
flx_y_val = flx_dict["y_val"].reshape(-1,1)

flx_train_auc = model.score(flx_dict["X_train"],flx_y_train,
                           flx_dict['y_mouse_train'],return_dict=True)
flx_val_auc = model.score(flx_dict["X_val"],flx_y_val,
                          flx_dict["y_mouse_val"],return_dict=True)

#EPM Performance
epm_y_train = np.hstack([epm_dict["y_train"].reshape(-1,1),epm_dict["y_train"].reshape(-1,1)])
epm_y_val = epm_dict["y_val"].reshape(-1,1)

epm_train_auc = model.score(epm_dict["X_train"],epm_y_train,
                           epm_dict["y_mouse_train"],return_dict=True)
epm_val_auc = model.score(epm_dict["X_val"],epm_y_val,
                          epm_dict["y_mouse_val"],return_dict=True)

#OFT Performance
oft_y_train = np.hstack([oft_dict["y_train"].reshape(-1,1),oft_dict["y_train"].reshape(-1,1)])
oft_y_val = oft_dict["y_val"].reshape(-1,1)

oft_train_auc = model.score(oft_dict["X_train"],oft_y_train,
                            oft_dict['y_mouse_train'],return_dict=True)
oft_val_auc = model.score(oft_dict["X_val"],oft_y_val,
                          oft_dict['y_mouse_val'],return_dict=True)

print("\nflx train",flx_train_auc)
print("\nflx val",flx_val_auc)
print("\nepm train",epm_train_auc)
print("\nepm val",epm_val_auc)
print("\noft train",oft_train_auc)
print("\noft val",oft_val_auc)

s = model.project(mt_X_val)
X_sup_recon = model.get_comp_recon(torch.Tensor(s).to("cuda"),0)
X_sup_recon_2 = model.get_comp_recon(torch.Tensor(s).to("cuda"),1)
X_sup_recon_3 = model.get_comp_recon(torch.Tensor(s).to("cuda"),2)
X_recon = model.reconstruct(mt_X_val)

net_1_recon_contribution = np.mean(X_sup_recon/X_recon,axis=0)
net_2_recon_contribution = np.mean(X_sup_recon_2/X_recon,axis=0)
net_3_recon_contribution = np.mean(X_sup_recon_3/X_recon,axis=0)

results_dict = {
    "flx_train_auc":flx_train_auc,
    "flx_val_auc":flx_val_auc,
    "epm_train_auc":epm_train_auc,
    "epm_val_auc":epm_val_auc,
    "oft_train_auc":oft_train_auc,
    "oft_val_auc":oft_val_auc,
    "recon_cont_net_1":net_1_recon_contribution,
    "recon_cont_net_2":net_2_recon_contribution,
    "recon_cont_net_3":net_3_recon_contribution,
}

with open(results_file,"wb") as f:
    pickle.dump(results_dict,f)


flx train {'Mouse3191': [1.0], 'Mouse3192': [1.0], 'Mouse3193': [1.0], 'Mouse3194': [1.0], 'Mouse3202': [nan], 'Mouse3203': [1.0], 'Mouse61635': [1.0], 'Mouse78752': [1.0], 'Mouse99002': [nan], 'Mouse99003': [nan], 'Mouse99021': [nan]}

flx val {'Mouse61631': [0.5658370453552934], 'Mouse78744': [0.5738507948869839]}

epm train {'Mouse04193': [0.9996683507080713], 'Mouse04201': [1.0], 'Mouse04202': [1.0], 'Mouse04205': [1.0], 'Mouse04215': [1.0], 'Mouse0630': [1.0], 'Mouse0633': [1.0], 'Mouse0634': [1.0], 'Mouse0642': [1.0], 'Mouse0643': [1.0], 'Mouse1551': [1.0], 'Mouse39114': [1.0], 'Mouse39124': [1.0], 'Mouse39125': [1.0], 'Mouse39133': [1.0], 'Mouse6291': [1.0], 'Mouse6292': [1.0], 'Mouse6293': [1.0], 'Mouse69064': [1.0], 'Mouse69065': [1.0]}

epm val {'Mouse69074': [0.6388175077934113], 'Mouse8580': [0.5622400934804275], 'Mouse8581': [0.6155283446712019], 'Mouse8582': [0.6904761904761905], 'Mouse8891': [0.6197321265074425], 'Mouse8894': [0.792918057100482]}

oft train {'Mouse04191