In [None]:
import pandas as pd
import numpy as np
import random
import os
from os.path import isfile, join
import sys
sys.path.append('../')
from modules import utils, constants
from modules.env import AnemiaEnv
import stable_baselines
import warnings
import tensorflow
from gym.spaces import Box
from stable_baselines.common.vec_env import DummyVecEnv
warnings.filterwarnings('ignore')

In [2]:
SEED = constants.SEED
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
tensorflow.set_random_seed(constants.SEED)
tensorflow.compat.v1.set_random_seed(constants.SEED)
# torch.manual_seed(SEED)
# torch.use_deterministic_algorithms(True)




#### Useful functions

In [5]:
def get_steps(filename, prefix):
    try:
        return int(filename[len(prefix):][:-10])
    except Exception as e:
        print(f'Filename: {filename}')
        print(f'Exception: {e}')

In [6]:
def load_model(filepath, model_type, env):
    if model_type.lower() == 'dqn3': #dqn3
        return stable_baselines3.DQN.load(filepath, env=env)
    elif model_type.lower() == 'dqn': #dqn
        return stable_baselines.DQN.load(filepath, env=env)
    else:
        raise ValueError(f'Unknown model type {model_type}')

In [7]:
def get_val_metrics(model, validation_env):
    val_df = pd.DataFrame()
    try:
        while True:
            obs, done = validation_env.reset(), False
            while not done:
                action, states = model.predict(obs, deterministic=True)
                obs, rew, done, info = validation_env.step(action)
                if done==True:
                    val_df = val_df.append(info, ignore_index=True)

    except StopIteration:
        pass
    acc, f1, roc_auc, = utils.get_metrics(val_df)
    min_path_length = val_df.episode_length.min()
    average_path_length = val_df.episode_length.mean()
    max_path_length = val_df.episode_length.max()
    min_sample_pathway = val_df[val_df.episode_length==min_path_length].trajectory.iloc[0]
    max_sample_pathway = val_df[val_df.episode_length==max_path_length].trajectory.iloc[0]
    return acc, f1, roc_auc, min_path_length, average_path_length, max_path_length, min_sample_pathway, max_sample_pathway

In [8]:
def create_val_df(folder, X_val, y_val, prefix, model_type='dqn3'):
    best_f1, best_acc, best_roc_auc = -1, -1, -1
    perf_list = []
    count = 0
        
    for item in os.listdir(folder):
        if count%100==0:
            print(count)
        count+=1
        if item.startswith(prefix):
            path = join(folder, item)
#             print(path)
            if (isfile(path)) & (path.endswith('.zip')):
                validation_env = DummyVecEnv([lambda: AnemiaEnv(X_val, y_val, random=False)])
#                 validation_env = LupusEnv(X_val, y_val, random=False)
                validation_env.seed(SEED)
                model = load_model(path, model_type, validation_env)
                acc, f1, roc_auc, min_path_length, average_path_length, max_path_length, min_sample_pathway, max_sample_pathway = get_val_metrics(model, validation_env)

                perf_dict = {'steps': get_steps(item, prefix), 'acc':acc, 'f1':f1, 'roc_auc':roc_auc, 
                             'min_path_length':min_path_length, 'average_path_length':average_path_length, 
                             'max_path_length':max_path_length, 'min_sample_pathway':min_sample_pathway, 
                             'max_sample_pathway':max_sample_pathway} 
                perf_list.append(perf_dict)
                if acc > best_acc:
                    best_acc = acc
                    model.save(f'{folder}/best_acc_model')
                if f1 > best_f1:
                    best_f1 = f1
                    model.save(f'{folder}/best_f1_model')
                if roc_auc > best_roc_auc:
                    best_roc_auc = roc_auc
                    model.save(f'{folder}/best_roc_auc_model')

#     print('creating dataframe object')
        val_df = pd.DataFrame.from_dict(perf_list) 
        val_df = val_df.sort_values(by=['steps'])
        val_df = val_df.reset_index(drop=True)
#     print('saving validation results')
        val_df.to_csv(f'{folder}/validation_results.csv', index=False)
    return val_df

#### Validation

In [9]:
validation_df = pd.read_csv('../data/val_set_constant.csv')
validation_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,9.496631,-1.0,-1.0,3.515439,375.685261,101.027779,122.174205,2.820006,1,1.27622,147.925454,74.18824,23.800042,11.545421,97.593498,28.489894,32.520362,1
1,7.09278,336.562609,-1.0,-1.0,115.507513,75.569193,35.449822,2.815743,0,0.870413,105.709466,118.836405,30.607322,2.113748,88.833122,21.278341,30.690491,3
2,6.554136,-1.0,1.774495,6.14581,-1.0,86.8497,-1.0,2.263958,0,0.569184,139.078814,-1.0,-1.0,-1.0,-1.0,19.662409,-1.0,6
3,12.417159,252.249921,1.70186,0.188413,515.449324,76.102948,-1.0,4.89488,1,0.576132,4.58823,-1.0,-1.0,-1.0,-1.0,37.251478,-1.0,3
4,7.748672,258.549935,2.980693,5.137341,363.214987,82.395181,99.002425,2.821283,1,0.885522,84.419465,33.706997,19.591641,27.675132,69.578682,23.246016,27.257252,5


In [10]:
X_val = validation_df.iloc[:, 0:-1]
y_val = validation_df.iloc[:, -1]
X_val, y_val = np.array(X_val), np.array(y_val)
X_val.shape, y_val.shape

((5600, 17), (5600,))

In [11]:
folder = '../models/sb/dqn'
prefix = 'dqn_basic_'
val_df = create_val_df(folder, X_val, y_val, prefix, 'dqn')
val_df.head()

0






Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Please use `layer.__call__` method instead.

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where







100


Unnamed: 0,steps,acc,f1,roc_auc,min_path_length,average_path_length,max_path_length,min_sample_pathway,max_sample_pathway
0,500000,9.0,2.06422,50.0,2,3.861071,10,"[tibc, Inconclusive diagnosis]","[tibc, cholestrol, ferritin, glucose, copper, ..."
1,1000000,9.0,2.06422,50.0,3,3.580714,5,"[hemoglobin, tibc, Inconclusive diagnosis]","[hemoglobin, tibc, tsat, glucose, Inconclusive..."
2,1500000,9.0,2.06422,50.0,2,2.891786,4,"[tibc, Inconclusive diagnosis]","[tibc, tsat, glucose, Inconclusive diagnosis]"
3,2000000,9.0,2.06422,50.0,2,2.891786,4,"[tibc, Inconclusive diagnosis]","[tibc, ethanol, glucose, Inconclusive diagnosis]"
4,2500000,9.0,2.06422,50.0,2,2.611786,3,"[tibc, Inconclusive diagnosis]","[tibc, mcv, Inconclusive diagnosis]"
