### Evaluating Faithfulness on our model:

In [1]:
import pandas as pd
import pickle
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from datasets import load_dataset
from lime.lime_text import LimeTextExplainer
import torch.nn.functional as F
import shap
from transformers import Pipeline
import os 
import numpy

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
data_dir = "output/"
destination_dir = "./"
print(device)

cuda


In [2]:

test_short_path = "data/test_10_top50_short.csv"
labels_10_top50 = pd.read_csv('data/icd10_codes_top50.csv')
code_labels_10 = pd.read_csv("data/icd10_codes.csv")
print("dataset loaded?")

dataset loaded?


In [3]:
# Model Parameters
MAX_POSITION_EMBEDDINGS = 512
MODEL = "emilyalsentzer/Bio_ClinicalBERT"
CKPT = os.path.join(data_dir,"best_model_state.bin")

In [4]:
# Create class dictionaries
classes = [class_ for class_ in code_labels_10["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}

print("classes")

config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

tokenizer_bert = AutoTokenizer.from_pretrained(MODEL)
model_bert = AutoModel.from_pretrained(MODEL, config=config, cache_dir='./model_ckpt/')
print("bert model and tokenizer initialized")



classes
bert model and tokenizer initialized


In [8]:
class TokenizerWrapper:
    def __init__(self, tokenizer, length, classes):
        self.tokenizer = tokenizer
        self.max_length = length
        self.classes = classes
        self.class2id = {class_: id for id, class_ in enumerate(self.classes)}
        self.id2class = {id: class_ for class_, id in self.class2id.items()}
        
    def multi_labels_to_ids(self, labels: list[str]) -> list[float]:
        ids = [0.0] * len(self.class2id)  # BCELoss requires float as target type
        for label in labels:
            ids[self.class2id[label]] = 1.0
        return ids
    
    def tokenize_function(self, example):
        result = self.tokenizer(
            example["text"],
            max_length = self.max_length,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
        result["label"] = torch.tensor([self.multi_labels_to_ids(eval(label)) for label in example["label"]])
        return result
        
data_files = {
        "test": test_short_path,
    }

tokenizer_wrapper = TokenizerWrapper(tokenizer_bert, MAX_POSITION_EMBEDDINGS, classes)
dataset = load_dataset("csv", data_files=data_files)
dataset = dataset.map(tokenizer_wrapper.tokenize_function, batched=True, num_proc=1)
dataset = dataset.with_format("torch")
print("dataset loaded")

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

dataset loaded


In [9]:

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.config = config
        self.device = device
        self.bert_model = model_bert
        self.can_generate = model_bert.can_generate
        self.base_model_prefix = model_bert.base_model_prefix
        self.get_input_embeddings = model_bert.get_input_embeddings
        self.dropout = torch.nn.Dropout(0.2)
        self.linear = torch.nn.Linear(self.bert_model.config.hidden_size, 50)
    
    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output
    
model_bert = BERTClass()
model_bert.load_state_dict(torch.load(CKPT))
model_bert = model_bert.to(device)

In [37]:
import lime
from lime import lime_text
from lime.lime_text import LimeTextExplainer
from lime.lime_text import IndexedString
import numpy as np
import torch.nn.functional as F
from time import time


explainer = LimeTextExplainer(class_names=classes, bow=False)

def predictor_opt(texts):
    print(len(texts))
    tk = tokenizer_bert(
            texts,
            max_length = MAX_POSITION_EMBEDDINGS,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
    ids = tk['input_ids'].to(device, dtype = torch.long)
    mask = tk['attention_mask'].to(device, dtype = torch.long)
    token_type_ids = tk['token_type_ids'].to(device, dtype = torch.long)
    outputs = model_bert(ids, mask, token_type_ids)
    # tensor_logits = outputs[0]
    # probas = tensor_logits.sigmoid().detach().cpu().numpy()
    probas = F.sigmoid(outputs).detach().cpu().numpy()
    del tk, outputs
    # probas = F.sigmoid(tensor_logits).detach().cpu().numpy()
    return probas


def predictor_model(texts, model, tokenizer):
    print(len(texts))
    tk = tokenizer_bert(
            texts,
            max_length = MAX_POSITION_EMBEDDINGS,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
    ids = tk['input_ids'].to(device, dtype = torch.long)
    mask = tk['attention_mask'].to(device, dtype = torch.long)
    token_type_ids = tk['token_type_ids'].to(device, dtype = torch.long)
    outputs = model_bert(ids, mask, token_type_ids)
    # tensor_logits = outputs[0]
    # probas = tensor_logits.sigmoid().detach().cpu().numpy()
    probas = F.sigmoid(outputs).detach().cpu().numpy()
    del tk, outputs
    # probas = F.sigmoid(tensor_logits).detach().cpu().numpy()
    return probas

The instances are formatted as a list of strings, where each string is one word used by lime. The rationales mask is a list of indices, where the first list refers to the index of the sample the label corresponds to and the second list is the index of string used in that rationale.

### Test code for faithfulness calculation

### 

In [45]:
# create a test with 10 instances for faithfulness evaluation
# from transformers import AutoTokenizer

# get the lime evaluations of each instance
import faithfulness
# this reimports the library for easy testing in the notebook
import importlib
importlib.reload(faithfulness)

samples_start = 0
samples_end = 50

instances = dataset["test"][samples_start:samples_end]["text"]
# print(len(instances))

# print(instances)
explainer = LimeTextExplainer(class_names=classes, bow=False)

indexed_text, index_array_rationalle = faithfulness.lime_create_index_arrays(instances, predictor_opt, explainer)
# print(indexed_text)
# print(index_array_rationalle)

# # remove the rationale words
rationalle_removed = faithfulness.remove_rationale_words(indexed_text, index_array_rationalle)
others_removed = faithfulness.remove_other_words(indexed_text, index_array_rationalle)

# rationalle_removed = rationalle_removed + rationalle_removed + rationalle_removed + rationalle_removed + rationalle_removed
# others_removed = others_removed + others_removed + others_removed + others_removed + others_removed 
# instances = instances + instances + instances + instances + instances

# print(rationalle_removed)

# print(len(rationalle_removed))
# print(len(others_removed))



10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10


In [46]:
# the extra list is needed since the function expects a list of instances each coming from a different interpretability method
# testing multi input by duplicating the arrays, don't actually do this
ind, faith = faithfulness.calculate_faithfulness(instances, [rationalle_removed, rationalle_removed], [others_removed, others_removed], model_bert, tokenizer_bert, predictor_model)
print(ind)
print(faith)

8
8
8
8
8
8
2
Currently interpreting instance:  0
Calculating Sufficiency
8
8
8
8
8
8
2
Calculating Comprehensiveness
8
8
8
8
8
8
2

-- Metrics -------------------------------------------------------------


Faithfulness for iteration:  0.09531664
Comprehensiveness for iteration:  0.41091752
Sufficency for iteration:  0.23196052


Comprehensiveness Median:  0.2714703
Comprehensiveness q1 (25% percentile):  0.2267022393643856
Comprehensiveness q3 (75% percentile):  0.4419192001223564


Sufficency Median:  0.23101819
Sufficency q1 (25% percentile):  0.19091162458062172
Sufficency q3 (75% percentile):  0.25325431674718857

Currently interpreting instance:  1
Calculating Sufficiency
8
8
8
8
8
8
2
Calculating Comprehensiveness
8
8
8
8
8
8
2

-- Metrics -------------------------------------------------------------


Faithfulness for iteration:  0.09601459
Comprehensiveness for iteration:  0.42271277
Sufficency for iteration:  0.22713907


Comprehensiveness Median:  0.29238832
Comprehensivene