In [5]:
import numpy as np
import pandas as pd
import torch
import pickle
import matplotlib.pyplot as plt
from lpne.models import DcsfaNmf

MODEL_FILE = "/hpc/home/mk423/Anxiety/FullDataWork/Models/Final_mt_Model_500_epochs.pt"
DATA_PATH = "/work/mk423/Anxiety/"
PROJECT_PATH = "/hpc/home/mk423/Anxiety/FullDataWork/Projections/"
FIGURE_PATH = "/hpc/home/mk423/Anxiety/FullDataWork/Figures/"
data_file = DATA_PATH + "EPM_{}_dict_May_17.pkl"
proj_file = PROJECT_PATH + "EPM_{}_Projections.csv"
mean_file = PROJECT_PATH + "EPM_{}_mean_scores.csv"

model = torch.load(MODEL_FILE,map_location="cpu")
model.device="cpu"


old_feature_list = ["X_power_1_2","X_coh_1_2","X_gc_1_2"]
feature_weights = [10,1,1]



import os, sys
umc_data_tools_path = "/hpc/home/mk423/Anxiety/Universal-Mouse-Code/"
sys.path.append(umc_data_tools_path)
import umc_data_tools as umc_dt

def get_3_net_aucs(s,y,y_group=None):
    
    auc_mean_list = []
    auc_stderr_list =[]
    
    if y_group is None:
        y_group = np.ones(s.shape[0])
        
    for i in range(3):
        auc_dict = umc_dt.lpne_auc(y,y,y_group,s[:,i].reshape(-1,1),mannWhitneyU=True)
        mean = np.mean([auc_dict[key][0] for key in auc_dict.keys() if key != "auc_method"])
        stderr = np.std([auc_dict[key][0] for key in auc_dict.keys() if key != "auc_method"]) / np.sqrt(len(auc_dict.keys()))
        
        auc_mean_list.append(mean)
        auc_stderr_list.append(stderr)
        
    return auc_mean_list, auc_stderr_list

def get_3_net_auc_dict(s,y,y_group=None):
    auc_dict_list = []
    
    if y_group is None:
        y_group = np.ones(s.shape[0])
        
    for i in range(3):
        auc_dict = umc_dt.lpne_auc(y,y,y_group,s[:,i].reshape(-1,1),mannWhitneyU=True)
        auc_dict_list.append(auc_dict)
        
    return auc_dict_list

In [8]:
splits = ["train","val"]
df_list = []
for split in splits:
    with open(data_file.format(split),"rb") as f:
        epm_test_dict = pickle.load(f)

    X_test = np.hstack([epm_test_dict[feature]*weight for feature,weight in zip(old_feature_list,feature_weights)])
    y_test = ~(epm_test_dict['y_ROI']%2).astype(bool)
    y_in_task_mask_test= ~epm_test_dict['y_Homecage'].astype(bool)
    y_mouse_test = epm_test_dict['y_mouse']
    y_time_test = epm_test_dict['y_time']
    test_nan_mask = (epm_test_dict['y_ROI'] > 0)
    epm_y_expDate_test = epm_test_dict['y_expDate']


    epm_test_auc = model.score(X_test,
                              y_in_task_mask_test.reshape(-1,1),
                              y_mouse_test,
                              return_dict=True)

    epm_mean_test_auc = np.mean([epm_test_auc[key] for key in epm_test_auc.keys()])
    epm_stderr_test_auc = np.std([epm_test_auc[key] for key in epm_test_auc.keys()]) / np.sqrt(len(epm_test_auc.keys()))

    print("EPM {} auc: {:.3f} +/- {:.3f}".format(split,epm_mean_test_auc,epm_stderr_test_auc))
    print(epm_test_auc)
    
    s_epm_test = model.project(X_test)
    auc_epm_test = [epm_test_auc[mouse][0] for mouse in y_mouse_test]
    auc_epm_test_3_net = get_3_net_auc_dict(s_epm_test,
                                            y_in_task_mask_test,
                                            y_mouse_test)

    coeffs = np.abs(model.classifier[0].weight[0].detach().cpu().numpy())
    agg_score = s_epm_test[:,:3] @ coeffs
    proj_dict = {
        "agg score":agg_score,
        "net 1 scores":s_epm_test[:,0],
        "net 2 scores":s_epm_test[:,1],
        "net 3 scores":s_epm_test[:,2],
        "in-task":y_in_task_mask_test,
        "mouse":y_mouse_test,
        "time":y_time_test,
        "expDate":epm_y_expDate_test,
        "auc (Homecage vs. Task)":auc_epm_test,
        "net 1 auc (Homecage vs. Task)":[auc_epm_test_3_net[0][mouse][0] for mouse in y_mouse_test],
        "net 2 auc (Homecage vs. Task)":[auc_epm_test_3_net[1][mouse][0] for mouse in y_mouse_test],
        "net 3 auc (Homecage vs. Task)":[auc_epm_test_3_net[2][mouse][0] for mouse in y_mouse_test],
        "roi":epm_test_dict["y_ROI"],
        "velocity":epm_test_dict["y_vel"],

    }

    df_proj = pd.DataFrame.from_dict(proj_dict)
    #df_proj.to_csv(proj_file.format(split))
    df_list.append(df_proj)

EPM train auc: 0.997 +/- 0.001
{'Mouse04201': [0.9906270262023723], 'Mouse04202': [0.9986905960992583], 'Mouse04205': [0.9790123456790123], 'Mouse04215': [0.9996155865586558], 'Mouse0630': [1.0], 'Mouse0634': [0.9999857366994722], 'Mouse0643': [1.0], 'Mouse1551': [0.999028872371419], 'Mouse39114': [0.995920745920746], 'Mouse39124': [0.9894177294790418], 'Mouse39133': [0.9914634804185996], 'Mouse6291': [1.0], 'Mouse6292': [1.0], 'Mouse6293': [0.9999674160964483], 'Mouse69064': [0.9976228847703466], 'Mouse69074': [0.9860350492880613], 'Mouse8580': [0.9997766092724336], 'Mouse8581': [0.9999818594104308], 'Mouse8582': [1.0], 'Mouse8891': [1.0], 'Mouse8894': [1.0]}
EPM val auc: 0.997 +/- 0.001
{'Mouse04193': [0.9933780691378223], 'Mouse0633': [1.0], 'Mouse0642': [0.9999826144404458], 'Mouse39125': [0.9987444044109618], 'Mouse69065': [0.994653564290473]}


In [9]:
df_allTrain = pd.concat(df_list)

In [10]:
df_allTrain

Unnamed: 0,agg score,net 1 scores,net 2 scores,net 3 scores,in-task,mouse,time,expDate,auc (Homecage vs. Task),net 1 auc (Homecage vs. Task),net 2 auc (Homecage vs. Task),net 3 auc (Homecage vs. Task),roi,velocity
0,1.345047,0.037202,0.056049,0.063754,False,Mouse04201,1,090521,0.990627,0.980835,0.989302,0.619165,-2147483648,-2147483648
1,0.405382,0.022869,0.008693,0.050391,False,Mouse04201,2,090521,0.990627,0.980835,0.989302,0.619165,-2147483648,-2147483648
2,1.112978,0.035689,0.042204,0.082552,False,Mouse04201,3,090521,0.990627,0.980835,0.989302,0.619165,-2147483648,-2147483648
3,2.381636,0.075571,0.094152,0.098962,False,Mouse04201,4,090521,0.990627,0.980835,0.989302,0.619165,-2147483648,-2147483648
4,2.206645,0.068926,0.087007,0.111706,False,Mouse04201,5,090521,0.990627,0.980835,0.989302,0.619165,-2147483648,-2147483648
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3255,4.309340,0.081131,0.206747,0.093246,True,Mouse69065,758,100621,0.994654,0.982828,0.994370,0.476062,2,1
3256,1.596412,0.015531,0.080419,0.142944,True,Mouse69065,759,100621,0.994654,0.982828,0.994370,0.476062,3,0
3257,3.739954,0.062140,0.186490,0.030196,True,Mouse69065,760,100621,0.994654,0.982828,0.994370,0.476062,3,3
3258,4.365373,0.085112,0.211300,0.011862,True,Mouse69065,761,100621,0.994654,0.982828,0.994370,0.476062,1,12


In [11]:
df_allTrain.to_csv(proj_file.format("TrainVal"),index=False)