In [1]:
DATA_LOCATION = "/work/mk423/Anxiety/"
UMC_PATH = "/hpc/home/mk423/Anxiety/Universal-Mouse-Code/"
MODEL_PATH = "../Models/"
data_file = "ChR2_pickle_file.pkl"

FEATURE_LIST = ['X_psd','X_coh','X_gc']
FEATURE_VECTOR = FEATURE_LIST
FEATURE_WEIGHT = [10,1,1]

MT_MODEL_PATH = MODEL_PATH + "Positive_MT_10_res_loss_10_power_features.pt"
PROJECTION_SAVE_PATH = "/hpc/home/mk423/Anxiety/MultiTaskWork/Projections/holdoutExperiments/"

TRAIN = False
PROJ_TEST = True

In [2]:
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import sys
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import OneHotEncoder
import pandas as pd
sys.path.append(UMC_PATH)
sys.path.append(PROJECTION_SAVE_PATH)
#from dCSFA_model import dCSFA_model
import umc_data_tools as umc_dt
from dCSFA_NMF import dCSFA_NMF

if torch.cuda.is_available():
    device="cuda:0"
else:
    device="cpu"


print("Using device: %s"%(device))

#For Consistency
RANDOM_STATE=42

import pandas as pd

model = torch.load(MT_MODEL_PATH,map_location='cpu')
model.device = "cpu"
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda:0


dCSFA_NMF(
  (recon_loss_f): MSELoss()
  (Encoder): Sequential(
    (0): Linear(in_features=5152, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Linear(in_features=256, out_features=20, bias=True)
    (4): Softplus(beta=1, threshold=20)
  )
)

In [3]:
dataDict = pickle.load(open(DATA_LOCATION + data_file,"rb"))

X = np.hstack([weight*dataDict[feature] for weight,feature in zip(FEATURE_WEIGHT,FEATURE_LIST)])
y_time = dataDict['y_time']
y_mouse = np.array(dataDict['y_mouse'])
y_expDate = dataDict['y_expDate']
y_BLaser = np.array(dataDict['y_BLaser'])
y_YLaser = np.array(dataDict['y_YLaser'])

y_Laser = np.logical_or(y_BLaser,y_YLaser)

In [4]:
y_pred,s = model.transform(X,None)[2:]

In [5]:
#Get per mouse average scores, average scores in HC and average scores in EPM
mouse_list = []
avg_laser_list = []
avg_BLaser_list = []
avg_YLaser_list = []

for mouse in np.unique(y_mouse):
    mouse_mask = y_mouse==mouse
    mouse_and_laser = np.logical_and(mouse_mask,y_Laser)
    mouse_and_BLaser = np.logical_and(mouse_mask,y_BLaser)
    mouse_and_YLaser = np.logical_and(mouse_mask,y_YLaser)
    
    avg_laser_score = np.mean(s[mouse_and_laser==1,0])
    avg_BLaser_score = np.mean(s[mouse_and_BLaser==1,0])
    avg_YLaser_score = np.mean(s[mouse_and_YLaser==1,0])
    
    mouse_list.append(mouse)
    avg_laser_list.append(avg_laser_score)
    avg_BLaser_list.append(avg_BLaser_score)
    avg_YLaser_list.append(avg_YLaser_score)
    

proj_dict = {
    "mouse":mouse_list,
    "avgLaserScore":avg_laser_list,
    "avgBLaserScore":avg_BLaser_list,
    "avgYLaserScore":avg_YLaser_list
}
df_projections = pd.DataFrame.from_dict(proj_dict)

df_projections.to_csv(PROJECTION_SAVE_PATH + "ChR2_mean_scores.csv")

In [11]:

mw_auc_dict = umc_dt.lpne_auc(y_pred[y_Laser==1],y_BLaser[y_Laser==1].squeeze(),y_mouse[y_Laser==1],s[y_Laser==1],True)
mw_mean, mw_std = umc_dt.get_mean_std_err_auc(y_pred[y_Laser==1],y_BLaser[y_Laser==1].squeeze(),y_mouse[y_Laser==1],s[y_Laser==1],True)
print("by mouse hc vs task auc: {:.3} +/- {:.3}".format(mw_mean,mw_std))

auc_list = []
p_val_list = []

for mouse in y_mouse:
    auc_list.append(mw_auc_dict[mouse][0])
    p_val_list.append(mw_auc_dict[mouse][1])


saveDict = {
    "mouse":y_mouse,
    "time":y_time,
    "expDate":y_expDate,
    "scores":s[:,0],
    "BLaser vs YLaser auc": auc_list,
    "BLaser vs YLaser pval": p_val_list,
    "y_BLaser":y_BLaser,
    "y_YLaser":y_YLaser,
    "y_Laser":y_Laser
    
}

df = pd.DataFrame.from_dict(saveDict)
df.to_csv(PROJECTION_SAVE_PATH + "MT_onto_ChR2.csv")

by mouse hc vs task auc: 0.602 +/- 0.0239


In [8]:
avg_YLaser_score

0.10947455

In [10]:
mw_auc_dict

{'auc_method': 'mannWhitneyU',
 'Mouse5371': (0.6249247222167074, 9.298713261971126e-13),
 'Mouse5373': (0.6973381641220353, 1.6645357942637355e-29),
 'Mouse5391': (0.630943348688863, 6.490635381370341e-15),
 'Mouse5392': (0.5196898286274341, 0.242040776479166),
 'Mouse5393': (0.5432357350089619, 0.010542567019269485),
 'Mouse5394': (0.7637109958481823, 5.437022253242543e-47),
 'Mouse5395': (0.58918429616428, 9.898047642916311e-08),
 'Mouse9511': (0.5494003773000874, 0.0015403678533822951),
 'Mouse9512': (0.506118147228524, 0.6909146093499594),
 'Mouse9513': (0.5569474814397758, 0.0010546075306827776),
 'Mouse9514': (0.6422875883330584, 1.6142809725134186e-16)}