### Evaluating Faithfulness on our model:

In [1]:
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from datasets import load_dataset
import torch.nn.functional as F
import shap
import shap
from transformers import Pipeline

import os 
import numpy

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
data_dir = "output/"
destination_dir = "./"
print(device)

cuda


In [2]:

test_short_path = "data/test_10_top50_short.csv"
labels_10_top50 = pd.read_csv('data/icd10_codes_top50.csv')
code_labels_10 = pd.read_csv("data/icd10_codes.csv")
print("dataset loaded?")

dataset loaded?


In [3]:
# Model Parameters
MAX_POSITION_EMBEDDINGS = 512
MODEL = "emilyalsentzer/Bio_ClinicalBERT"
CKPT = os.path.join(data_dir,"best_model_state.bin")

In [4]:
# Create class dictionaries
classes = [class_ for class_ in code_labels_10["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}

print("classes")

config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

tokenizer_bert = AutoTokenizer.from_pretrained(MODEL)
model_bert = AutoModel.from_pretrained(MODEL, config=config, cache_dir='./model_ckpt/')
print("bert model and tokenizer initialized")



classes
bert model and tokenizer initialized


In [5]:
class TokenizerWrapper:
    def __init__(self, tokenizer, length, classes):
        self.tokenizer = tokenizer
        self.max_length = length
        self.classes = classes
        self.class2id = {class_: id for id, class_ in enumerate(self.classes)}
        self.id2class = {id: class_ for class_, id in self.class2id.items()}
        
    def multi_labels_to_ids(self, labels: list[str]) -> list[float]:
        ids = [0.0] * len(self.class2id)  # BCELoss requires float as target type
        for label in labels:
            ids[self.class2id[label]] = 1.0
        return ids
    
    def tokenize_function(self, example):
        result = self.tokenizer(
            example["text"],
            max_length = self.max_length,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
        result["label"] = torch.tensor([self.multi_labels_to_ids(eval(label)) for label in example["label"]])
        return result
        
data_files = {
        "test": test_short_path,
    }

tokenizer_wrapper = TokenizerWrapper(tokenizer_bert, MAX_POSITION_EMBEDDINGS, classes)
dataset = load_dataset("csv", data_files=data_files)
dataset = dataset.map(tokenizer_wrapper.tokenize_function, batched=True, num_proc=1)
dataset = dataset.with_format("torch")
print("dataset loaded")

Generating test split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

dataset loaded


In [6]:

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.config = config
        self.device = device
        self.bert_model = model_bert
        self.can_generate = model_bert.can_generate
        self.base_model_prefix = model_bert.base_model_prefix
        self.get_input_embeddings = model_bert.get_input_embeddings
        self.dropout = torch.nn.Dropout(0.2)
        self.linear = torch.nn.Linear(self.bert_model.config.hidden_size, 50)
    
    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output
    
model_bert = BERTClass()
model_bert.load_state_dict(torch.load(CKPT))
model_bert = model_bert.to(device)

### Pipeline Initialization

In [9]:
class BERT_ICD10_Pipeline(Pipeline):
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "maybe_arg" in kwargs:
            preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, text):
        return self.tokenizer(
            text,
            max_length = MAX_POSITION_EMBEDDINGS,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(self.device)

    def _forward(self, model_inputs):
        ids = model_inputs['input_ids'].to(self.device, dtype = torch.long)
        mask = model_inputs['attention_mask'].to(self.device, dtype = torch.long)
        token_type_ids = model_inputs['token_type_ids'].to(self.device, dtype = torch.long)
        outputs = self.model(ids, mask, token_type_ids).to(self.device)
        return outputs

    def postprocess(self, model_outputs):
        probs = F.sigmoid(model_outputs).detach().cpu().numpy() # if there's more than one possible diagnosis

        output = []
        for i, prob in enumerate(probs[0]):
            label = self.model.config.id2label[i]
            score = prob
            output.append({"label": label, "score": score})
        # print(output)
        return output

### Test code for faithfulness calculation

In [11]:

pipeline = BERT_ICD10_Pipeline(model=model_bert, tokenizer=tokenizer_bert, device = device)
print("pipeline initialized")


pipeline initialized


In [47]:
masker = shap.maskers.Text(pipeline.tokenizer)
explainer = shap.Explainer(pipeline, masker)
shap_input = dataset['test']['text'][:1]
print("computing shap")

computing shap


In [48]:
shap.sample(shap_input, 2)



In [49]:
shap_values = explainer(
        shap_input,
        batch_size=5,
        outputs=shap.Explanation.argsort.flip[:2]
        )

print(shap_values)

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:55, 55.07s/it]               


.values =
array([[[ 3.24147594e-04, -7.64450838e-05],
        [ 3.24147594e-04, -7.64450838e-05],
        [ 3.24147594e-04, -7.64450838e-05],
        ...,
        [-4.98865204e-05,  9.77840208e-05],
        [-4.98865204e-05,  9.77840208e-05],
        [-4.98865204e-05,  9.77840208e-05]]])

.base_values =
array([[0.19439352, 0.29103452]])

.data =
(array(['', 'Sex', ':   ', ..., '_', '_', ''], dtype=object),)


In [54]:
test_input = ['a test']

In [55]:
shap_values = explainer(test_input)

In [65]:
import numpy as np

In [70]:
shap_values

.values =
array([[[ 0.00884168,  0.02199024,  0.04240134, ...,  0.        ,
          0.        ,  0.        ],
        [ 0.03830965,  0.15435492, -0.00512322, ...,  0.        ,
          0.        ,  0.        ],
        [ 0.07860697,  0.05347701, -0.00376649, ...,  0.        ,
          0.        ,  0.        ],
        [-0.04718865, -0.00847325,  0.04289608, ...,  0.        ,
          0.        ,  0.        ]]])

.base_values =
array([[0.10521022, 0.25801185, 0.39506537, ..., 0.        , 0.        ,
        0.        ]])

.data =
(array(['', 'a ', 'test', ''], dtype=object),)

In [69]:
vals= np.abs(shap_values.values).mean(0)
vals

array([[0.00884168, 0.02199024, 0.04240134, ..., 0.        , 0.        ,
        0.        ],
       [0.03830965, 0.15435492, 0.00512322, ..., 0.        , 0.        ,
        0.        ],
       [0.07860697, 0.05347701, 0.00376649, ..., 0.        , 0.        ,
        0.        ],
       [0.04718865, 0.00847325, 0.04289608, ..., 0.        , 0.        ,
        0.        ]])

In [77]:
features_column = shap_values.data[0]

In [78]:
feature_importance = pd.DataFrame(list(zip(features_column, sum(vals))), columns=['feature','feature_importance_vals'])

In [79]:
feature_importance.sort_values(by=['feature_importance_vals'], ascending=False,inplace=True)

In [80]:
feature_importance['feature']

Unnamed: 0,col_name,feature_importance_vals
1,a,0.238295
0,,0.172947
3,,0.127548
2,test,0.094187


In [None]:
# TODO:
# 1) Find the top k features from feature_importance
# 2) Keep track of their corresponding token index after tokenization
# 3) Retokenize the input text and remove the token indices in step 2
# 4) Detokenize by combining the strings back together
# 5) Feed it back to the model