In [1]:
import numpy as np
import pandas as pd
import copy
import random
import itertools
import json
import pickle

from pathlib import Path
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn import mixture

import torch.nn.functional as F
import torch

In [2]:
from env import CustomizeEnv
from AHIL.ahil_utils import *
from AHIL.EM_EDM_utils import *
from evaluation import *






In [3]:
base_path = '' 
args = {'env_name': 'CCHS', 'trial': 1} 
nrCl_dict = {'CCHS': 3} # Cluster Number 

# Whether to train the model/load from file
training_mode = True
save_model_result = True ### Whether to save the model

# ----------------------------------------------------------------------------------------------
# Load the config file
config_path = 'example/cchs_para_config.npy'
config = np.load(config_path, allow_pickle=True).item()

# Update the config values if necessary
config['NUM_STEPS_TRAIN'] = 500
config['BATCH_SIZE'] = 64
config['MLP_WIDTHS'] = 32

config['EMEDM_nrCl'] = nrCl_dict[args ['env_name']]
config['SGLD_BUFFER_SIZE'] = 10000
config['SGLD_LEARN_RATE'] = 1e-3

config['SAMPLE_BUFFER'] = 'balanced' # 'random'/'stratified'/'balanced'
config['BASE_PATH'] = base_path
config['MODEL_LEARNER'] = 'EDM'
print(config)

# ---------------------------------------------------------------------------------------------
avg_patterns = 'binary' # 'binary'/'weighted'/'micro'/'macro'
init_mode = 'random' # Initialization method: 'random'/'dtw'/'kmeans'

{'NUM_STEPS_TRAIN': 500, 'EMEDM_LLH_THRES_CONV': 1e-05, 'EMEDM_LLH_THRES_DES': 10000.0, 'BATCH_SIZE': 64, 'MLP_WIDTHS': 32, 'EMEDM_ITERS': 10, 'EMEDM_nrCl': 3, 'ENV': 'CCHS', 'GYM_ENV': ['CartPole-v1', 'Acrobot-v1', 'MountainCar-v0'], 'ADAM_ALPHA': 0.001, 'ADAM_BETAS': [0.9, 0.999], 'SGLD_BUFFER_SIZE': 10000, 'SGLD_LEARN_RATE': 0.001, 'SGLD_NOISE_COEF': 0.01, 'SGLD_NUM_STEPS': 20, 'SGLD_REINIT_FREQ': 0.05, 'SAMPLE_BUFFER': 'balanced', 'EMEDM_BETA': 0.5, 'EMEDM_CLUSTER_THRES': 0, 'BASE_PATH': '', 'MODEL_LEARNER': 'EDM'}


In [4]:
# Dataset info
env = 'CCHS'
id_feat = 'VisitIdentifier'
sel_feats = ['SystolicBP', 'MAP', 'RespiratoryRate', 'PulseOx', 'HeartRate', 'Temperature', 
             'WBC', 'BiliRubin', 'BUN', 'Lactate', 'Creatinine', 'Platelet', 'Bands', 'FIO2']
act_num = 2

max_sel_feats = ['max_' + i for i in sel_feats]
min_sel_feats = ['min_' + i for i in sel_feats]
all_feats = sel_feats + max_sel_feats + min_sel_feats

In [5]:
# Load the data 
df = pd.read_csv('../baseline/example/cchs_sample_data.csv')
vid_list = sorted(np.unique(df.VisitIdentifier))

# Get the list of dataframes
df_list = []
vid_list = sorted(np.unique(df.VisitIdentifier))
for vid in vid_list: 
    df_list.append(df.loc[df.VisitIdentifier == vid])

In [6]:
seed_rp_results = [] 

for rp_idx in range(1): 
    # ------------------------------------------------------------------------------------
    # Get the training and testing index
    tr_idx, te_idx = train_test_split(np.arange(len(df_list)), test_size=0.2, random_state=rp_idx)
    
    # Slice the training & testing data
    tr_df = [df_list[i] for i in tr_idx]
    te_df = [df_list[i] for i in te_idx]
    
    # Sub-clustering the training data by GMM
    tr_gmm_data = np.array(pd.concat([tr_df[i][sel_feats] for i in range(len(tr_df))]))
    
    gmm_clusters = 6
    gmm = mixture.GaussianMixture(n_components=gmm_clusters, max_iter=10000)
    gmm.fit(tr_gmm_data)
    tr_gmm_pred = gmm.predict(tr_gmm_data)

    # Get the sub-trajectories
    tr_sub_demos, tr_sub_demos_lab = subTrajectoriesbyCluster(tr_df, all_feats, tr_gmm_pred, clus_num=gmm_clusters)
    
    # ------------------------------------------------------------------------------------
    # Training the model for each cluster
    student_dict = {}
    tr_sub_trajs, tr_sub_returns = tr_sub_demos['trajs'], tr_sub_demos['returns']
    
    for clus_idx in range(gmm_clusters): 
        print('Cluster Idx:', clus_idx)
        
        # Get the demos within the current cluster
        clus_tr_sub_trajs, clus_tr_sub_returns = [], []
        for i in range(len(tr_sub_demos_lab)): 
            if tr_sub_demos_lab[i] == clus_idx: 
                clus_tr_sub_trajs.append(tr_sub_trajs[i])
                clus_tr_sub_returns.append(tr_sub_returns[i])
        
        clus_sub_demos = {}
        clus_sub_demos['trajs'], clus_sub_demos['returns'] = clus_tr_sub_trajs, clus_tr_sub_returns
    
        # Training process 
        # EM-EDM for each cluster(cluster number = 1) ==> Apply EDM for each cluster
        init_seeds = 42 
        nrCl = 1
        config['ENV_VOLUME_PATH'] = base_path + 'example/' + args['env_name']
        config['CLUS_BASE_PATH'] = config['ENV_VOLUME_PATH'] 
        
        # Load the expert trajectories
        clus_teachers = clus_sub_demos 
        clus_trajs, clus_returns = clus_teachers['trajs'], clus_teachers['returns']
        traj_num = len(clus_trajs)
    
        model_folder = 'example/AHIL_models/' + args['env_name']
        model_path = ''.join([model_folder, '/AHIL', '_', str(traj_num), '_fold_', str(rp_idx), '.sav'])
        Path(model_folder).mkdir(parents = True, exist_ok = True)
        config['MODEL_PATH'] = model_path
    
        # Learn the different policies by EM-EDM
        print('** Applying EM-EDM to learn policies... ')
        seed_tmp_results, seed_tmp_jaccard = [], []
        
        if training_mode:
            # Initialize the clusters 
            pred_labs = InitEMEDMClusters(clus_trajs, nrCl, len(tr_df), init_mode=init_mode, 
                                          DTW_thres=1e5, max_iter=10, init_seed=init_seeds)
            # ---------------------------------------------------------
            # EM-EDM to cluster the trajectories            
            student_list, rho, LLH, nrCl, pred_labs, pred_probs = EMEDMWarped(clus_teachers, nrCl, pred_labs, 
                                                                              config, decay_expert=False, verbo=False)
            # ---------------------------------------------------------
            # Save the learnd models (policies per cluster) to file
            if save_model_result: 
                pickle.dump((student_list, rho), open(model_path, 'wb'))
        else: 
            student_list, rho = pickle.load(open(model_path, 'rb'))
    
        student_dict[clus_idx] = student_list[0]
    
    # ------------------------------------------------------------------------------------
    # Sub-clustering the testing data by GMM
    te_gmm_data = np.array(pd.concat([te_df[i][sel_feats] for i in range(len(te_df))]))
    te_gmm_pred = gmm.predict(te_gmm_data)
    te_sub_demos, te_sub_demos_lab = subTrajectoriesbyCluster(te_df, all_feats, te_gmm_pred, clus_num=gmm_clusters)

    # ------------------------------------------------------------------------------------
    # Model evaluation
    true_act, pred_act, pred_prob = [], [], []
    # For each subtrajectory find the corresponding policy
    for idx in range(len(te_sub_demos_lab)): 
        clusIdx = te_sub_demos_lab[idx]
        if clusIdx in student_dict.keys(): 
            pass
        else: 
            print('RANDOMLY SELECT A CLUSTER TO FIT ...')
            random.seed(idx + rp_idx)
            clusIdx = random.sample(student_dict.keys(), 1)[0]
    
        # For each state-action pair
        for sa_idx in range(len(te_sub_demos['trajs'][idx])):
            tmp_act, _, tmp_qvalue = student_dict[clusIdx].select_action(te_sub_demos['trajs'][idx][sa_idx][0])
            # Apply softmax to Q-values to get action probabilities
            action_probabilities = F.softmax(torch.tensor(tmp_qvalue), dim=0).numpy()
            
            true_act.append(te_sub_demos['trajs'][idx][sa_idx][1][0])
            pred_act.append(tmp_act)
            pred_prob.append(action_probabilities)

    seed_rp_results.append(overall_eval(true_act, pred_act, pred_prob, 2, avg_patterns=avg_patterns))



* Num of Partitioned trajs:  1580
* Sub-traj for each cluster: 
  -  {0: 545, 1: 278, 2: 43, 3: 88, 4: 78, 5: 548}
Cluster Idx: 0
** Applying EM-EDM to learn policies... 
*** Iteration:  0


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [00:19<00:00, 25.78it/s]


Collapsed to 1 cluster at 0-th iteration
Cluster Idx: 1
** Applying EM-EDM to learn policies... 
*** Iteration:  0


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [00:15<00:00, 31.27it/s]


Collapsed to 1 cluster at 0-th iteration
Cluster Idx: 2
** Applying EM-EDM to learn policies... 
*** Iteration:  0


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [00:16<00:00, 30.94it/s]


Collapsed to 1 cluster at 0-th iteration
Cluster Idx: 3
** Applying EM-EDM to learn policies... 
*** Iteration:  0


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [00:17<00:00, 29.09it/s]


Collapsed to 1 cluster at 0-th iteration
Cluster Idx: 4
** Applying EM-EDM to learn policies... 
*** Iteration:  0


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [00:16<00:00, 30.39it/s]


Collapsed to 1 cluster at 0-th iteration
Cluster Idx: 5
** Applying EM-EDM to learn policies... 
*** Iteration:  0


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [00:16<00:00, 30.74it/s]


Collapsed to 1 cluster at 0-th iteration
* Num of Partitioned trajs:  358
* Sub-traj for each cluster: 
  -  {0: 118, 1: 51, 2: 21, 3: 28, 4: 14, 5: 126}
Performance Measurements:
Confusion matrix: 
 [[173  26]
 [114  11]]
Accuracy:  0.5679012345679012
Recall:  0.088
Precision:  0.2972972972972973
F-score:  0.13580246913580246
AUC:  0.6408643216080402
APR:  0.6080200507597875
Jaccard Score:  0.0728476821192053
