In [1]:
import numpy as np
import pandas as pd
import torch
import pickle
import matplotlib.pyplot as plt
from lpne.models import DcsfaNmf

MODEL_FILE = "/hpc/home/mk423/Anxiety/FullDataWork/Models/Final_mt_Model_500_epochs.pt"
DATA_PATH = "/work/mk423/Anxiety/"
PROJECT_PATH = "/hpc/home/mk423/Anxiety/FullDataWork/Projections/"
c19_oft_file = DATA_PATH + "C19_OFT_Data.pkl"
proj_file = PROJECT_PATH + "c19_oft_projection.csv"
mean_file = PROJECT_PATH + "c19_oft_mean_scores.csv"
model = torch.load(MODEL_FILE,map_location="cpu")
model.device="cpu"

import os, sys
umc_data_tools_path = "/hpc/home/mk423/Anxiety/Universal-Mouse-Code/"
sys.path.append(umc_data_tools_path)
import umc_data_tools as umc_dt

FEATURE_LIST = ['X_psd','X_coh','X_gc']
FEATURE_WEIGHT = [10,1,1]

def get_3_net_aucs(s,y,y_group=None):
    
    auc_mean_list = []
    auc_stderr_list =[]
    
    if y_group is None:
        y_group = np.ones(s.shape[0])
        
    for i in range(3):
        auc_dict = umc_dt.lpne_auc(y,y,y_group,s[:,i].reshape(-1,1),mannWhitneyU=True)
        mean = np.mean([auc_dict[key][0] for key in auc_dict.keys() if key != "auc_method"])
        stderr = np.std([auc_dict[key][0] for key in auc_dict.keys() if key != "auc_method"]) / np.sqrt(len(auc_dict.keys()))
        
        auc_mean_list.append(mean)
        auc_stderr_list.append(stderr)
        
    return auc_mean_list, auc_stderr_list

def get_3_net_auc_dict(s,y,y_group=None):
    auc_dict_list = []
    
    if y_group is None:
        y_group = np.ones(s.shape[0])
        
    for i in range(3):
        auc_dict = umc_dt.lpne_auc(y,y,y_group,s[:,i].reshape(-1,1),mannWhitneyU=True)
        auc_dict_list.append(auc_dict)
        
    return auc_dict_list

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataDict = pickle.load(open(c19_oft_file,"rb"))

X = np.hstack([weight*dataDict[feature] for weight,feature in zip(FEATURE_WEIGHT,FEATURE_LIST)])
y_time = dataDict['y_time']
y_mouse = np.array(dataDict['y_mouse'])
y_condition = np.array(dataDict['y_condition'])
y_task = dataDict['y_task']
y_expDate = dataDict['y_expDate']
y_sex = np.array(dataDict['y_sex'])
s = model.project(X)


In [4]:
c19_mask = y_condition=="CLOCK"
wt_mask = y_condition=="WT"


c19_vs_wt = model.score(X,
                        wt_mask,
                       )

c19_vs_wt_hc = model.score(X[y_task==0],
                        wt_mask[y_task==0],
                       )

c19_vs_wt_task = model.score(X[y_task==1],
                        wt_mask[y_task==1],
                       )
c19_hc_v_task_auc = model.score(X[c19_mask==1],
                                y_task[c19_mask==1],
                                y_mouse[c19_mask==1],
                                return_dict=True)

c19_hc_v_task_auc_list = [c19_hc_v_task_auc[key] for key in c19_hc_v_task_auc.keys()
                          if not np.isnan(c19_hc_v_task_auc[key])]

c19_hc_v_task_mean = np.mean(c19_hc_v_task_auc_list)
c19_hc_v_task_sterr = np.std(c19_hc_v_task_auc_list) / np.sqrt(len(c19_hc_v_task_auc_list))

wt_hc_v_task_auc = model.score(X[wt_mask==1],
                                y_task[wt_mask==1],
                                y_mouse[wt_mask==1],
                                return_dict=True)

wt_hc_v_task_auc_list = [wt_hc_v_task_auc[key] for key in wt_hc_v_task_auc.keys() 
                         if not np.isnan(wt_hc_v_task_auc[key])]

wt_hc_v_task_mean = np.mean(wt_hc_v_task_auc_list)
wt_hc_v_task_sterr = np.std(wt_hc_v_task_auc_list) / np.sqrt(len(wt_hc_v_task_auc_list))

print("C19 HC vs Task {:.3f} +/- {:.3f} (n={})".format(c19_hc_v_task_mean,c19_hc_v_task_sterr,
                                               len(c19_hc_v_task_auc_list)))

print("WT HC vs Task {:.3f} +/- {:.3f} (n={})".format(wt_hc_v_task_mean,wt_hc_v_task_sterr,
                                               len(wt_hc_v_task_auc_list)))

print("Homecage wt vs c19: ",c19_vs_wt_hc)
print("Task wt vs c19: ",c19_vs_wt_task)
print("Overall wt vs c19: ",c19_vs_wt)

C19 HC vs Task 0.512 +/- 0.006 (n=4)
WT HC vs Task 0.514 +/- 0.025 (n=4)
Homecage wt vs c19:  0.4876571031258824
Task wt vs c19:  0.5444318509880958
Overall wt vs c19:  0.5208998757530277


In [5]:
c19_hc_v_task_3_net = get_3_net_aucs(s[c19_mask==1],y_task[c19_mask==1],y_mouse[c19_mask==1])
wt_hc_v_task_3_net = get_3_net_aucs(s[wt_mask==1],y_task[wt_mask==1],y_mouse[wt_mask==1])
hc_wt_vs_c19_3_net = get_3_net_aucs(s[y_task==0],wt_mask[y_task==0])
task_wt_vs_c19_3_net = get_3_net_aucs(s[y_task==1],wt_mask[y_task==1])

print("c19 hc vs task (means,stderr) ",c19_hc_v_task_3_net)
print("wt hc vs task (means,stderr) ",wt_hc_v_task_3_net)
print("hc wt vs c19 (means,stderr) ",hc_wt_vs_c19_3_net)
print("task wt vs c19 (means,stderr) ",task_wt_vs_c19_3_net)

c19 hc vs task (means,stderr)  ([0.5267833056781356, 0.5072875175173633, 0.501483192591798], [0.009884762161171234, 0.004290612618197886, 0.007499952488878685])
wt hc vs task (means,stderr)  ([0.5112291239856317, 0.5139069988882127, 0.5004377000055195], [0.02579647955659582, 0.02070721328767653, 0.015366170766802833])
hc wt vs c19 (means,stderr)  ([0.5289917230793965, 0.4763398037640854, 0.5251074721760943], [0.0, 0.0, 0.0])
task wt vs c19 (means,stderr)  ([0.6080747373921627, 0.5211843238483237, 0.5218109054224708], [0.0, 0.0, 0.0])


In [6]:
all_hc_v_task_aucs = model.score(X,y_task,y_mouse,return_dict=True)
all_hc_v_task_aucs_3_net = get_3_net_auc_dict(s,y_task,y_mouse)

results_dict = {
    "net 1 scores":s[:,0],
    "net 2 scores":s[:,1],
    "net 3 scores":s[:,2],
    "mouse":y_mouse,
    "sex":y_sex,
    "condition":y_condition,
    "in-task":y_task,
    "time":y_time,
    "expDate":y_expDate,
    "hc_v_task auc":[all_hc_v_task_aucs[mouse][0] for mouse in y_mouse],
    "net 1 hc_v_task auc":[all_hc_v_task_aucs_3_net[0][mouse][0] for mouse in y_mouse],
    "net 2 hc_v_task auc":[all_hc_v_task_aucs_3_net[1][mouse][0] for mouse in y_mouse],
    "net 3 hc_v_task auc":[all_hc_v_task_aucs_3_net[2][mouse][0] for mouse in y_mouse],
}

df = pd.DataFrame.from_dict(results_dict)
df.to_csv(proj_file)
df.head()

Unnamed: 0,net 1 scores,net 2 scores,net 3 scores,mouse,sex,condition,in-task,time,expDate,hc_v_task auc,net 1 hc_v_task auc,net 2 hc_v_task auc,net 3 hc_v_task auc
0,0.102891,0.170118,0.025495,Mouse30391,M,WT,False,1,100621,0.572226,0.602942,0.556138,0.542542
1,0.072117,0.168851,0.025139,Mouse30391,M,WT,False,2,100621,0.572226,0.602942,0.556138,0.542542
2,0.103911,0.158096,0.006654,Mouse30391,M,WT,False,3,100621,0.572226,0.602942,0.556138,0.542542
3,0.109749,0.155957,0.003117,Mouse30391,M,WT,False,4,100621,0.572226,0.602942,0.556138,0.542542
4,0.126828,0.202271,0.013822,Mouse30391,M,WT,False,5,100621,0.572226,0.602942,0.556138,0.542542


In [9]:
#Get per mouse average scores, average scores in HC and average scores in EPM
mouse_list = []
avg_score_list = []
avg_hc_score_list = []
avg_task_score_list = []
condition_list = []

for mouse in np.unique(y_mouse):
    mouse_mask = y_mouse==mouse
    hc_mouse_mask = np.logical_and(mouse_mask,y_task==0)
    task_mouse_mask = np.logical_and(mouse_mask,y_task==1)
    avg_score = np.mean(s[mouse_mask==1,:3],axis=0)
    avg_score_hc = np.mean(s[hc_mouse_mask==1,:3],axis=0)
    avg_score_task = np.mean(s[task_mouse_mask==1,:3],axis=0)
    
    mouse_list.append(mouse)
    avg_score_list.append(avg_score)
    avg_hc_score_list.append(avg_score_hc)
    avg_task_score_list.append(avg_score_task)
    condition_list.append(np.unique(y_condition[y_mouse==mouse])[0])
    
avg_score_list = np.array(avg_score_list)
avg_hc_score_list = np.array(avg_hc_score_list)
avg_task_score_list = np.array(avg_task_score_list)

coeffs = np.abs(model.classifier[0].weight[0].detach().cpu().numpy())

mag_score_list = avg_score_list*coeffs
mag_hc_score_list = avg_hc_score_list*coeffs
mag_task_score_list = avg_task_score_list*coeffs

net_impact_scores = mag_score_list / np.sum(mag_score_list,axis=1).reshape(-1,1)
net_hc_scores = mag_hc_score_list / np.sum(mag_hc_score_list,axis=1).reshape(-1,1)
net_task_scores = mag_task_score_list / np.sum(mag_task_score_list,axis=1).reshape(-1,1)

proj_dict = {
    "mouse":mouse_list,
    "condition":condition_list,
    
    "net 1 avgScore":avg_score_list[:,0],
    "net 1 avgHCScore":avg_hc_score_list[:,0],
    "net 1 avgTaskScore":avg_task_score_list[:,0],
    "net 2 avgScore":avg_score_list[:,1],
    "net 2 avgHCScore":avg_hc_score_list[:,1],
    "net 2 avgTaskScore":avg_task_score_list[:,1],
    "net 3 avgScore":avg_score_list[:,2],
    "net 3 avgHCScore":avg_hc_score_list[:,2],
    "net 3 avgTaskScore":avg_task_score_list[:,2],
    
    "net 1 avgImpact":net_impact_scores[:,0],
    "net 1 avgHCImpact":net_hc_scores[:,0],
    "net 1 avgTaskImpact":net_task_scores[:,0],
    "net 2 avgImpact":net_impact_scores[:,1],
    "net 2 avgHCImpact":net_hc_scores[:,1],
    "net 2 avgTaskImpact":net_task_scores[:,1],
    "net 3 avgImpact":net_impact_scores[:,2],
    "net 3 avgHCImpact":net_hc_scores[:,2],
    "net 3 avgTaskImpact":net_task_scores[:,2],
}
df_means = pd.DataFrame.from_dict(proj_dict)
df_means.to_csv(mean_file)
df_means.head()

Unnamed: 0,mouse,condition,net 1 avgScore,net 1 avgHCScore,net 1 avgTaskScore,net 2 avgScore,net 2 avgHCScore,net 2 avgTaskScore,net 3 avgScore,net 3 avgHCScore,net 3 avgTaskScore,net 1 avgImpact,net 1 avgHCImpact,net 1 avgTaskImpact,net 2 avgImpact,net 2 avgHCImpact,net 2 avgTaskImpact,net 3 avgImpact,net 3 avgHCImpact,net 3 avgTaskImpact
0,Mouse30391,WT,0.108739,0.104417,0.113359,0.166796,0.161659,0.172288,0.028558,0.027363,0.029836,0.275434,0.273604,0.27726,0.719223,0.7211,0.71735,0.005343,0.005296,0.00539
1,Mouse30392,WT,0.050447,0.050045,0.050915,0.066983,0.06421,0.070207,0.047569,0.047288,0.047896,0.300303,0.307319,0.292668,0.678782,0.671233,0.686997,0.020916,0.021449,0.020335
2,Mouse69841,CLOCK,0.089825,0.087691,0.092268,0.152558,0.151759,0.153473,0.03011,0.030444,0.029728,0.255364,0.25178,0.25938,0.738313,0.741763,0.734447,0.006323,0.006456,0.006173
3,Mouse69861,WT,0.112302,0.114414,0.11087,0.143594,0.1509,0.138638,0.031148,0.031188,0.03112,0.312778,0.306247,0.317519,0.680814,0.687587,0.675898,0.006408,0.006166,0.006583
4,Mouse69862,WT,0.121645,0.123674,0.121084,0.156372,0.154852,0.156793,0.043155,0.047101,0.042062,0.31109,0.316494,0.309595,0.680759,0.674603,0.682462,0.008152,0.008903,0.007944
