In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense

from transformers import TFBertForSequenceClassification

import numpy as np
import pandas as pd

from datetime import datetime
import os
import sys
import pickle
import time
import argparse

sys.path.insert(0, '/vast/nj594/mimic_explain/fastshap_text')
from surrogate import TextSurrogate

sys.path.insert(0, '/vast/nj594/xai/helpers')
from evaluate import evaluate_mimic as evaluate
from fastshap_dkl import FastSHAP_TEXT as FastSHAP

# IMPORTANT: SET RANDOM SEEDS FOR REPRODUCIBILITY
os.environ['PYTHONHASHSEED'] = str(420)
import random
random.seed(420)
np.random.seed(420)
tf.random.set_seed(420)


#Select GPU
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Command Line Arguements
parser = argparse.ArgumentParser(description='REAL-X Mimic Experiment')
parser.add_argument('--arg_file', type=str, default='', metavar='a',
                    help='Path to File with Grid Search Arguments')
parser.add_argument('--index', type=int, default=9999, metavar='i',
                    help='Index for Job Array')
parser.add_argument('--verbose', type=int, default=1, metavar='v',
                    help='Prints Outputs')
args = parser.parse_args()

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Get Index (Either from argument or from SLURM JOB ARRAY)
if 'SLURM_ARRAY_TASK_ID' in os.environ:
    args.index = int(os.environ['SLURM_ARRAY_TASK_ID'])
    print('SLURM_ARRAY_TASK_ID found..., using index %s' % args.index)
else:
    print('no SLURM_ARRAY_TASK_ID... using index %s' % args.index)
    

In [2]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Get Arguments
index = 0
arg_file = 'fastshap/arg_file.pkl'
with open(arg_file, "rb") as arg_file:
    arg_file = pickle.load(arg_file)

arg_file = arg_file[index]

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#Set Model Dir
method = 'fastshap-dkl'
run = str(index)
model_dir = os.path.join(method, run)
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#Load Data 

data_dir = './data'
label_list = [0, 1]
max_seq_length = 128
num_classes = len(label_list)

### Initialize Tokenizer

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

mask_token = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]

### Load

train_dir = os.path.join(data_dir, 'train_dataset')
val_dir = os.path.join(data_dir, 'val_dataset')
test_dir = os.path.join(data_dir, 'test_dataset')

element_spec = ({'input_ids': tf.TensorSpec(shape=(128,), dtype=tf.int32, name=None),
                 'attention_mask': tf.TensorSpec(shape=(128,), dtype=tf.int32, name=None),
                 'token_type_ids': tf.TensorSpec(shape=(128,), dtype=tf.int32, name=None)},
                tf.TensorSpec(shape=(2,), dtype=tf.int32, name=None))

train_data = tf.data.experimental.load(train_dir, element_spec)
val_data = tf.data.experimental.load(val_dir, element_spec)
test_data = tf.data.experimental.load(test_dir, element_spec)
X_test = np.vstack([x[0]['input_ids'].numpy() for x in test_data])
X_val = np.vstack([x[0]['input_ids'].numpy() for x in val_data])
y_test = np.vstack([y.numpy() for x,y in test_data])
y_val = np.vstack([y.numpy() for x,y in val_data])

### Batch

batch_size = 16
train_data = train_data.shuffle(20000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
val_data = val_data.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
test_data = test_data.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

# Load Surrogate

from transformers import TFBertForSequenceClassification
surrogate_model = TFBertForSequenceClassification.from_pretrained('surrogate/surrogate')    
surrogate = TextSurrogate(surrogate_model = surrogate_model,
                          seq_length = max_seq_length,
                          baseline = mask_token)

### Get Predicted Class
from transformers import TFBertForSequenceClassification
bert_model='./model/model'
base_model = TFBertForSequenceClassification.from_pretrained(bert_model)

model = Sequential()
model.add(base_model)
model.add(tf.keras.layers.Lambda(lambda x: x.logits))
model.add(tf.keras.layers.Activation('softmax'))
for x in test_data:
    model(x)
    break

model.trainable = False

preds = model.predict(test_data)
preds_discrete = np.eye(2)[preds.argmax(1)]

preds_val = model.predict(val_data)
preds_discrete_val = np.eye(2)[preds_val.argmax(1)]

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# FastSHAP

### Specify Explainer Architecture

from transformers.models.bert.modeling_tf_bert import TFBertPredictionHeadTransform
bert_model='./model/model'
base_model = TFBertForSequenceClassification.from_pretrained(bert_model)

bert_main = base_model.layers[0]

inputs = {}
for k, v in train_data.unbatch().element_spec[0].items():
    inputs[k] = tf.keras.layers.Input(shape = v.shape, name = k, dtype = v.dtype) 
input_key = [k for k in train_data.element_spec[0].keys() if 'input' in k.lower()][0] 

bert_out = bert_main(inputs)

net = TFBertPredictionHeadTransform(config = bert_main.config)(bert_out[0])
out = Dense(1)(net)

explainer = Model(inputs, out)

2022-05-05 16:38:49.938090: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-05 16:38:49.938629: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 32252 MB memory:  -> device: 0, name: Vega 20, pci bus id: 0000:c5:00.0
2022-05-05 16:38:50.819989: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
2022-05-05 16:38:50.838798: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:38:50.841328: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:38:51.294670: I tensorflo

array([[  101,  3058,  1997, ..., 17850, 12104,   102],
       [  101, 18583,  1012, ..., 10210,  7941,   102],
       [  101,  2089, 16755, ...,  6292,  2089,   102],
       ...,
       [  101,  3460,  2012, ...,  2566,  2154,   102],
       [  101,  1997,  3052, ...,  2590,  2008,   102],
       [  101,  3058,  1997, ...,  5219,  2006,   102]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(16, 128), dtype=int32, numpy=
array([[1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1]], dtype=int32)>, 'token_type_ids': <tf.Tensor: shape=(16, 128), dtype=int32, numpy=
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>}, <tf.Tensor: shape=(16, 2), dtype=int32, numpy=
array([[0, 1],

2022-05-05 16:39:03.224183: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:03.226407: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:03.349671: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:03.351923: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:03.358973: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:03.361071: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:08.570973: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:23.949413: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:23.951612: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 16:39:23

In [61]:
from importlib import reload
import fastshap_dkl
reload(fastshap_dkl)
from fastshap_dkl import FastSHAP_TEXT as FastSHAP

In [62]:
print('Training FastSHAP')
fastshap = FastSHAP(explainer = explainer,
                    imputer = surrogate,
                    baseline = mask_token,
                    normalization=arg_file['normalization'],
                    link='identity')

Training FastSHAP


2022-05-05 17:00:43.376347: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 17:00:43.380299: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 17:00:43.382837: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.


In [63]:
t = time.time()
fastshap.train(train_data = train_data,
               val_data = val_data,
               batch_size = arg_file['batch_size'],
               num_samples = arg_file['num_samples'],
               max_epochs = 10,
               validation_batch_size = arg_file['batch_size'],
               lr=arg_file['lr'],
               min_lr=1e-5,
               lr_factor=0.9,
               eff_lambda=arg_file['eff_lambda'],
               paired_sampling=arg_file['paired_sampling'],
               lookback=arg_file['lookback'],
               verbose=1, 
               model_dir = model_dir)
training_time = time.time() - t

Epoch 1/10


2022-05-05 17:01:08.721465: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 17:01:09.089850: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.
2022-05-05 17:01:09.104219: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:507] ROCm Fusion is enabled.


 106/1641 [>.............................] - ETA: 10:00 - loss: 22.7397 - shap_loss: 22.7397

KeyboardInterrupt: 

In [None]:
with open(os.path.join(model_dir, 'training_time.pkl'), 'wb') as f:
    pickle.dump(training_time, f)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Explain w/ FastSHAP

## TEST
print('Explaining Test')

### Explain

t = time.time()
shap_values = fastshap.shap_values(test_data).squeeze()
explaining_time = time.time() - t

### Save

with open(os.path.join(model_dir, 'explaining_time.pkl'), 'wb') as f:
    pickle.dump(explaining_time, f)

with open(os.path.join(model_dir, 'shap_values.pkl'), 'wb') as f:
    pickle.dump(shap_values, f)

## VAL
print('Explaining Val')

### Explain

t = time.time()
shap_values_val = fastshap.shap_values(val_data).squeeze()
explaining_time_val = time.time() - t

### Save

with open(os.path.join(model_dir, 'explaining_time_val.pkl'), 'wb') as f:
    pickle.dump(explaining_time_val, f)

with open(os.path.join(model_dir, 'shap_values_val.pkl'), 'wb') as f:
    pickle.dump(shap_values_val, f)

loss = np.min(fastshap.val_losses)
with open(os.path.join(model_dir, 'loss.pkl'), 'wb') as f:
    pickle.dump(loss, f)
    

In [None]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
### Add Small Random Noise to prevent ties
#### TEST
shap_values += (np.random.random(shap_values.shape) * 1e-7)
#### VAL
shap_values_val += (np.random.random(shap_values_val.shape) * 1e-7)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Evaluate

### Load DataFrame
df_test = pd.read_csv(os.path.join(data_dir, "test.csv"))
df_val = pd.read_csv(os.path.join(data_dir, "val.csv"))

### Load Evaluator
evaluator_base = TFBertForSequenceClassification.from_pretrained('evaluation/evaluator-data/surrogate')    

evaluator_model = tf.keras.models.Sequential()
evaluator_model.add(evaluator_base)
evaluator_model.add(tf.keras.layers.Lambda(lambda x: x.logits))
evaluator_model.add(tf.keras.layers.Activation('softmax'))
for x in test_data:
    evaluator_model(x)
    break
evaluator_model.summary()

def eval_model(x):
    attention_mask = np.ones_like(x).astype(int)
    token_type_ids = np.zeros_like(x).astype(int)
    
    input_ = dict(
        input_ids = x.astype(int),
        attention_mask = attention_mask,
        token_type_ids = token_type_ids,
    )
    
    return evaluator_model.predict(input_)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#### Retrospective Evaluation ####

# Exclusion
retro_ex_val = evaluate(df_val.copy(), X_val, shap_values_val, evaluator_model, y_val, y_val, 
                        mode = 'exclude', method = method, mask_token=mask_token)
retro_ex_test = evaluate(df_test.copy(), X_test, shap_values, evaluator_model, y_test, y_test, 
                         mode = 'exclude', method = method, mask_token=mask_token)

# Inclusion
retro_in_val = evaluate(df_val.copy(), X_val, shap_values_val, evaluator_model, y_val, y_val, 
                        mode = 'include', method = method, mask_token=mask_token)
retro_in_test = evaluate(df_test.copy(), X_test, shap_values, evaluator_model, y_test, y_test, 
                         mode = 'include', method = method, mask_token=mask_token)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#### Prospective Evaluation ####

# Exclusion
pro_ex_val = evaluate(df_val.copy(), X_val, shap_values_val, evaluator_model, preds_discrete_val, y_val, 
                        mode = 'exclude', method = method, mask_token=mask_token)
pro_ex_test = evaluate(df_test.copy(), X_test, shap_values, evaluator_model, preds_discrete, y_test, 
                         mode = 'exclude', method = method, mask_token=mask_token)

# Inclusion
pro_in_val = evaluate(df_val.copy(), X_val, shap_values_val, evaluator_model, preds_discrete_val, y_val, 
                        mode = 'include', method = method, mask_token=mask_token)
pro_in_test = evaluate(df_test.copy(), X_test, shap_values, evaluator_model, preds_discrete, y_test, 
                         mode = 'include', method = method, mask_token=mask_token)

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Combine Results
tags = ['retro_ex_val','retro_ex_test','retro_in_val','retro_in_test', 
        'pro_ex_val','pro_ex_test','pro_in_val','pro_in_test']
result_list = [retro_ex_val,retro_ex_test,retro_in_val,retro_in_test,
               pro_ex_val,pro_ex_test,pro_in_val,pro_in_test]

results = {}
for res, tag  in zip(result_list, tags):
    res = {k+'-'+tag:v for k,v in res.items()}
    results = {**results, **res}


#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Save

### Create Results Dictionary
header = ["model_dir", "lr", "epochs", "batch_size", "lookback", 
          "num_samples", "paired_sampling", "eff_lambda", "normalization", 
          "training_time", "explaining_time"]
metrics = ['AUC_acc','AUC_auroc','AUC_log_likelihood','AUC_log_odds']
for tag in tags:
    header += [x+'-'+tag for x in metrics]
    
results = {**results, **arg_file}
results['model_dir'] = model_dir
results["explaining_time"] = explaining_time
results["training_time"] = training_time
results = {k:v for k,v in results.items() if k in header}

### Convert to DataFrame
results_df = pd.DataFrame(results, index=[0])
results_df = results_df[header]

### Append DataFrame to csv
results_path = method+'/results.csv'
if os.path.exists(results_path):
    results_df.to_csv(results_path, mode='a',  header=False)
else:
    results_df.to_csv(results_path, mode='w',  header=True)