In [1]:
import pandas as pd
import numpy as np
import random
from os.path import isfile, join
import os
import ast
from stable_baselines3 import DQN
import torch as th
import sys
sys.path.append('..')
from modules import former_constants as constants
from modules.env import LupusEnv
from multiprocessing import Process
from stable_baselines3.common.vec_env import DummyVecEnv
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from modules.env import LupusEnv
import warnings
warnings.filterwarnings('ignore')

In [2]:
SEED = constants.SEED
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
th.manual_seed(SEED)
th.use_deterministic_algorithms(True)
SEED

126

#### Necessary functions

In [3]:
def load_dqn(filename):
    '''
    Loads a previously saved DQN model
    '''
    model = DQN.load(filename)

In [4]:
def create_env(X, y, random=True):
    '''
    Creates and environment using the given data
    '''
    env = LupusEnv(X, y, random)
    print(f'The environment seed is {env.seed()}') #to delete
    return env

In [5]:
def evaluate_dqn(dqn_model, X_test, y_test):
    test_df = pd.DataFrame()
    env = create_env(X_test, y_test, random=False)
    count=0

    try:
        while True:
            count+=1
            obs, done = env.reset(), False
            while not done:
                action, _states = dqn_model.predict(obs, deterministic=True)
                obs, rew, done, info = env.step(action)
                if done == True:
                    test_df = test_df.append(info, ignore_index=True)
    except StopIteration:
        pass
    return test_df

In [6]:
def multiclass(actual_class, pred_class, average = 'macro'):
    '''
    Returns the ROC-AUC score for multi-labeled data
    '''

    unique_class = set(actual_class)
    roc_auc_dict = {}
    for per_class in unique_class:
        other_class = [x for x in unique_class if x != per_class]
        new_actual_class = [0 if x in other_class else 1 for x in actual_class]
        new_pred_class = [0 if x in other_class else 1 for x in pred_class]
        roc_auc = roc_auc_score(new_actual_class, new_pred_class, average = average)
        roc_auc_dict[per_class] = roc_auc
    avg = sum(roc_auc_dict.values()) / len(roc_auc_dict)
    return avg

In [7]:
def test(ytest, ypred):
    '''
    Return performance metrics for a model
    '''
    acc = accuracy_score(ytest, ypred)
    f1 = f1_score(ytest, ypred, average ='macro', labels=np.unique(ytest))
    try:
        roc_auc = multiclass(ytest, ypred)
    except:
        roc_auc = None
    return acc*100, f1*100, roc_auc*100

In [8]:
def get_val_metrics(model, X_val, y_val):
    val_df = evaluate_dqn(model, X_val, y_val)
    acc, f1, roc_auc, = test(val_df['y_actual'], val_df['y_pred'])
    min_path_length = val_df.episode_length.min()
    average_path_length = val_df.episode_length.mean()
    max_path_length = val_df.episode_length.max()
    min_sample_pathway = val_df[val_df.episode_length==min_path_length].trajectory.iloc[0]
    max_sample_pathway = val_df[val_df.episode_length==max_path_length].trajectory.iloc[0]
    return acc, f1, roc_auc, min_path_length, average_path_length, max_path_length, min_sample_pathway, max_sample_pathway

#### Reading the data

In [9]:
val_df = pd.read_csv('../new_data/val_set_constant.csv')
val_df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,proteinuria,biopsy_proven_lupus_nephritis,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0
1,1,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,1,0,0,0
4,0,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [10]:
X_val = val_df.iloc[:, 0:-1]
y_val = val_df.iloc[:, -1]
X_val, y_val = np.array(X_val), np.array(y_val)
X_val.shape, y_val.shape

((5600, 24), (5600,))

In [11]:
def get_steps(filename, prefix):
    try:
        return int(filename[len(prefix)+1:][:-10])
    except Exception as e:
        print(f'Filename: {filename}')
        print(f'Exception: {e}')

In [12]:
def validate_model(seed, steps, X_val, y_val, prefix):
    best_f1, best_acc, best_roc_auc = -1, -1, -1
    perf_list = []
    count = 0
    
    folder = f'../models/logs/robust_dqn3/noisiness/0.3/biopsy_9/l1_norm/seed_{seed}_{steps}'    
    for item in os.listdir(folder):        
        if item.startswith(prefix): # & (get_steps(item, prefix) > 20000000:
            path = join(folder, item)
#             print(f'path: {path}')
            if (isfile(path)) & (path.endswith('.zip')):
                count+=1
                if count%10 == 0:
                    print(count)
                model = DQN.load(path)
#                 print(f'Model type: {type(model)}')
#                     wpahm_score = utils.get_val_metrics(model, X_val, y_val)
                acc, f1, roc_auc, min_length, avg_length, max_length, min_path, max_path=get_val_metrics(model, X_val, y_val)
                perf_dict = {'steps': get_steps(item, prefix), 'acc':acc, 'f1':f1, 'roc_auc':roc_auc, 
                             'min_path_length':min_length, 'avg_length':avg_length, 'max_length':max_length, 
                             'min_path':min_path, 'max_path':max_path} 
                perf_list.append(perf_dict)
                if acc > best_acc:
                    best_acc = acc
                    model.save(f'{folder}/best_acc_model')
                if f1 > best_f1:
                    best_f1 = f1
                    model.save(f'{folder}/best_f1_model')
                if roc_auc > best_roc_auc:
                    best_roc_auc = roc_auc
                    model.save(f'{folder}/best_roc_auc_model')

    val_df = pd.DataFrame.from_dict(perf_list) 
    try:
        val_df = val_df.sort_values(by=['steps'])
    except:
        pass
    val_df = val_df.reset_index(drop=True)
    val_df.to_csv(f'{folder}/validation_results.csv', index=False)
    return val_df          

In [13]:
steps=100000000

In [14]:
# filename = '../models/logs/robust_dqn3/noisiness/0.0/biopsy_9/l2_norm/seed_42_100000000/robust_dqn3_69000000_steps.zip'
# os.path.exists(filename)

In [15]:
# model = DQN.load(filename)
# type(model)

In [16]:
validate_model(SEED, steps, X_val, y_val, 'robust_dqn3')

The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
10
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
Filename: robust_dqn3_100000000_full_model.zip
Exception: invalid literal for int() with base 10: '100000000_full'
The environment seed is [126]
The environment seed is [126]
20
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
The environment seed is [126]
30
The en

Unnamed: 0,steps,acc,f1,roc_auc,min_path_length,avg_length,max_length,min_path,max_path
0,1000000.0,0.000000,0.000000,50.000000,17.0,24.839107,25.0,"[anti_dsdna_antibody, low_c4, cutaneous_lupus,...","[anti_dsdna_antibody, low_c4, cutaneous_lupus,..."
1,2000000.0,0.000000,0.000000,50.000000,17.0,24.904464,25.0,"[pericardial_effusion, low_c4, delirium, low_c...","[pericardial_effusion, low_c4, non_scarring_al..."
2,3000000.0,0.000000,0.000000,50.000000,19.0,24.940893,25.0,"[non_scarring_alopecia, low_c4, pericardial_ef...","[non_scarring_alopecia, pericardial_effusion, ..."
3,4000000.0,0.000000,0.000000,50.000000,18.0,24.898750,25.0,"[low_c4, pericardial_effusion, non_scarring_al...","[low_c4, non_scarring_alopecia, pericardial_ef..."
4,5000000.0,0.000000,0.000000,50.000000,18.0,24.828214,25.0,"[delirium, ana, low_c3, non_scarring_alopecia,...","[delirium, low_c3, ana, non_scarring_alopecia,..."
...,...,...,...,...,...,...,...,...,...
96,97000000.0,88.321429,88.315053,88.324469,7.0,14.008750,23.0,"[auto_immune_hemolysis, delirium, anti_dsdna_a...","[auto_immune_hemolysis, delirium, low_c4, anti..."
97,98000000.0,91.017857,91.210401,91.236484,10.0,17.150357,25.0,"[biopsy_proven_lupus_nephritis, pericardial_ef...","[biopsy_proven_lupus_nephritis, psychosis, leu..."
98,99000000.0,89.196429,89.231372,89.245390,7.0,15.801607,24.0,"[seizure, proteinuria, joint_involvement, lupu...","[seizure, pericardial_effusion, low_c4, biopsy..."
99,100000000.0,91.089286,91.095451,91.101072,6.0,15.446786,24.0,"[biopsy_proven_lupus_nephritis, low_c4, deliri...","[biopsy_proven_lupus_nephritis, pleural_effusi..."
