In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras import regularizers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling1D, Conv1D

import numpy as np
import pandas as pd

from datetime import datetime
import os
import sys
import pickle
import time
import argparse

sys.path.insert(0, '/vast/nj594/ecg_explain/fastshap_ecg')
from surrogate import Surrogate

sys.path.insert(0, '/vast/nj594/xai/helpers')
from evaluate import evaluate
from fastshap_dkl import FastSHAP_ECG as FastSHAP

# IMPORTANT: SET RANDOM SEEDS FOR REPRODUCIBILITY
os.environ['PYTHONHASHSEED'] = str(420)
import random
random.seed(420)
np.random.seed(420)
tf.random.set_seed(420)

In [None]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Command Line Arguements
parser = argparse.ArgumentParser(description='REAL-X Mimic Experiment')
parser.add_argument('--arg_file', type=str, default='', metavar='a',
                    help='Path to File with Grid Search Arguments')
parser.add_argument('--index', type=int, default=9999, metavar='i',
                    help='Index for Job Array')
parser.add_argument('--verbose', type=int, default=1, metavar='v',
                    help='Prints Outputs')
args = parser.parse_args()

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Get Index (Either from argument or from SLURM JOB ARRAY)
if 'SLURM_ARRAY_TASK_ID' in os.environ:
    index = int(os.environ['SLURM_ARRAY_TASK_ID'])
    print('SLURM_ARRAY_TASK_ID found..., using index %s' % index)
else:
    print('no SLURM_ARRAY_TASK_ID... using index %s' % index)
    

In [2]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Get Arguments
arg_file = 'fastshap/arg_file.pkl'
index = 0
print(arg_file)
with open(arg_file, "rb") as arg_file:
    arg_file = pickle.load(arg_file)

arg_file = arg_file[index]
print(arg_file)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#Set Model Dir
method = 'fastshap-dkl'
run = index
model_dir = os.path.join(method, str(run))
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Load Data

data_dir = os.path.join(os.getcwd(), 'data')

X_train = np.load(os.path.join(data_dir, 'X_train.npy'), allow_pickle=True)
X_val = np.load(os.path.join(data_dir, 'X_val.npy'), allow_pickle=True)
X_test = np.load(os.path.join(data_dir, 'X_test.npy'), allow_pickle=True)

y_train = np.load(os.path.join(data_dir, 'y_train.npy'), allow_pickle=True)
y_val = np.load(os.path.join(data_dir, 'y_val.npy'), allow_pickle=True)
y_test = np.load(os.path.join(data_dir, 'y_test.npy'), allow_pickle=True)

preds = np.load(os.path.join(data_dir, 'predictions.npy'), allow_pickle=True)
preds_discrete = np.eye(2)[preds.argmax(1)]

preds_val = np.load(os.path.join(data_dir, 'predictions_val.npy'), allow_pickle=True)
preds_discrete_val = np.eye(2)[preds_val.argmax(1)]


#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Load Surrogate

superpixel_size = 8
surrogate_model = tf.keras.models.load_model(os.path.join(os.getcwd(), 'surrogate', 'surrogate.h5'))    
surrogate = Surrogate(surrogate_model = surrogate_model,
                           baseline = 0,
                           width = 1000, 
                           superpixel_size = superpixel_size)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# FastSHAP

### Specify Explainer Architecture

params = {
    #NN Hyperparameters
    "input_shape": [1000, 1],
    "num_categories": 2,
    "conv_subsample_lengths": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
    "conv_filter_length": 8,
    "conv_num_filters_start": 32,
    "conv_init": "he_normal",
    "conv_activation": "relu",
    "conv_dropout": 0.2,
    "conv_num_skip": 2,
    "conv_increase_channels_at": 4,
    "compile": False,
    "is_regular_conv": False,
    "is_by_time": False,
    "is_by_lead": False,
    "ecg_out_size": 64,
    "nn_layer_sizes" : None,
    "is_multiply_layer": False, 
}
num_classes = 2

#Stanford Model
sys.path.insert(0, '/scratch/nj594/ecg/models/stanford')
import network

cnn = network.build_network(**params) 
base_model = Model(cnn.inputs, cnn.layers[67].output)

model_input = Input(shape=(1000,1))
net = base_model(model_input)
out = Conv1D(1,1)(net)

explainer = Model(model_input, out)

### Extract Superpixel Size

superpixel_size = int(1000/explainer.output.shape[1])

fastshap/arg_file.pkl
{'batch_size': 16, 'eff_lambda': 0.0, 'epochs': 100, 'lookback': 20, 'lr': 0.001, 'normalization': None, 'num_samples': 1, 'paired_sampling': False}


2022-05-04 09:53:26.102554: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-04 09:53:26.103051: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 32252 MB memory:  -> device: 0, name: Vega 20, pci bus id: 0000:48:00.0




In [4]:
arg_file['num_samples']

1

In [5]:
### Train FastSHAP

fastshap = FastSHAP(explainer = explainer,
                    imputer = surrogate,
                    baseline = 0,
                    normalization=arg_file['normalization'],
                    link='identity')

t = time.time()
fastshap.train(train_data = X_train,
               val_data = X_val,
               batch_size = arg_file['batch_size'],
               num_samples = 4,
               max_epochs = 2,
               validation_batch_size = arg_file['batch_size'],
               lr=arg_file['lr'],
               min_lr=1e-5,
               lr_factor=0.9,
               eff_lambda=arg_file['eff_lambda'],
               paired_sampling=arg_file['paired_sampling'],
               lookback=arg_file['lookback'],
               verbose=1, 
               model_dir = model_dir)
training_time = time.time() - t

with open(os.path.join(model_dir, 'training_time.pkl'), 'wb') as f:
    pickle.dump(training_time, f)

2022-05-04 09:55:26.341875: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-04 09:55:26.345491: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-04 09:55:26.347986: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.


The following Variables were used a Lambda layer's call (f_x), but
are not present in its tracked objects:
  <tf.Variable 'conv1d_33/kernel:0' shape=(8, 1, 32) dtype=float32>
  <tf.Variable 'conv1d_33/bias:0' shape=(32,) dtype=float32>
  <tf.Variable 'batch_normalization_33/gamma:0' shape=(32,) dtype=float32>
  <tf.Variable 'batch_normalization_33/beta:0' shape=(32,) dtype=float32>
  <tf.Variable 'conv1d_34/kernel:0' shape=(8, 32, 32) dtype=float32>
  <tf.Variable 'conv1d_34/bias:0' shape=(32,) dtype=float32>
  <tf.Variable 'batch_normalization_34/gamma:0' shape=(32,) dtype=float32>
  <tf.Variable 'batch_normalization_34/beta:0' shape=(32,) dtype=float32>
  <tf.Variable 'conv1d_35/kernel:0' shape=(8, 32, 32) dtype=float32>
  <tf.Variable 'conv1d_35/bias:0' shape=(32,) dtype=float32>
  <tf.Variable 'batch_normalization_35/gamma:0' shape=(32,) dtype=float32>
  <tf.Variable 'batch_normalization_35/beta:0' shape=(32,) dtype=float32>
  <tf.Variable 'conv1d_36/kernel:0' shape=(8, 32, 32) dty

Epoch 1/2


2022-05-04 09:55:34.233031: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.


   1/1070 [..............................] - ETA: 2:04:53 - loss: 2.3431 - shap_loss: 2.3431

MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record



MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record



MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record



MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486
MIOpen(HIP): Error [FindRecordUnsafe] Ill-formed record: key not found: /home/nj594/.config/miopen//gfx906_60.HIP.2_11_0_.ufdb.txt#486



Epoch 00001: val_shap_loss improved from inf to 0.74457, saving model to fastshap-dkl/0/explainer_weights.h5
Epoch 2/2

Epoch 00002: val_shap_loss did not improve from 0.74457
Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           [(None, 1000, 1)]         0         
_________________________________________________________________
model_1 (Functional)         (None, 125, 1)            249441    
_________________________________________________________________
reshape_4 (Reshape)          (None, 125, 1)            0         
_________________________________________________________________
permute_1 (Permute)          (None, 1, 125)            0         
_________________________________________________________________
phi (Layer)                  (None, 1, 125)            0         
Total params: 249,441
Trainable params: 248,097
Non-trainable params: 1,344
____



In [None]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Explain w/ FastSHAP

### Explain

t = time.time()
shap_values = fastshap.explainer.predict(X_test)
explaining_time = time.time() - t

shap_values_val = fastshap.explainer.predict(X_val)

### Save

with open(os.path.join(model_dir, 'explaining_time.pkl'), 'wb') as f:
    pickle.dump(explaining_time, f)
    
with open(os.path.join(model_dir, 'shap_values.pkl'), 'wb') as f:
    pickle.dump(shap_values, f)
    
    

In [None]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Load Evaluator Model

eval_dir = os.path.join(os.getcwd(), 'evaluation', 'evaluator-data')
evaluator_model = tf.keras.models.load_model(os.path.join(eval_dir, 'surrogate.h5'))

OPTIMIZER = tf.keras.optimizers.Adam(1e-3)
METRICS = [ 
  tf.keras.metrics.AUC(name='auroc'),
  tf.keras.metrics.AUC(curve='PR', name='auprc'),
  tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='accuracy'),
]

evaluator_model.compile(
    loss='categorical_crossentropy',
    optimizer=OPTIMIZER,
    metrics=METRICS,
)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#### Retrospective Evaluation ####

# Exclusion
retro_ex_val = evaluate(X_val, shap_values_val, evaluator_model, y_val, y_val, 
                        mode = 'exclude', method = method)
retro_ex_test = evaluate(X_test, shap_values, evaluator_model, y_test, y_test, 
                         mode = 'exclude', method = method)

# Inclusion
retro_in_val = evaluate(X_val, shap_values_val, evaluator_model, y_val, y_val, 
                        mode = 'include', method = method)
retro_in_test = evaluate(X_test, shap_values, evaluator_model, y_test, y_test, 
                         mode = 'include', method = method)


#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#### Prospective Evaluation ####

# Exclusion
pro_ex_val = evaluate(X_val, shap_values_val, evaluator_model, preds_discrete_val, y_val, 
                        mode = 'exclude', method = method)
pro_ex_test = evaluate(X_test, shap_values, evaluator_model, preds_discrete, y_test, 
                         mode = 'exclude', method = method)

# Inclusion
pro_in_val = evaluate(X_val, shap_values_val, evaluator_model, preds_discrete_val, y_val, 
                        mode = 'include', method = method)
pro_in_test = evaluate(X_test, shap_values, evaluator_model, preds_discrete, y_test, 
                         mode = 'include', method = method)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Combine Results
tags = ['retro_ex_val','retro_ex_test','retro_in_val','retro_in_test', 
        'pro_ex_val','pro_ex_test','pro_in_val','pro_in_test']
result_list = [retro_ex_val,retro_ex_test,retro_in_val,retro_in_test,
               pro_ex_val,pro_ex_test,pro_in_val,pro_in_test]

results = {}
for res, tag  in zip(result_list, tags):
    res = {k+'-'+tag:v for k,v in res.items()}
    results = {**results, **res}

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Save

### Create Results Dictionary
header = ["model_dir", "lr", "epochs", "batch_size", "lookback", 
          "num_samples", "paired_sampling", "eff_lambda", "normalization", 
          "training_time", "explaining_time"]
metrics = ['AUC_acc','AUC_auroc','AUC_log_likelihood','AUC_log_odds']
for tag in tags:
    header += [x+'-'+tag for x in metrics]
    
results = {**results, **arg_file}
results['model_dir'] = model_dir
results["explaining_time"] = explaining_time
results["training_time"] = training_time
results = {k:v for k,v in results.items() if k in header}

### Convert to DataFrame
results_df = pd.DataFrame(results, index=[0])
results_df = results_df[header]

### Append DataFrame to csv
results_path = method+'/results.csv'
if os.path.exists(results_path):
    results_df.to_csv(results_path, mode='a',  header=False)
else:
    results_df.to_csv(results_path, mode='w',  header=True)