In [1]:
import pickle
import numpy as np

NewDataPath = "/work/mk423/Anxiety/New_FLX_Animals_April_12.pkl"
OldDataPath = "/work/mk423/Anxiety/FLX_{}_dict_old_features.pkl"

FEATURE_LIST = ["X_psd","X_coh","X_gc"]
FEATURE_WEIGHT = [10,1,1]

In [2]:
newDict = pickle.load(open(NewDataPath,"rb"))
oldTrainDict = pickle.load(open(OldDataPath.format("train"),"rb"))
oldValDict = pickle.load(open(OldDataPath.format("validation"),"rb"))
oldTestDict = pickle.load(open(OldDataPath.format("test"),"rb"))

In [3]:
newDict.keys()

dict_keys(['X_psd', 'X_coh', 'X_gc', 'y_flx', 'y_sal', 'y_hab', 'y_mouse', 'y_expDate', 'y_time'])

In [4]:
oldTrainDict.keys()

dict_keys(['X_psd', 'X_psd_first_30', 'X_ds', 'X_ds_first_30', 'y_mouse', 'y_mouse_first_30', 'y_expDate', 'y_expDate_first_30', 'y_time', 'y_time_first_30', 'mice', 'y_flx', 'y_flx_train_first_30', 'X_psd_full', 'X_ds_full', 'y_mouse_full', 'y_expDate_full', 'y_time_full', 'y_flx_full', 'X_power_1_2', 'X_power_1_2_first_30', 'X_power_1_2_full', 'X_coh_1_2', 'X_coh_1_2_first_30', 'X_coh_1_2_full', 'X_gc_1_2', 'X_gc_1_2_first_30', 'X_gc_1_2_full'])

In [18]:
oft_train_mice = ['Mouse04191', 'Mouse04201', 'Mouse04202', 'Mouse04205', 'Mouse04215','Mouse3191', 'Mouse3192', 'Mouse3193', 'Mouse3194', 'Mouse3203', 'Mouse39114','Mouse39124', 'Mouse39133', 'Mouse69064', 'Mouse69072', 'Mouse69074']

oft_val_mice= ['Mouse04193', 'Mouse39125', 'Mouse69065']

oft_mice = oft_train_mice + oft_val_mice


In [19]:
X_psd_full = np.vstack([newDict['X_psd'],oldTrainDict['X_power_1_2_full'],oldValDict['X_power_1_2_full'],oldTestDict['X_power_1_2_full']])
X_coh_full = np.vstack([newDict['X_coh'],oldTrainDict['X_coh_1_2_full'],oldValDict['X_coh_1_2_full'],oldTestDict['X_coh_1_2_full']])
X_gc_full = np.vstack([newDict['X_gc'],oldTrainDict['X_gc_1_2_full'],oldValDict['X_gc_1_2_full'],oldTestDict['X_gc_1_2_full']])

y_flx_full = np.hstack([newDict['y_flx'].squeeze(),oldTrainDict['y_flx_full'],oldValDict['y_flx_full'],oldTestDict['y_flx_full']])
y_mouse_full = np.hstack([newDict['y_mouse'].squeeze(),oldTrainDict['y_mouse_full'],oldValDict['y_mouse_full'],oldTestDict['y_mouse_full']])
y_expDate_full = np.hstack([newDict['y_expDate'].squeeze(),oldTrainDict['y_expDate_full'],oldValDict['y_expDate_full'],oldTestDict['y_expDate_full']])
y_time_full = np.hstack([newDict['y_time'].squeeze(),oldTrainDict['y_time_full'],oldValDict['y_time_full'],oldTestDict['y_time_full']])

y_hab_full = y_time_full > 60*30

In [20]:
np.unique(y_mouse_full)

array(['Mouse3191', 'Mouse3192', 'Mouse3193', 'Mouse3194', 'Mouse3202',
       'Mouse3203', 'Mouse61631', 'Mouse61635', 'Mouse69061',
       'Mouse78732', 'Mouse78743', 'Mouse78744', 'Mouse78745',
       'Mouse78751', 'Mouse78752', 'Mouse78764', 'Mouse99002',
       'Mouse99003', 'Mouse99021'], dtype='<U10')

In [21]:
np.intersect1d(oft_mice,np.unique(y_mouse_full))

array(['Mouse3191', 'Mouse3192', 'Mouse3193', 'Mouse3194', 'Mouse3203'],
      dtype='<U10')

In [22]:
## Define Training Data
default_training_mice = list(np.intersect1d(oft_mice,np.unique(y_mouse_full))) #mice only given one condition must be in the training data
for mouse in np.unique(y_mouse_full):
    mouse_mask = np.logical_and(y_mouse_full==mouse,y_hab_full)
    
    if np.mean(y_flx_full[mouse_mask==1]) > .9:
        print("{} only has one class and will be kept in the training data".format(mouse))
        default_training_mice.append(mouse)
    
    
remaining_mice = [mouse for mouse in np.unique(y_mouse_full) if mouse not in default_training_mice]
print(default_training_mice,len(default_training_mice))
print(remaining_mice,len(remaining_mice))

Mouse3202 only has one class and will be kept in the training data
Mouse99002 only has one class and will be kept in the training data
Mouse99003 only has one class and will be kept in the training data
Mouse99021 only has one class and will be kept in the training data
['Mouse3191', 'Mouse3192', 'Mouse3193', 'Mouse3194', 'Mouse3203', 'Mouse3202', 'Mouse99002', 'Mouse99003', 'Mouse99021'] 9
['Mouse61631', 'Mouse61635', 'Mouse69061', 'Mouse78732', 'Mouse78743', 'Mouse78744', 'Mouse78745', 'Mouse78751', 'Mouse78752', 'Mouse78764'] 10


In [24]:
#Get other splits randomly
np.random.seed(42)


rem_idxs = np.arange(len(remaining_mice))
np.random.shuffle(rem_idxs)

train_mice = default_training_mice 
val_mice = [remaining_mice[idx] for idx in rem_idxs[:4]]
test_mice = [remaining_mice[idx] for idx in rem_idxs[4:]]


print("train",train_mice)
print("val",val_mice)
print("test",test_mice)

train ['Mouse3191', 'Mouse3192', 'Mouse3193', 'Mouse3194', 'Mouse3203', 'Mouse3202', 'Mouse99002', 'Mouse99003', 'Mouse99021']
val ['Mouse78752', 'Mouse61635', 'Mouse78744', 'Mouse61631']
test ['Mouse78751', 'Mouse69061', 'Mouse78764', 'Mouse78743', 'Mouse78732', 'Mouse78745']


In [25]:
train_idxs = np.array([1 if mouse in train_mice else 0 for mouse in y_mouse_full])
val_idxs = np.array([1 if mouse in val_mice else 0 for mouse in y_mouse_full])
test_idxs = np.array([1 if mouse in test_mice else 0 for mouse in y_mouse_full])

In [26]:
np.mean(train_idxs)

0.3988027157910106

In [27]:
np.mean(val_idxs)

0.23797240405908548

In [28]:
np.mean(test_idxs)

0.3632248801499039

In [29]:
#Create and Save New Dictionaries

slice_idxs = train_idxs
train_dict = {
    "X_psd":X_psd_full[slice_idxs==1],
    "X_coh":X_coh_full[slice_idxs==1],
    "X_gc":X_gc_full[slice_idxs==1],
    
    "y_flx":y_flx_full[slice_idxs==1],
    "y_mouse":y_mouse_full[slice_idxs==1],
    "y_expDate":y_expDate_full[slice_idxs==1],
    "y_time":y_time_full[slice_idxs==1],
    "y_hab":y_hab_full[slice_idxs==1],
    
    "date-created":"April 27 2023"
}

with open("/work/mk423/Anxiety/final_FLX_train.pkl","wb") as f:
    pickle.dump(train_dict,f)
    
slice_idxs = val_idxs
val_dict = {
    "X_psd":X_psd_full[slice_idxs==1],
    "X_coh":X_coh_full[slice_idxs==1],
    "X_gc":X_gc_full[slice_idxs==1],
    
    "y_flx":y_flx_full[slice_idxs==1],
    "y_mouse":y_mouse_full[slice_idxs==1],
    "y_expDate":y_expDate_full[slice_idxs==1],
    "y_time":y_time_full[slice_idxs==1],
    "y_hab":y_hab_full[slice_idxs==1],
    
    "date-created":"April 27 2023"
}

with open("/work/mk423/Anxiety/final_FLX_val.pkl","wb") as f:
    pickle.dump(val_dict,f)
    
slice_idxs = test_idxs
test_dict = {
    "X_psd":X_psd_full[slice_idxs==1],
    "X_coh":X_coh_full[slice_idxs==1],
    "X_gc":X_gc_full[slice_idxs==1],
    
    "y_flx":y_flx_full[slice_idxs==1],
    "y_mouse":y_mouse_full[slice_idxs==1],
    "y_expDate":y_expDate_full[slice_idxs==1],
    "y_time":y_time_full[slice_idxs==1],
    "y_hab":y_hab_full[slice_idxs==1],
    
    "date-created":"April 27 2023"
}

with open("/work/mk423/Anxiety/final_FLX_test.pkl","wb") as f:
    pickle.dump(test_dict,f)

In [75]:
np.unique(test_dict["y_mouse"])

array(['Mouse3194', 'Mouse3203', 'Mouse61635', 'Mouse69061', 'Mouse78744',
       'Mouse78751'], dtype='<U10')

In [76]:
test_mice

['Mouse3203',
 'Mouse69061',
 'Mouse78744',
 'Mouse78751',
 'Mouse3194',
 'Mouse61635']