### Evaluating Faithfulness on our model:

In [1]:
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from datasets import load_dataset
import torch.nn.functional as F
from lime.lime_text import LimeTextExplainer
import torch.nn.functional as F
import os 
import numpy as np

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
data_dir = "output/"
destination_dir = "./"
print(device)

cuda


In [2]:

test_short_path = "data/test_10_top50_short.csv"
labels_10_top50 = pd.read_csv('data/icd10_codes_top50.csv')
code_labels_10 = pd.read_csv("data/icd10_codes.csv")
print("dataset loaded?")

dataset loaded?


In [3]:
# Model Parameters
MAX_POSITION_EMBEDDINGS = 512
MODEL = "emilyalsentzer/Bio_ClinicalBERT"
CKPT = os.path.join(data_dir,"best_model_state.bin")

In [4]:
# Create class dictionaries
classes = [class_ for class_ in code_labels_10["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}

print("classes")

config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

tokenizer_bert = AutoTokenizer.from_pretrained(MODEL)
model_bert = AutoModel.from_pretrained(MODEL, config=config, cache_dir='./model_ckpt/')
print("bert model and tokenizer initialized")

classes
bert model and tokenizer initialized


In [5]:
class TokenizerWrapper:
    def __init__(self, tokenizer, length, classes):
        self.tokenizer = tokenizer
        self.max_length = length
        self.classes = classes
        self.class2id = {class_: id for id, class_ in enumerate(self.classes)}
        self.id2class = {id: class_ for class_, id in self.class2id.items()}
        
    def multi_labels_to_ids(self, labels: list[str]) -> list[float]:
        ids = [0.0] * len(self.class2id)  # BCELoss requires float as target type
        for label in labels:
            ids[self.class2id[label]] = 1.0
        return ids
    
    def tokenize_function(self, example):
        result = self.tokenizer(
            example["text"],
            max_length = self.max_length,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
        result["label"] = torch.tensor([self.multi_labels_to_ids(eval(label)) for label in example["label"]])
        return result
        
data_files = {
        "test": test_short_path,
    }

tokenizer_wrapper = TokenizerWrapper(tokenizer_bert, MAX_POSITION_EMBEDDINGS, classes)
dataset = load_dataset("csv", data_files=data_files)
dataset = dataset.map(tokenizer_wrapper.tokenize_function, batched=True, num_proc=1)
dataset = dataset.with_format("torch")
print("dataset loaded")

dataset loaded


In [6]:

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.config = config
        self.device = device
        self.bert_model = model_bert
        self.can_generate = model_bert.can_generate
        self.base_model_prefix = model_bert.base_model_prefix
        self.get_input_embeddings = model_bert.get_input_embeddings
        self.dropout = torch.nn.Dropout(0.2)
        self.linear = torch.nn.Linear(self.bert_model.config.hidden_size, 50)
    
    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output
    
model_bert = BERTClass()
model_bert.load_state_dict(torch.load(CKPT))
model_bert = model_bert.to(device)

In [64]:
explainer = LimeTextExplainer(class_names=classes, bow=False)

def predictor_bert(texts):
    tk = tokenizer_bert(
            texts,
            max_length = MAX_POSITION_EMBEDDINGS,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
    ids = tk['input_ids'].to(device, dtype = torch.long)
    mask = tk['attention_mask'].to(device, dtype = torch.long)
    token_type_ids = tk['token_type_ids'].to(device, dtype = torch.long)
    outputs = model_bert(ids, mask, token_type_ids)
    probas = F.sigmoid(outputs).detach().cpu().numpy()
    del tk, outputs
    return probas


def predictor_model(texts, model, tokenizer):
    tk = tokenizer(
            texts,
            max_length = MAX_POSITION_EMBEDDINGS,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
    ids = tk['input_ids'].to(device, dtype = torch.long)
    mask = tk['attention_mask'].to(device, dtype = torch.long)
    token_type_ids = tk['token_type_ids'].to(device, dtype = torch.long)
    outputs = model(ids, mask, token_type_ids)
    probas = F.sigmoid(outputs).detach().cpu().numpy()
    del tk, outputs
    return probas

The instances are formatted as a list of strings, where each string is one word used by lime. The rationales mask is a list of indices, where the first list refers to the index of the sample the label corresponds to and the second list is the index of string used in that rationale.

### Test code for faithfulness calculation

In [71]:
import faithfulness_lime as faithfulness
# this reimports the library for easy testing in the notebook
import importlib
from lime.lime_text import IndexedString
import numpy as np
importlib.reload(faithfulness)

samples_start = 0
samples_end = 3
input_data = dataset["test"]["text"][samples_start:samples_end]
print(len(input_data[0]))
print(len(input_data[1]))
print(len(input_data[2]))
explainer = LimeTextExplainer(class_names=classes, bow=False)
indexed_text, index_array_rationalle = faithfulness.lime_create_index_arrays(input_data, predictor_bert, explainer, k_labels = 10)
# print(indexed_text)
print(len(indexed_text))
print(len(indexed_text[0]))
print(len(indexed_text[1]))
print(len(indexed_text[2]))
print(len(index_array_rationalle))
print(len(index_array_rationalle[0]))
print(len(index_array_rationalle[1]))
print(index_array_rationalle[0])
print(index_array_rationalle[1])
explanations = []
for s in input_data:
    indexed_string = IndexedString(s)
    with torch.no_grad():
        exp = explainer.explain_instance(s, predictor_bert, num_features=indexed_string.num_words(), num_samples=10)
    exp = exp.as_list()
    cur_limes = {}
    tkns = []
    for i in range(len(exp)):
        # print(exp[i])
        tkns.append(exp[i][0])
    for i in range(len(exp)):
        cur_limes[exp[i][0]] = (exp[i][1], i)
    cur_limes = sorted(cur_limes.items(), key=lambda x: abs(x[1][0]), reverse=True)
    # print(cur_limes)
    explanations.append((cur_limes, tkns))


# remove the rationale words

rationalle_removed = faithfulness.remove_rationale_words(indexed_text, index_array_rationalle)
others_removed = faithfulness.remove_other_words(indexed_text, index_array_rationalle)

print(len(rationalle_removed))
print('first',rationalle_removed[0])
print('second', rationalle_removed[1])
print('third', rationalle_removed[2])
print(len(others_removed))

11849
8074
7224
3
3737
3737
3737
2
21030
21030
[0 0 0 ... 2 2 2]
[ 333  516  616 ... 1731 1288  416]
3
first           No                           Procedure         resection              :
                                                                                                                                                                             ,                                                                   the                                                                                                                                                                                                                                                                           3mm            ,                                                                                                                                                         x                                                                                                                                

In [79]:
import faithfulness_lime as faithfulness
importlib.reload(faithfulness)

# NOTE: April 19th version  was working for LIME. Resolution: just create another instance for LIME if I'm too lazy to rewrite it
# the extra list is needed since the function expects a list of instances each coming from a different interpretability method
# testing multi input by duplicating the array
# s, don't actually do this
_, faith = faithfulness.calculate_faithfulness(input_data, rationalle_removed, others_removed, model_bert, tokenizer_bert, predictor_model)
print(faith)

Calculating Sufficiency
['          No                           Procedure         resection              :\n                                                                                                                                                                             ,                                                                   the                                                                                                                                                                                                                                                                           3mm            ,                                                                                                                                                         x                                                                                                                                             Right5                                 touch                         