In [1]:
import pandas as pd
import numpy as np
import random
import tensorflow as tf
from os.path import isfile, join
import os
import ast
from stable_baselines import PPO2
import sys
sys.path.append('..')
from modules import utils, constants
from multiprocessing import Process
from modules.env import AnemiaEnv
import warnings
warnings.filterwarnings('ignore')

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



  "stable-baselines is in maintenance mode, please use [Stable-Baselines3 (SB3)](https://github.com/DLR-RM/stable-baselines3) for an up-to-date version. You can find a [migration guide](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html) in SB3 documentation."


In [2]:
seed = 63

In [3]:
val_df = pd.read_csv('../data/val_set_constant.csv')
val_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,9.496631,-1.0,-1.0,3.515439,375.685261,101.027779,122.174205,2.820006,1,1.27622,147.925454,74.18824,23.800042,11.545421,97.593498,28.489894,32.520362,1
1,7.09278,336.562609,-1.0,-1.0,115.507513,75.569193,35.449822,2.815743,0,0.870413,105.709466,118.836405,30.607322,2.113748,88.833122,21.278341,30.690491,3
2,6.554136,-1.0,1.774495,6.14581,-1.0,86.8497,-1.0,2.263958,0,0.569184,139.078814,-1.0,-1.0,-1.0,-1.0,19.662409,-1.0,6
3,12.417159,252.249921,1.70186,0.188413,515.449324,76.102948,-1.0,4.89488,1,0.576132,4.58823,-1.0,-1.0,-1.0,-1.0,37.251478,-1.0,3
4,7.748672,258.549935,2.980693,5.137341,363.214987,82.395181,99.002425,2.821283,1,0.885522,84.419465,33.706997,19.591641,27.675132,69.578682,23.246016,27.257252,5


In [4]:
X_val = val_df.iloc[:, 0:-1]
y_val = val_df.iloc[:, -1]
X_val, y_val = np.array(X_val), np.array(y_val)
X_val.shape, y_val.shape

((5600, 17), (5600,))

In [5]:
def get_steps(filename, prefix):
    try:
        return int(filename[len(prefix)+3:][:-10])
    except Exception as e:
        print(f'Filename: {filename}')
        print(f'Exception: {e}')

In [6]:
get_steps('ppo_seed_63_100000000_steps.zip', 'ppo_seed_')

100000000

In [7]:
def synthetic_ppo_eval(ppo_model, X_test, y_test):
    attempts, correct = 0,0
    test_df = pd.DataFrame()

    env = AnemiaEnv(X_test, y_test, random=False)
    count=0

    try:
        while True:
            count+=1
            if count%5000==0:
                print(f'Count: {count}')
            obs, done = env.reset(), False
            while not done:
                action, _states = ppo_model.predict(obs, deterministic=True)
                obs, rew, done,info = env.step(action)
                #if (done==True) & (np.isfinite(info['y_pred'])):
                if done == True:
                    test_df = test_df.append(info, ignore_index=True)
                #print('....................TEST DF ....................')
                #if len(test_df) != 0:
                #    print(test_df.head())

    except StopIteration:
        print('Testing done.....')
    return test_df


In [8]:
def test(ytest, ypred):
    acc = accuracy_score(ytest, ypred)
    f1 = f1_score(ytest, ypred, average ='macro', labels=np.unique(ytest))
    try:
        roc_auc = multiclass(ytest, ypred)
    except:
        roc_auc = None
    return acc, f1, roc_auc

In [13]:
def validate_model(seed, X_val, y_val, prefix):
    best_acc = -1
    best_pathway_score, best_pahm_score, best_wpahm_score, best_wmean_score = -1, -1, -1, -1
    perf_list = []
    count = 0
    
    folder = f'../models/sb/ppo/seed_{seed}'    
    for item in os.listdir(folder):        
        if item.startswith(prefix): 
            path = join(folder, item)
            if (isfile(path)) & (path.endswith('.zip')):
                count+=1
                if count%10 == 0:
                    print(count)
                ppo_model = PPO2.load(path)
                model_df = synthetic_ppo_eval(ppo_model, X_val, y_val)
                acc, f1, roc_auc = utils.get_metrics(model_df)
                perf_dict = {'steps': get_steps(item, prefix), 'acc':acc, 'f1':f1, 'roc_auc':roc_auc} 
                min_path_length = model_df.episode_length.min()
                perf_dict['min_path_length'] = min_path_length
                perf_dict['average_path_length'] = model_df.episode_length.mean()
                max_path_length = model_df.episode_length.max()
                perf_dict['max_path_length'] = max_path_length
                perf_dict['min_sample_pathway'] = model_df[model_df.episode_length==min_path_length].trajectory.iloc[0]
                perf_dict['max_sample_pathway'] = model_df[model_df.episode_length==max_path_length].trajectory.iloc[0]
                print(perf_dict)
                perf_list.append(perf_dict)

                if acc > best_acc:
                    best_acc = acc
                    ppo_model.save(f'{folder}/best_acc_model')
                    
    val_df = pd.DataFrame.from_dict(perf_list) 
    try:
        val_df = val_df.sort_values(by=['steps'])
    except:
        pass
    val_df = val_df.reset_index(drop=True)
    val_df.to_csv(f'{folder}/validation_results.csv', index=False)
    return val_df          

In [14]:
validate_model(seed, X_val, y_val, 'ppo_seed_')

Loading a model without an environment, this model cannot be trained until it has a valid environment.
Count: 5000
Testing done.....
{'steps': 320000000, 'acc': 12.892857142857142, 'f1': 2.8551091426763677, 'roc_auc': 50.0, 'min_path_length': 1.0, 'average_path_length': 1.0, 'max_path_length': 1.0, 'min_sample_pathway': ['Hemolytic anemia'], 'max_sample_pathway': ['Hemolytic anemia']}
Loading a model without an environment, this model cannot be trained until it has a valid environment.
Count: 5000
Testing done.....
{'steps': 180000000, 'acc': 12.892857142857142, 'f1': 2.8551091426763677, 'roc_auc': 50.0, 'min_path_length': 1.0, 'average_path_length': 1.0, 'max_path_length': 1.0, 'min_sample_pathway': ['Hemolytic anemia'], 'max_sample_pathway': ['Hemolytic anemia']}
Loading a model without an environment, this model cannot be trained until it has a valid environment.
Count: 5000
Testing done.....
{'steps': 500000000, 'acc': 12.892857142857142, 'f1': 2.8551091426763677, 'roc_auc': 50.0, 

Count: 5000
Testing done.....
{'steps': 440000000, 'acc': 12.892857142857142, 'f1': 2.8551091426763677, 'roc_auc': 50.0, 'min_path_length': 1.0, 'average_path_length': 1.0, 'max_path_length': 1.0, 'min_sample_pathway': ['Hemolytic anemia'], 'max_sample_pathway': ['Hemolytic anemia']}
Loading a model without an environment, this model cannot be trained until it has a valid environment.
Count: 5000
Testing done.....
{'steps': 480000000, 'acc': 12.892857142857142, 'f1': 2.8551091426763677, 'roc_auc': 50.0, 'min_path_length': 1.0, 'average_path_length': 1.0, 'max_path_length': 1.0, 'min_sample_pathway': ['Hemolytic anemia'], 'max_sample_pathway': ['Hemolytic anemia']}
Loading a model without an environment, this model cannot be trained until it has a valid environment.
Count: 5000
Testing done.....
{'steps': 100000000, 'acc': 12.892857142857142, 'f1': 2.8551091426763677, 'roc_auc': 50.0, 'min_path_length': 1.0, 'average_path_length': 1.0, 'max_path_length': 1.0, 'min_sample_pathway': ['Hem

Unnamed: 0,steps,acc,f1,roc_auc,min_path_length,average_path_length,max_path_length,min_sample_pathway,max_sample_pathway
0,20000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
1,40000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
2,60000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
3,80000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
4,100000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
5,120000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
6,140000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
7,160000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
8,180000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
9,200000000,12.892857,2.855109,50.0,1.0,1.0,1.0,[Hemolytic anemia],[Hemolytic anemia]
