In [2]:
import pandas as pd
import numpy as np
import random
import os
from os.path import isfile, join
# import torch
import sys
sys.path.append('../')
from modules import utils, constants
from modules.env import LupusEnv
# import stable_baselines3, sb3_contrib
import stable_baselines
import warnings
import tensorflow
from stable_baselines.common.vec_env import DummyVecEnv
warnings.filterwarnings('ignore')

In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
tensorflow.set_random_seed(constants.SEED)
# torch.manual_seed(SEED)
# torch.use_deterministic_algorithms(True)

#### Useful functions

In [4]:
def get_steps(filename, prefix):
    try:
        return int(filename[len(prefix):][:-10])
    except Exception as e:
        print(f'Filename: {filename}')
        print(f'Exception: {e}')

In [5]:
def load_model(filepath, model_type, env):
    if model_type.lower() == 'dqn3': #dqn3
        return stable_baselines3.DQN.load(filepath, env=env)
    elif model_type.lower() == 'dqn': #dqn
        return stable_baselines.DQN.load(filepath, env=env)
    elif model_type.lower() == 'ppo': #PPO
        return stable_baselines3.PPO.load(filepath, env=env)
    elif model_type.lower() == 'ppo2': #PPO2
        return stable_baselines.PPO2.load(filepath, env=env)
    elif model_type.lower() == 'ppo3': #PPO3
        return stable_baselines3.PPO3.load(filepath, env=env) 
    elif model_type.lower() == 'ac': #AC
        return stable_baselines3.AC.load(filepath, env=env) 
    elif model_type.lower() == 'a2c': #A2C
#         return stable_baselines3.A2C.load(filepath, env=env)
        return stable_baselines.A2C.load(filepath, env=env)
    elif model_type.lower() == 'acer': #ACER
        return stable_baselines.ACER.load(filepath, env=env) 
    elif model_type.lower() == 'ddpg': #DDPG
        return stable_baselines3.DDPG.load(filepath, env=env)
    elif model_type.lower() == 'pg': #PG
        return stable_baselines3.PG.load(filepath, env=env)
    elif model_type.lower() == 'acktr': #TD3
        return stable_baselines.ACKTR.load(filepath, env=env)
    elif model_type.lower() == 'trpo': #TRPO
        return sb3_contrib.TRPO.load(filepath, env=env)
    else:
        raise ValueError(f'Unknown model type {model_type}')

In [6]:
def get_val_metrics(model, validation_env):
    val_df = pd.DataFrame()
    try:
        while True:
            obs, done = validation_env.reset(), False
            while not done:
                action, states = model.predict(obs, deterministic=True)
                obs, rew, done, info = validation_env.step(action)
                if done==True:
                    val_df = val_df.append(info, ignore_index=True)

    except StopIteration:
        pass
    acc, f1, roc_auc, = utils.test(val_df['y_actual'], val_df['y_pred'])
    min_path_length = val_df.episode_length.min()
    average_path_length = val_df.episode_length.mean()
    max_path_length = val_df.episode_length.max()
    min_sample_pathway = val_df[val_df.episode_length==min_path_length].trajectory.iloc[0]
    max_sample_pathway = val_df[val_df.episode_length==max_path_length].trajectory.iloc[0]
    return acc, f1, roc_auc, min_path_length, average_path_length, max_path_length, min_sample_pathway, max_sample_pathway

In [7]:
def create_val_df(folder, X_val, y_val, prefix, model_type='dqn3'):
    best_f1, best_acc, best_roc_auc = -1, -1, -1
    perf_list = []
    count = 0
        
    for item in os.listdir(folder):
         if count%100==0:
            print(count)
        count+=1
        if item.startswith(prefix):
            path = join(folder, item)
#             print(path)
            if (isfile(path)) & (path.endswith('.zip')):
                validation_env = DummyVecEnv([lambda: LupusEnv(X_val, y_val, random=False)])
#                 validation_env = LupusEnv(X_val, y_val, random=False)
                validation_env.seed(SEED)
                model = load_model(path, model_type, validation_env)
                acc, f1, roc_auc, min_path_length, average_path_length, max_path_length, min_sample_pathway, max_sample_pathway = get_val_metrics(model, validation_env)

                perf_dict = {'steps': get_steps(item, prefix), 'acc':acc, 'f1':f1, 'roc_auc':roc_auc, 
                              'min_path_length':min_path_length, 'average_path_length':average_path_length, 
                              'max_path_length':max_path_length, 'min_sample_pathway':min_sample_pathway, 
                              'max_sample_pathway':max_sample_pathway} 
                perf_list.append(perf_dict)
                if acc > best_acc:
                    best_acc = acc
                    model.save(f'{folder}/best_acc_model')
                if f1 > best_f1:
                    best_f1 = f1
                    model.save(f'{folder}/best_f1_model')
                if roc_auc > best_roc_auc:
                    best_roc_auc = roc_auc
                    model.save(f'{folder}/best_roc_auc_model')

#     print('creating dataframe object')
        val_df = pd.DataFrame.from_dict(perf_list) 
        val_df = val_df.sort_values(by=['steps'])
        val_df = val_df.reset_index(drop=True)
#     print('saving validation results')
        val_df.to_csv(f'{folder}/validation_results.csv', index=False)
    return val_df

#### Validation

In [8]:
validation_df = pd.read_csv('../data/missingness/0/validation_set.csv')
validation_df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,joint_involvement,proteinuria,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label
0,1,0,0,0,0,0,0,0,1,0,...,0,0,0,0,0,0,1,1,0,1
1,1,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
4,1,0,0,0,0,1,0,0,1,0,...,0,1,0,1,1,0,0,0,0,1


In [9]:
X_val = validation_df.iloc[:, 0:-1]
y_val = validation_df.iloc[:, -1]
X_val, y_val = np.array(X_val), np.array(y_val)
X_val.shape, y_val.shape

((5600, 23), (5600,))

In [10]:
folder = '../models/logs/sb/vanilla_dqn'
prefix = 'ddqn_basic_'
val_df = create_val_df(folder, X_val, y_val, prefix, 'dqn')
val_df.head()

1
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Use tf.cast instead.
2


KeyboardInterrupt: 

#### delete from here

#### end here

#### Testing