In [42]:
import pandas as pd
import numpy as np
import random
import os
from os.path import isfile, join
# import torch
import sys
sys.path.append('../')
from modules import utils, constants
from modules.env import LupusEnv
# import stable_baselines3, sb3_contrib
import stable_baselines
import warnings
from stable_baselines.common.vec_env import DummyVecEnv
warnings.filterwarnings('ignore')

In [18]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
# torch.manual_seed(SEED)
# torch.use_deterministic_algorithms(True)

#### Useful functions

In [19]:
def get_steps(filename, prefix):
    return int(filename[len(prefix):][:-10])

In [20]:
def load_model(filepath, model_type, env):
    if model_type.lower() == 'dqn3': #dqn
        return stable_baselines3.DQN.load(filepath, env=env)
    elif model_type.lower() == 'ppo': #PPO
        return stable_baselines3.PPO.load(filepath, env=env)
    elif model_type.lower() == 'ppo2': #PPO2
        return stable_baselines.PPO2.load(filepath, env=env)
    elif model_type.lower() == 'ppo3': #PPO3
        return stable_baselines3.PPO3.load(filepath, env=env) 
    elif model_type.lower() == 'ac': #AC
        return stable_baselines3.AC.load(filepath, env=env) 
    elif model_type.lower() == 'a2c': #A2C
        return stable_baselines3.A2C.load(filepath, env=env)
    elif model_type.lower() == 'acer': #ACER
        return stable_baselines.ACER.load(filepath, env=env) 
    elif model_type.lower() == 'ddpg': #DDPG
        return stable_baselines3.DDPG.load(filepath, env=env)
    elif model_type.lower() == 'pg': #PG
        return stable_baselines3.PG.load(filepath, env=env)
    elif model_type.lower() == 'acktr': #TD3
        return stable_baselines.ACKTR.load(filepath, env=env)
    elif model_type.lower() == 'trpo': #TRPO
        return sb3_contrib.TRPO.load(filepath, env=env)
    else:
        raise ValueError(f'Unknown model type {model_type}')

In [21]:
def get_val_metrics(model, validation_env):
    val_df = pd.DataFrame()
    try:
        while True:
            obs, done = validation_env.reset(), False
            while not done:
                action, states = model.predict(obs, deterministic=True)
                obs, rew, done, info = validation_env.step(action)
                if done==True:
                    val_df = val_df.append(info, ignore_index=True)

    except StopIteration:
        pass
    acc, f1, roc_auc, = utils.test(val_df['y_actual'], val_df['y_pred'])
    min_path_length = val_df.episode_length.min()
    average_path_length = val_df.episode_length.mean()
    max_path_length = val_df.episode_length.max()
    min_sample_pathway = val_df[val_df.episode_length==min_path_length].trajectory.iloc[0]
    max_sample_pathway = val_df[val_df.episode_length==max_path_length].trajectory.iloc[0]
    return acc, f1, roc_auc, min_path_length, average_path_length, max_path_length, min_sample_pathway, max_sample_pathway

In [28]:
def create_test_df(model_path, model_type, X_test, y_test):
    testing_env = DummyVecEnv([lambda: LupusEnv(X_test, y_test, random=False)])
#     testing_env = LupusEnv(X_test, y_test, random=False)
    model = load_model(model_path, model_type, testing_env)
    test_df = pd.DataFrame()
    try:
        while True:
            obs, done = testing_env.reset(), False
            while not done:
                action, states = model.predict(obs, deterministic=True)
                obs, rew, done, info = testing_env.step(action)
                if done==True:
                    test_df = test_df.append(info, ignore_index=True)

    except StopIteration:
        pass
    min_path_length = test_df.episode_length.min()
    average_path_length = test_df.episode_length.mean()
    max_path_length = test_df.episode_length.max()
    return test_df

In [23]:
def create_val_df(folder, X_val, y_val, prefix, model_type='dqn3'):
    best_f1, best_acc, best_roc_auc = -1, -1, -1
    perf_list = []
        
    for item in os.listdir(folder):
        path = join(folder, item)
        #print(path)
        if (isfile(path)) & (path.endswith('.zip')):
#             validation_env = DummyVecEnv([lambda: LupusEnv(X_val, y_val, random=False)])
            validation_env = LupusEnv(X_val, y_val, random=False)
            model = load_model(path, model_type, validation_env)
            acc, f1, roc_auc, min_path_length, average_path_length, max_path_length, min_sample_pathway, max_sample_pathway = get_val_metrics(model, validation_env)
           
            perf_dict = {'steps': get_steps(item, prefix), 'acc':acc, 'f1':f1, 'roc_auc':roc_auc, 
                          'min_path_length':min_path_length, 'average_path_length':average_path_length, 
                          'max_path_length':max_path_length, 'min_sample_pathway':min_sample_pathway, 
                          'max_sample_pathway':max_sample_pathway} 
            perf_list.append(perf_dict)
            if acc > best_acc:
                best_acc = acc
                model.save(f'{folder}/best_acc_model')
            if f1 > best_f1:
                best_f1 = f1
                model.save(f'{folder}/best_f1_model')
            if roc_auc > best_roc_auc:
                best_roc_auc = roc_auc
                model.save(f'{folder}/best_roc_auc_model')
    
    print('creating dataframe object')
    val_df = pd.DataFrame.from_dict(perf_list) 
    val_df = val_df.sort_values(by=['steps'])
    val_df = val_df.reset_index(drop=True)
    print('saving validation results')
    val_df.to_csv(f'{folder}/validation_results.csv', index=False)
    return val_df

#### Validation

In [24]:
validation_df = pd.read_csv('../data/missingness/0/validation_set.csv')
validation_df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,joint_involvement,proteinuria,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label
0,1,0,0,0,0,0,0,0,1,0,...,0,0,0,0,0,0,1,1,0,1
1,1,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
4,1,0,0,0,0,1,0,0,1,0,...,0,1,0,1,1,0,0,0,0,1


In [9]:
X_val = validation_df.iloc[:, 0:-1]
y_val = validation_df.iloc[:, -1]
X_val, y_val = np.array(X_val), np.array(y_val)
X_val.shape, y_val.shape

((5600, 23), (5600,))

In [10]:
folder = '../models/logs/dqn2'
prefix = 'dqn2_basic_'
val_df = create_val_df(folder, X_val, y_val, prefix, 'dqn3')
val_df.head()

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

Unnamed: 0,steps,acc,f1,roc_auc,min_path_length,average_path_length,max_path_length,min_sample_pathway,max_sample_pathway
0,500000,0.0,0.0,50.0,18.0,23.95125,24.0,"[ana, anti_dsdna_antibody, joint_involvement, ...","[ana, anti_dsdna_antibody, joint_involvement, ..."
1,1000000,0.0,0.0,50.0,17.0,23.961429,24.0,"[ana, anti_dsdna_antibody, joint_involvement, ...","[ana, anti_dsdna_antibody, joint_involvement, ..."
2,1500000,0.0,0.0,50.0,18.0,23.929286,24.0,"[ana, fever, anti_dsdna_antibody, leukopenia, ...","[ana, fever, anti_dsdna_antibody, leukopenia, ..."
3,2000000,0.0,0.0,50.0,19.0,23.926607,24.0,"[ana, anti_dsdna_antibody, fever, joint_involv...","[ana, anti_dsdna_antibody, joint_involvement, ..."
4,2500000,0.0,0.0,50.0,19.0,23.933393,24.0,"[ana, anti_dsdna_antibody, joint_involvement, ...","[ana, anti_dsdna_antibody, joint_involvement, ..."


#### delete from here

#### end here

#### Testing

In [25]:
testing_df = pd.read_csv('../data/missingness/0/testing_set.csv')
testing_df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,joint_involvement,proteinuria,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label
0,1,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,1
1,1,1,0,0,0,0,0,0,0,0,...,0,1,0,0,0,0,0,1,0,1
2,1,0,0,0,0,0,1,0,0,0,...,1,0,0,0,1,1,0,0,0,1
3,1,1,0,0,0,0,0,0,1,0,...,0,0,0,0,1,0,0,1,0,1
4,1,0,0,0,0,0,0,0,1,0,...,1,0,0,0,0,1,0,0,0,1


In [26]:
X_test = testing_df.iloc[:, 0:-1]
y_test = testing_df.iloc[:, -1]
X_test, y_test = np.array(X_test), np.array(y_test)
X_test.shape, y_test.shape

((14000, 23), (14000,))

In [44]:
test_df = create_test_df('../models/logs/dqn2/best_acc_model.zip', 'dqn3', X_test, y_test)
test_df.head()

NameError: name 'stable_baselines3' is not defined

In [39]:
acc, f1, roc_auc = utils.test(test_df['y_actual'], test_df['y_pred'])
acc, f1, roc_auc

(84.97142857142858, 84.97039674315599, 85.13631614504776)

In [40]:
test_df.episode_length.min(), test_df.episode_length.mean(), test_df.episode_length.max()

(3, 4.0842857142857145, 5)

In [41]:
test_df.to_csv('../test_dfs/23_03_23/acktr_basic_test_df.csv', index=False)