In [1]:
import numpy as np
import pandas as pd
import torch
import pickle
import matplotlib.pyplot as plt
from lpne.models import DcsfaNmf

MODEL_FILE = "/hpc/home/mk423/Anxiety/FullDataWork/Models/Final_mt_Model_500_epochs.pt"
DATA_PATH = "/work/mk423/Anxiety/"
PROJECT_PATH = "/hpc/home/mk423/Anxiety/FullDataWork/Projections/"
data_file = DATA_PATH + "OFT_test_dict_old_features_hand_picked.pkl"
proj_file = PROJECT_PATH + "OFT_Holdout_Projections_w_agg.csv"
mean_file = PROJECT_PATH + "OFT_Holdout_mean_scores_w_agg.csv"

model = torch.load(MODEL_FILE,map_location="cpu")
model.device="cpu"


old_feature_list = ["X_power_1_2","X_coh_1_2","X_gc_1_2"]
feature_weights = [10,1,1]

import os, sys
umc_data_tools_path = "/hpc/home/mk423/Anxiety/Universal-Mouse-Code/"
sys.path.append(umc_data_tools_path)
import umc_data_tools as umc_dt

def get_3_net_aucs(s,y,y_group=None):
    
    auc_mean_list = []
    auc_stderr_list =[]
    
    if y_group is None:
        y_group = np.ones(s.shape[0])
        
    for i in range(3):
        auc_dict = umc_dt.lpne_auc(y,y,y_group,s[:,i].reshape(-1,1),mannWhitneyU=True)
        mean = np.mean([auc_dict[key][0] for key in auc_dict.keys() if key != "auc_method"])
        stderr = np.std([auc_dict[key][0] for key in auc_dict.keys() if key != "auc_method"]) / np.sqrt(len(auc_dict.keys()))
        
        auc_mean_list.append(mean)
        auc_stderr_list.append(stderr)
        
    return auc_mean_list, auc_stderr_list

def get_3_net_auc_dict(s,y,y_group=None):
    auc_dict_list = []
    
    if y_group is None:
        y_group = np.ones(s.shape[0])
        
    for i in range(3):
        auc_dict = umc_dt.lpne_auc(y,y,y_group,s[:,i].reshape(-1,1),mannWhitneyU=True)
        auc_dict_list.append(auc_dict)
        
    return auc_dict_list

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open(data_file,'rb') as f:
    test_dict = pickle.load(f)
    
oft_X_test = np.hstack([test_dict[feature]*weight for feature,weight in zip(old_feature_list,feature_weights)])
oft_y_hc_test = test_dict['y_Homecage'].astype(bool)
oft_y_task_test = ~oft_y_hc_test
oft_y_ROI_test = test_dict['y_ROI']
oft_y_vel_test = test_dict['y_vel']
oft_y_mouse_test = test_dict['y_mouse']
oft_y_time_test = test_dict['y_time']
oft_y_expDate_test = test_dict['y_expDate']


In [3]:
len(np.unique(oft_y_mouse_test))

9

In [3]:
oft_test_auc = model.score(oft_X_test,
                            oft_y_task_test.reshape(-1,1),
                            oft_y_mouse_test,
                            return_dict=True)

oft_mean_test_auc = np.mean([oft_test_auc[key] for key in oft_test_auc.keys()])
oft_stderr_test_auc = np.std([oft_test_auc[key] for key in oft_test_auc.keys()]) / np.sqrt(len(oft_test_auc.keys()))
print("OFT test auc: {:.3f} +/- {:.3f}".format(oft_mean_test_auc,oft_stderr_test_auc))
print(oft_test_auc)

OFT test auc: 0.844 +/- 0.025
{'Mouse04203': [0.9039771437220919], 'Mouse39115': [0.8642398389570553], 'Mouse39121': [0.9121263527347176], 'Mouse39122': [0.7343512031678343], 'Mouse39132': [0.9265035101827903], 'Mouse39135': [0.8670534203529128], 'Mouse69061': [0.6935489654677808], 'Mouse69071': [0.8263132216620588], 'Mouse69075': [0.8662951995012468]}


In [11]:
s_oft_test = model.project(oft_X_test)
auc_oft_test = [oft_test_auc[mouse][0] for mouse in oft_y_mouse_test]
auc_oft_test_3_net = get_3_net_auc_dict(s_oft_test,
                                        oft_y_task_test,
                                        oft_y_mouse_test)
coeffs = np.abs(model.classifier[0].weight[0].detach().cpu().numpy())

s_agg_test = s_oft_test[:,:3] @ coeffs

proj_dict = {
    "composite scores":s_agg_test,
    "net 1 scores":s_oft_test[:,0],
    "net 2 scores":s_oft_test[:,1],
    "net 3 scores":s_oft_test[:,2],
    "in-task":oft_y_task_test,
    "mouse":oft_y_mouse_test,
    "time":oft_y_time_test,
    "expDate":oft_y_expDate_test,
    "auc (Homecage vs. Task)":auc_oft_test,
    "net 1 auc (Homecage vs. Task)":[auc_oft_test_3_net[0][mouse][0] for mouse in oft_y_mouse_test],
    "net 2 auc (Homecage vs. Task)":[auc_oft_test_3_net[1][mouse][0] for mouse in oft_y_mouse_test],
    "net 3 auc (Homecage vs. Task)":[auc_oft_test_3_net[2][mouse][0] for mouse in oft_y_mouse_test],
    "roi":oft_y_ROI_test,
    "velocity":oft_y_vel_test,
    
}

df = pd.DataFrame.from_dict(proj_dict)
df.to_csv(proj_file)
df.head()

Unnamed: 0,composite scores,net 1 scores,net 2 scores,net 3 scores,in-task,mouse,time,expDate,auc (Homecage vs. Task),net 1 auc (Homecage vs. Task),net 2 auc (Homecage vs. Task),net 3 auc (Homecage vs. Task),roi,velocity
0,0.47157,0.02642,0.012433,0.007595,False,Mouse04203,1,90421,0.903977,0.895378,0.890808,0.740237,,-2147484000.0
1,2.237217,0.085713,0.082628,0.027619,False,Mouse04203,3,90421,0.903977,0.895378,0.890808,0.740237,,-2147484000.0
2,2.893158,0.066529,0.133424,0.023311,False,Mouse04203,4,90421,0.903977,0.895378,0.890808,0.740237,,-2147484000.0
3,0.251323,0.005939,0.009126,0.056668,False,Mouse04203,5,90421,0.903977,0.895378,0.890808,0.740237,,-2147484000.0
4,2.85297,0.042559,0.145444,0.015247,False,Mouse04203,6,90421,0.903977,0.895378,0.890808,0.740237,,-2147484000.0


In [10]:
s_agg_test.shape

(5455,)

In [12]:
s = model.project(oft_X_test)
y_mouse = oft_y_mouse_test
y_task = oft_y_task_test

#Get per mouse average scores, average scores in HC and average scores in EPM
mouse_list = []
avg_score_list = []
avg_hc_score_list = []
avg_task_score_list = []

for mouse in np.unique(y_mouse):
    mouse_mask = y_mouse==mouse
    hc_mouse_mask = np.logical_and(mouse_mask,y_task==0)
    task_mouse_mask = np.logical_and(mouse_mask,y_task==1)
    avg_score = np.mean(s[mouse_mask==1,:3],axis=0)
    avg_score_hc = np.mean(s[hc_mouse_mask==1,:3],axis=0)
    avg_score_task = np.mean(s[task_mouse_mask==1,:3],axis=0)
    
    mouse_list.append(mouse)
    avg_score_list.append(avg_score)
    avg_hc_score_list.append(avg_score_hc)
    avg_task_score_list.append(avg_score_task)
    
avg_score_list = np.array(avg_score_list)
avg_hc_score_list = np.array(avg_hc_score_list)
avg_task_score_list = np.array(avg_task_score_list)

coeffs = np.abs(model.classifier[0].weight[0].detach().cpu().numpy())

mag_score_list = avg_score_list*coeffs
mag_hc_score_list = avg_hc_score_list*coeffs
mag_task_score_list = avg_task_score_list*coeffs

net_impact_scores = mag_score_list / np.sum(mag_score_list,axis=1).reshape(-1,1)
net_hc_scores = mag_hc_score_list / np.sum(mag_hc_score_list,axis=1).reshape(-1,1)
net_task_scores = mag_task_score_list / np.sum(mag_task_score_list,axis=1).reshape(-1,1)

proj_dict = {
    "mouse":mouse_list,
    
    "composite avgScore":np.sum(mag_score_list,axis=1),
    "composite avgHCScore":np.sum(mag_hc_score_list,axis=1),
    "composite avgTaskScore":np.sum(mag_task_score_list,axis=1),
    
    "net 1 avgScore":avg_score_list[:,0],
    "net 1 avgHCScore":avg_hc_score_list[:,0],
    "net 1 avgTaskScore":avg_task_score_list[:,0],
    "net 2 avgScore":avg_score_list[:,1],
    "net 2 avgHCScore":avg_hc_score_list[:,1],
    "net 2 avgTaskScore":avg_task_score_list[:,1],
    "net 3 avgScore":avg_score_list[:,2],
    "net 3 avgHCScore":avg_hc_score_list[:,2],
    "net 3 avgTaskScore":avg_task_score_list[:,2],
    
    "net 1 avgImpact":net_impact_scores[:,0],
    "net 1 avgHCImpact":net_hc_scores[:,0],
    "net 1 avgTaskImpact":net_task_scores[:,0],
    "net 2 avgImpact":net_impact_scores[:,1],
    "net 2 avgHCImpact":net_hc_scores[:,1],
    "net 2 avgTaskImpact":net_task_scores[:,1],
    "net 3 avgImpact":net_impact_scores[:,2],
    "net 3 avgHCImpact":net_hc_scores[:,2],
    "net 3 avgTaskImpact":net_task_scores[:,2],
}
df_means = pd.DataFrame.from_dict(proj_dict)
df_means.to_csv(mean_file)
df_means.head()

Unnamed: 0,mouse,composite avgScore,composite avgHCScore,composite avgTaskScore,net 1 avgScore,net 1 avgHCScore,net 1 avgTaskScore,net 2 avgScore,net 2 avgHCScore,net 2 avgTaskScore,...,net 3 avgTaskScore,net 1 avgImpact,net 1 avgHCImpact,net 1 avgTaskImpact,net 2 avgImpact,net 2 avgHCImpact,net 2 avgTaskImpact,net 3 avgImpact,net 3 avgHCImpact,net 3 avgTaskImpact
0,Mouse04203,3.141445,2.065341,4.009461,0.086621,0.061394,0.106969,0.135739,0.086547,0.175418,...,0.050979,0.270071,0.291152,0.261312,0.720452,0.6987,0.72949,0.009477,0.010147,0.009198
1,Mouse39115,3.16838,1.896296,3.667851,0.077542,0.036479,0.093664,0.142014,0.090194,0.162361,...,0.059839,0.239709,0.188418,0.250121,0.747354,0.793056,0.738077,0.012937,0.018526,0.011803
2,Mouse39121,3.537434,2.185519,4.631442,0.124493,0.089341,0.15294,0.13739,0.077413,0.185925,...,0.046178,0.344702,0.400388,0.323438,0.647586,0.590595,0.669349,0.007712,0.009017,0.007213
3,Mouse39122,3.904745,3.295088,4.321552,0.069922,0.059114,0.077311,0.191165,0.161196,0.211655,...,0.048741,0.175391,0.175716,0.175222,0.816296,0.815676,0.816619,0.008313,0.008608,0.00816
4,Mouse39132,3.592385,2.348602,4.637017,0.101218,0.066415,0.130448,0.154132,0.100351,0.199302,...,0.050096,0.275969,0.276976,0.27554,0.715387,0.712434,0.716644,0.008644,0.01059,0.007816
