In [1]:
DATA_LOCATION = "/work/mk423/Anxiety/"
UMC_PATH = "/hpc/home/mk423/Anxiety/Universal-Mouse-Code/"
MODEL_PATH = "../Models/"
controls_data_file = "FC_Extinction_Controls.pkl"
conditioned_data_file = "FC_Extinction_Conditioned.pkl"

FEATURE_LIST = ['X_psd','X_coh','X_gc']
FEATURE_VECTOR = FEATURE_LIST
FEATURE_WEIGHT = [10,1,1]

MT_MODEL_PATH = MODEL_PATH + "Positive_MT_10_res_loss_10_power_features.pt"
PROJECTION_SAVE_PATH = "/hpc/home/mk423/Anxiety/MultiTaskWork/Projections/holdoutExperiments/"

TRAIN = False
PROJ_TEST = True

In [3]:
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import sys
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import OneHotEncoder
import pandas as pd
sys.path.append(UMC_PATH)
sys.path.append(PROJECTION_SAVE_PATH)
#from dCSFA_model import dCSFA_model
import umc_data_tools as umc_dt
from dCSFA_NMF import dCSFA_NMF

if torch.cuda.is_available():
    device="cuda:0"
else:
    device="cpu"


print("Using device: %s"%(device))

#For Consistency
RANDOM_STATE=42

import pandas as pd

model = torch.load(MT_MODEL_PATH,map_location='cpu')
model.device = "cpu"
model.Encoder.eval()

Using device: cuda:0


Sequential(
  (0): Linear(in_features=5152, out_features=256, bias=True)
  (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): LeakyReLU(negative_slope=0.01)
  (3): Linear(in_features=256, out_features=20, bias=True)
  (4): Softplus(beta=1, threshold=20)
)

In [4]:
with open(DATA_LOCATION+controls_data_file,"rb") as f:
    controls_dict = pickle.load(f)
    
with open(DATA_LOCATION+conditioned_data_file,"rb") as f:
    conditioned_dict = pickle.load(f)

In [5]:
controls_dict.keys()

dict_keys(['X_psd', 'X_coh', 'X_gc', 'y_time', 'y_mouse', 'y_expDate', 'y_tone', 'y_freeze_w_tone', 'labels', 'feature version'])

In [10]:
X_controls = np.hstack([weight*controls_dict[feature] for weight,feature in zip(FEATURE_WEIGHT,FEATURE_LIST)])
X_cond = np.hstack([weight*conditioned_dict[feature] for weight,feature in zip(FEATURE_WEIGHT,FEATURE_LIST)])
X = np.vstack([X_controls,X_cond])
y_time = np.hstack([controls_dict['y_time'],conditioned_dict['y_time']])
y_mouse = np.hstack([controls_dict['y_mouse'],conditioned_dict['y_mouse']])
y_expDate = np.hstack([controls_dict['y_expDate'],conditioned_dict['y_expDate']])
y_tone = np.hstack([controls_dict['y_tone'],conditioned_dict['y_tone']])
y_freeze_w_tone = np.hstack([controls_dict['y_freeze_w_tone'],conditioned_dict['y_freeze_w_tone']])
y_conditioned = np.hstack([np.zeros(controls_dict['y_time'].shape),np.ones(conditioned_dict['y_time'].shape)])

In [17]:
y_pred,s = model.transform(X,None)[2:]

#Get per mouse average scores, average scores in HC and average scores in EPM
mouse_list = []
conditioned_list = []
avg_score_list = []
avg_tone_score_list = []
avg_no_tone_score_list = []
avg_freeze_w_tone_score_list = []
condition_list = []
for mouse in np.unique(y_mouse):
    if mouse in controls_dict['y_mouse']:
        conditioned_list.append(0)
    else:
        conditioned_list.append(1)
        
    mouse_mask = y_mouse==mouse
    tone_mouse_mask = np.logical_and(mouse_mask,y_tone.squeeze()==1)
    no_tone_mouse_mask = np.logical_and(mouse_mask,y_tone.squeeze()==0)
    freeze_w_tone_mouse_mask = np.logical_and(mouse_mask,y_freeze_w_tone.squeeze()==1)
    
    avg_score = np.mean(s[mouse_mask==1,0])
    avg_tone_score = np.mean(s[tone_mouse_mask==1,0])
    avg_no_tone_score = np.mean(s[no_tone_mouse_mask==1,0])
    avg_freeze_w_tone_score = np.mean(s[freeze_w_tone_mouse_mask==1,0])
    
    mouse_list.append(mouse)
    avg_score_list.append(avg_score)
    avg_tone_score_list.append(avg_tone_score)
    avg_no_tone_score_list.append(avg_no_tone_score)
    avg_freeze_w_tone_score_list.append(avg_freeze_w_tone_score)
    
proj_dict = {
    "mouse":mouse_list,
    "conditioned":conditioned_list,
    "Average Score":avg_score_list,
    "Average Tone Score":avg_tone_score_list,
    "Average No Tone Score":avg_no_tone_score_list,
    "Average Freeze With Tone Score":avg_freeze_w_tone_score_list,
}

df_projections = pd.DataFrame.from_dict(proj_dict)

df_projections.to_csv(PROJECTION_SAVE_PATH + "FC_extinction_mean_scores.csv")

In [20]:
mw_auc_dict = umc_dt.lpne_auc(y_pred,y_freeze_w_tone.squeeze(),y_mouse,s,True)
mw_mean, mw_std = umc_dt.get_mean_std_err_auc(y_pred,y_freeze_w_tone.squeeze(),y_mouse,s,True)
print("by freezing with tone auc: {:.3} +/- {:.3}".format(mw_mean,mw_std))

mw_auc_dict = umc_dt.lpne_auc(y_pred,y_tone.squeeze(),y_mouse,s,True)
mw_mean, mw_std = umc_dt.get_mean_std_err_auc(y_pred,y_tone.squeeze(),y_mouse,s,True)
print("by tone only auc: {:.3} +/- {:.3}".format(mw_mean,mw_std))

Mouse  Mouse9071  has only one class - AUC cannot be calculated
n_positive samples  0
n_negative samples  860
Mouse  Mouse9072  has only one class - AUC cannot be calculated
n_positive samples  0
n_negative samples  968
Mouse  Mouse9071  has only one class - AUC cannot be calculated
n_positive samples  0
n_negative samples  860
Mouse  Mouse9072  has only one class - AUC cannot be calculated
n_positive samples  0
n_negative samples  968
by freezing with tone auc: 0.504 +/- 0.00694
by tone only auc: 0.499 +/- 0.00687


In [21]:
mw_auc_dict = umc_dt.lpne_auc(y_pred,y_freeze_w_tone.squeeze(),y_mouse,s,True)
mw_mean, mw_std = umc_dt.get_mean_std_err_auc(y_pred,y_freeze_w_tone.squeeze(),y_mouse,s,True)
print("by freezing with tone auc: {:.3} +/- {:.3}".format(mw_mean,mw_std))

auc_list = []
p_val_list = []

for mouse in y_mouse:
    auc_list.append(mw_auc_dict[mouse][0])
    p_val_list.append(mw_auc_dict[mouse][1])

saveDict = {
    "mouse":y_mouse.squeeze(),
    "conditioned":y_conditioned,
    "freeze with tone":y_freeze_w_tone.squeeze(),
    "tone only":y_tone.squeeze(),
    "time":y_time.squeeze(),
    "expDate":y_expDate.squeeze(),
    "scores":s[:,0],
    "freeze with tone auc": auc_list,
    "freeze with tone pval": p_val_list,
    
}

df = pd.DataFrame.from_dict(saveDict)
df.to_csv(PROJECTION_SAVE_PATH + "MT_onto_FC_Extinction.csv")

Mouse  Mouse9071  has only one class - AUC cannot be calculated
n_positive samples  0
n_negative samples  860
Mouse  Mouse9072  has only one class - AUC cannot be calculated
n_positive samples  0
n_negative samples  968
Mouse  Mouse9071  has only one class - AUC cannot be calculated
n_positive samples  0
n_negative samples  860
Mouse  Mouse9072  has only one class - AUC cannot be calculated
n_positive samples  0
n_negative samples  968
by freezing with tone auc: 0.504 +/- 0.00694
