In [1]:
import torch
import numpy as np
import pickle
from lpne.models import DcsfaNmf
from lpne.plotting import circle_plot
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
from PIL import Image
import matplotlib.pyplot as plt
import os, sys
from sklearn.model_selection import KFold

umc_data_tools_path = "/hpc/home/mk423/Anxiety/Universal-Mouse-Code/"
sys.path.append(umc_data_tools_path)
import umc_data_tools as umc_dt

flx_data_path = "/work/mk423/Anxiety/final_FLX_{}.pkl"
epm_data_path = "/work/mk423/Anxiety/EPM_{}_dict_May_17.pkl"
oft_data_path = "/work/mk423/Anxiety/OFT_{}_dict_old_features_hand_picked.pkl"

anx_info_dict = "/work/mk423/Anxiety/Anx_Info_Dict.pkl"

saved_model_path = "/hpc/home/mk423/Anxiety/FullDataWork/Models/"
saved_model_name = "all_mt_model.pt"

projection_save_path = "/hpc/home/mk423/Anxiety/FullDataWork/Projections/"
plots_path = "/hpc/home/mk423/Anxiety/FullDataWork/Figures/"

feature_list = ["X_psd","X_coh","X_gc"]
old_feature_list = ["X_power_1_2","X_coh_1_2","X_gc_1_2"]
feature_weights = [10,1,1]

RANDOM_STATE = 42
kf = KFold(n_splits=4)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open(flx_data_path.format("train"),"rb") as f:
    train_dict = pickle.load(f)
    
with open(flx_data_path.format("val"),"rb") as f:
    val_dict = pickle.load(f)
    
flx_X_train = np.hstack([train_dict[feature]*weight for feature,weight in zip(feature_list,feature_weights)])
flx_y_train = train_dict['y_flx']
flx_y_mouse_train = train_dict['y_mouse']
flx_y_hab_train = train_dict['y_hab']
flx_y_time_train = train_dict['y_time']

flx_X_val = np.hstack([val_dict[feature]*weight for feature,weight in zip(feature_list,feature_weights)])
flx_y_val = val_dict['y_flx']
flx_y_mouse_val = val_dict['y_mouse']
flx_y_hab_val = val_dict['y_hab']
flx_y_time_val = val_dict['y_time']

flx_X = np.vstack([flx_X_train[flx_y_hab_train==1],flx_X_val[flx_y_hab_val==1]])
flx_y = np.hstack([flx_y_train[flx_y_hab_train==1],flx_y_val[flx_y_hab_val==1]])
flx_y_mouse = np.hstack([flx_y_mouse_train[flx_y_hab_train==1],flx_y_mouse_val[flx_y_hab_val==1]])
flx_y_time = np.hstack([flx_y_time_train[flx_y_hab_train==1],flx_y_time_val[flx_y_hab_val==1]])

In [21]:
with open(oft_data_path.format("train"),'rb') as f:
    train_dict = pickle.load(f)

with open(oft_data_path.format("validation"),'rb') as f:
    val_dict = pickle.load(f)

#Train Arrays
oft_X_train = np.hstack([train_dict[feature]*weight for feature,weight in zip(old_feature_list,feature_weights)])
oft_y_hc_train = train_dict['y_Homecage'].astype(bool)
oft_y_task_train = ~oft_y_hc_train
oft_y_ROI_train = train_dict['y_ROI']
oft_y_vel_train = train_dict['y_vel']
oft_y_mouse_train = train_dict['y_mouse']
oft_y_time_train = train_dict['y_time']

#Validation Arrays
oft_X_val = np.hstack([val_dict[feature]*weight for feature,weight in zip(old_feature_list,feature_weights)])
oft_y_hc_val = val_dict['y_Homecage'].astype(bool)
oft_y_task_val = ~oft_y_hc_val
oft_y_ROI_val = val_dict['y_ROI']
oft_y_vel_val = val_dict['y_vel']
oft_y_mouse_val = val_dict['y_mouse']
oft_y_time_val = val_dict['y_time']

oft_X = np.vstack([oft_X_train,oft_X_val])
oft_y = np.hstack([oft_y_task_train,oft_y_task_val])
oft_y_mouse = np.hstack([oft_y_mouse_train,oft_y_mouse_val])

In [19]:
with open(epm_data_path.format("train"),"rb") as f:
    epm_train_dict = pickle.load(f)

with open(epm_data_path.format("val"),"rb") as f:
    epm_validation_dict = pickle.load(f)
    

X_train = np.hstack([epm_train_dict[feature]*weight for feature,weight in zip(old_feature_list,feature_weights)])
X_train[X_train<0] = 0
y_train = (epm_train_dict['y_ROI']%2).astype(bool)
y_in_task_mask_train = ~epm_train_dict['y_Homecage'].astype(bool)
y_mouse_train = epm_train_dict['y_mouse']
y_time_train = epm_train_dict['y_time']
train_nan_mask = (epm_train_dict['y_ROI'] > 0)


X_train_task = X_train[np.logical_and(y_in_task_mask_train==1,train_nan_mask)==1]
y_train_task = y_train[np.logical_and(y_in_task_mask_train==1,train_nan_mask)==1]
y_mouse_train_task = y_mouse_train[np.logical_and(y_in_task_mask_train==1,train_nan_mask)==1]
y_time_train_task = y_time_train[np.logical_and(y_in_task_mask_train==1,train_nan_mask)==1]
X_val = np.hstack([epm_validation_dict[feature]*weight for feature,weight in zip(old_feature_list,feature_weights)])


y_val = (epm_validation_dict['y_ROI']%2).astype(bool)
y_in_task_mask_val= ~epm_validation_dict['y_Homecage'].astype(bool)
y_mouse_val = epm_validation_dict['y_mouse']
y_time_val = epm_validation_dict['y_time']
val_nan_mask = (epm_validation_dict['y_ROI'] > 0)

X_val_task = X_val[np.logical_and(y_in_task_mask_val==1,val_nan_mask)==1]
y_val_task = y_val[np.logical_and(y_in_task_mask_val==1,val_nan_mask)==1]
y_mouse_val_task = y_mouse_val[np.logical_and(y_in_task_mask_val==1,val_nan_mask)==1]
y_time_val_task = y_time_val[np.logical_and(y_in_task_mask_val==1,val_nan_mask)==1]

epm_X = np.vstack([X_train,X_val])
epm_y = np.hstack([y_in_task_mask_train,y_in_task_mask_val])
epm_y_mouse = np.hstack([y_mouse_train,y_mouse_val])
epm_y_time = np.hstack([y_time_train,y_time_val])

## Identify dataset overlap

In [29]:
epm_oft_overlap = np.intersect1d(np.unique(epm_y_mouse),np.unique(oft_y_mouse))
flx_oft_overlap = np.intersect1d(np.unique(oft_y_mouse),np.unique(flx_y_mouse))
#flx_epm_overlap = np.intersect1d(np.unique(epm_y_mouse),np.unique(flx_y_mouse)) This is empty

shared_mice = np.union1d(epm_oft_overlap,flx_oft_overlap)
all_mice = np.union1d(np.unique(epm_y_mouse),np.union1d(oft_y_mouse,flx_y_mouse))

always_training_mice = []
for mouse in all_mice:
    flx_mouse_mask = flx_y_mouse == mouse
    epm_mouse_mask = epm_y_mouse == mouse
    oft_mouse_mask = oft_y_mouse == mouse
    
    flx_single_class = np.unique(flx_y[flx_mouse_mask==1]).shape[0]==1
    epm_single_class = np.unique(epm_y[epm_mouse_mask==1]).shape[0]==1
    oft_single_class = np.unique(oft_y[oft_mouse_mask==1]).shape[0]==1
    
    if flx_single_class or epm_single_class or oft_single_class:
        always_training_mice.append(mouse)
    
always_training_mice = np.array(always_training_mice)

### Collect First Pass KFolds

In [50]:
flx_mice = np.unique(flx_y_mouse)
mc_unique_flx_mice = np.array([mouse 
                               for mouse in flx_mice 
                               if mouse not in np.union1d(always_training_mice,shared_mice)])

flx_always_train_mice = np.array([mouse
                                  for mouse in flx_mice
                                  if mouse in always_training_mice])

print(flx_always_train_mice)
mc_flx_kf_train_mice = []
mc_flx_kf_val_mice = []

for i, (train_idxs,val_idxs) in enumerate(kf.split(mc_unique_flx_mice)):
    mc_flx_kf_train_mice.append(mc_unique_flx_mice[train_idxs])
    mc_flx_kf_val_mice.append(mc_unique_flx_mice[val_idxs])
    
    print(i+1)
    print("train mice",mc_flx_kf_train_mice[-1])
    print("val mice",mc_flx_kf_val_mice[-1])

['Mouse3192' 'Mouse3194' 'Mouse3202' 'Mouse3203' 'Mouse99002' 'Mouse99003'
 'Mouse99021']
1
train mice ['Mouse61635' 'Mouse78744' 'Mouse78752']
val mice ['Mouse61631']
2
train mice ['Mouse61631' 'Mouse78744' 'Mouse78752']
val mice ['Mouse61635']
3
train mice ['Mouse61631' 'Mouse61635' 'Mouse78752']
val mice ['Mouse78744']
4
train mice ['Mouse61631' 'Mouse61635' 'Mouse78744']
val mice ['Mouse78752']


In [51]:
epm_mice = np.unique(epm_y_mouse)
mc_unique_epm_mice = np.array([mouse
                               for mouse in epm_mice
                               if mouse not in np.union1d(always_training_mice,shared_mice)])

epm_always_train_mice = np.array([mouse
                                  for mouse in epm_mice
                                  if mouse in always_training_mice])

print(epm_always_train_mice)
mc_epm_kf_train_mice = []
mc_epm_kf_val_mice = []

for i, (train_idxs, val_idxs) in enumerate(kf.split(mc_unique_epm_mice)):
    
    mc_epm_kf_train_mice.append(mc_unique_epm_mice[train_idxs])
    mc_epm_kf_val_mice.append(mc_unique_epm_mice[val_idxs])
    
    print(i+1)
    print("train mice",mc_epm_kf_train_mice[-1])
    print("val mice",mc_epm_kf_val_mice[-1])

['Mouse04215']
1
train mice ['Mouse0643' 'Mouse1551' 'Mouse6291' 'Mouse6292' 'Mouse6293' 'Mouse8580'
 'Mouse8581' 'Mouse8582' 'Mouse8891' 'Mouse8894']
val mice ['Mouse0630' 'Mouse0633' 'Mouse0634' 'Mouse0642']
2
train mice ['Mouse0630' 'Mouse0633' 'Mouse0634' 'Mouse0642' 'Mouse6293' 'Mouse8580'
 'Mouse8581' 'Mouse8582' 'Mouse8891' 'Mouse8894']
val mice ['Mouse0643' 'Mouse1551' 'Mouse6291' 'Mouse6292']
3
train mice ['Mouse0630' 'Mouse0633' 'Mouse0634' 'Mouse0642' 'Mouse0643' 'Mouse1551'
 'Mouse6291' 'Mouse6292' 'Mouse8582' 'Mouse8891' 'Mouse8894']
val mice ['Mouse6293' 'Mouse8580' 'Mouse8581']
4
train mice ['Mouse0630' 'Mouse0633' 'Mouse0634' 'Mouse0642' 'Mouse0643' 'Mouse1551'
 'Mouse6291' 'Mouse6292' 'Mouse6293' 'Mouse8580' 'Mouse8581']
val mice ['Mouse8582' 'Mouse8891' 'Mouse8894']


In [55]:
#oft only has 2 completely unique mice with multiple classes so inclusion of these mice in the training and test
#sets will be done by hand.
oft_mice = np.unique(oft_y_mouse)

mc_unique_oft_mice = np.array([mouse
                               for mouse in oft_mice
                               if mouse not in np.union1d(always_training_mice,shared_mice)])

print("unique oft mice",mc_unique_oft_mice)

oft_always_train_mice = np.array([mouse
                                  for mouse in oft_mice
                                  if mouse in always_training_mice])

print("oft always train mice",oft_always_train_mice)
mc_oft_kf_train_mice = []
mc_oft_kf_val_mice = []



unique oft mice ['Mouse04191' 'Mouse69072']
oft always train mice ['Mouse04215' 'Mouse3192' 'Mouse3194' 'Mouse3203']


### Overlapping Mice Kfold Splits

In [62]:
#OFT FLX
mc_oft_flx_overlap_mice = np.array([mouse
                                    for mouse in flx_oft_overlap
                                    if mouse not in always_training_mice])

print(mc_oft_flx_overlap_mice)
#There are only 2 mice in this category, so inclusion to the kfolds will be done by hand.

['Mouse3191' 'Mouse3193']


In [61]:
#OFT EPM
mc_oft_epm_overlap_mice = np.array([mouse
                                    for mouse in epm_oft_overlap
                                    if mouse not in always_training_mice])

mc_oft_epm_kf_train_mice = []
mc_oft_epm_kf_val_mice = []

for i, (train_idxs,val_idxs) in enumerate(kf.split(mc_oft_epm_overlap_mice)):
    mc_oft_epm_kf_train_mice.append(mc_oft_epm_overlap_mice[train_idxs])
    mc_oft_epm_kf_val_mice.append(mc_oft_epm_overlap_mice[val_idxs])
    
    print(i+1)
    print("train mice",mc_oft_epm_kf_train_mice[-1])
    print("val mice",mc_oft_epm_kf_val_mice[-1])

1
train mice ['Mouse04205' 'Mouse39114' 'Mouse39124' 'Mouse39125' 'Mouse39133'
 'Mouse69064' 'Mouse69065' 'Mouse69074']
val mice ['Mouse04193' 'Mouse04201' 'Mouse04202']
2
train mice ['Mouse04193' 'Mouse04201' 'Mouse04202' 'Mouse39125' 'Mouse39133'
 'Mouse69064' 'Mouse69065' 'Mouse69074']
val mice ['Mouse04205' 'Mouse39114' 'Mouse39124']
3
train mice ['Mouse04193' 'Mouse04201' 'Mouse04202' 'Mouse04205' 'Mouse39114'
 'Mouse39124' 'Mouse69065' 'Mouse69074']
val mice ['Mouse39125' 'Mouse39133' 'Mouse69064']
4
train mice ['Mouse04193' 'Mouse04201' 'Mouse04202' 'Mouse04205' 'Mouse39114'
 'Mouse39124' 'Mouse39125' 'Mouse39133' 'Mouse69064']
val mice ['Mouse69065' 'Mouse69074']


### Combine all splits

In [71]:
kf_flx_train_mice = []
kf_flx_val_mice = []
kf_epm_train_mice = []
kf_epm_val_mice = []
kf_oft_train_mice = []
kf_oft_val_mice = []

for i in range(4):
    
    temp_flx_train_mice = np.hstack([flx_always_train_mice,mc_flx_kf_train_mice[i]])
    temp_flx_val_mice = np.hstack([mc_flx_kf_val_mice[i],mc_oft_flx_overlap_mice])
    
    temp_epm_train_mice = np.hstack([epm_always_train_mice,mc_epm_kf_train_mice[i],mc_oft_epm_kf_train_mice[i]])
    temp_epm_val_mice = np.hstack([mc_epm_kf_val_mice[i],mc_oft_epm_kf_val_mice[i]])
    
    temp_oft_train_mice = np.hstack([oft_always_train_mice,mc_oft_epm_kf_train_mice[i],mc_unique_oft_mice[int(i%2)]])
    temp_oft_val_mice = np.hstack([mc_oft_epm_kf_val_mice[i],mc_oft_flx_overlap_mice,mc_unique_oft_mice[int((i+1)%2)]])
    

    kf_flx_train_mice.append(temp_flx_train_mice)
    kf_flx_val_mice.append(temp_flx_val_mice)
    kf_epm_train_mice.append(temp_epm_train_mice)
    kf_epm_val_mice.append(temp_epm_val_mice)
    kf_oft_train_mice.append(temp_oft_train_mice)
    kf_oft_val_mice.append(temp_oft_val_mice)

### Save the fixed splits

In [72]:
for i in range(4):
    
    #FLX
    flx_train_mask = np.array([1 if mouse in kf_flx_train_mice[i] 
                               else 0 
                               for mouse in flx_y_mouse])
    flx_val_mask = 1-flx_train_mask
    
    flx_fold_dict = {
        "X_train":flx_X[flx_train_mask==1],
        "y_train":flx_y[flx_train_mask==1],
        "y_mouse_train":flx_y_mouse[flx_train_mask==1],
        "train_mice":kf_flx_train_mice[i],
        
        "X_val":flx_X[flx_val_mask==1],
        "y_val":flx_y[flx_val_mask==1],
        "y_mouse_val":flx_y_mouse[flx_val_mask==1],
        "val_mice":kf_flx_val_mice[i]
    }
    
    with open("/work/mk423/Anxiety/fixed_flx_kf_dict_fold_{}.pkl".format(i+1),"wb") as f:
        pickle.dump(flx_fold_dict,f)
        
    #EPM
    epm_train_mask = np.array([1 if mouse in kf_epm_train_mice[i] 
                               else 0 
                               for mouse in epm_y_mouse])
    epm_val_mask = 1-epm_train_mask
    
    epm_fold_dict = {
        "X_train":epm_X[epm_train_mask==1],
        "y_train":epm_y[epm_train_mask==1],
        "y_mouse_train":epm_y_mouse[epm_train_mask==1],
        "train_mice":kf_epm_train_mice[i],
        
        "X_val":epm_X[epm_val_mask==1],
        "y_val":epm_y[epm_val_mask==1],
        "y_mouse_val":epm_y_mouse[epm_val_mask==1],
        "val_mice":kf_epm_val_mice[i]
    }
    
    with open("/work/mk423/Anxiety/fixed_epm_kf_dict_fold_{}.pkl".format(i+1),"wb") as f:
        pickle.dump(epm_fold_dict,f)
        
    #OFT
    oft_train_mask = np.array([1 if mouse in kf_oft_train_mice[i] 
                               else 0 
                               for mouse in oft_y_mouse])
    oft_val_mask = 1-oft_train_mask
    
    oft_fold_dict = {
        "X_train":oft_X[oft_train_mask==1],
        "y_train":oft_y[oft_train_mask==1],
        "y_mouse_train":oft_y_mouse[oft_train_mask==1],
        "train_mice":kf_oft_train_mice[i],
        
        "X_val":oft_X[oft_val_mask==1],
        "y_val":oft_y[oft_val_mask==1],
        "y_mouse_val":oft_y_mouse[oft_val_mask==1],
        "val_mice":kf_oft_val_mice[i]
    }
    
    with open("/work/mk423/Anxiety/fixed_oft_kf_dict_fold_{}.pkl".format(i+1),"wb") as f:
        pickle.dump(oft_fold_dict,f)
    