### Evaluating Faithfulness on our model:

In [1]:
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from datasets import load_dataset
import torch.nn.functional as F
from lime.lime_text import LimeTextExplainer
import torch.nn.functional as F
import os 
import numpy as np

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
data_dir = "output/"
destination_dir = "./"
print(device)

cuda


In [2]:

test_short_path = "data/test_10_top50_short.csv"
labels_10_top50 = pd.read_csv('data/icd10_codes_top50.csv')
code_labels_10 = pd.read_csv("data/icd10_codes.csv")


In [3]:
# Model Parameters
MAX_POSITION_EMBEDDINGS = 512
MODEL = "emilyalsentzer/Bio_ClinicalBERT"
CKPT = os.path.join(data_dir,"best_model_state.bin")

In [4]:
# Create class dictionaries
classes = [class_ for class_ in code_labels_10["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}

print("classes")

config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

tokenizer_bert = AutoTokenizer.from_pretrained(MODEL)
model_bert = AutoModel.from_pretrained(MODEL, config=config, cache_dir='./model_ckpt/')
print("bert model and tokenizer initialized")

classes
bert model and tokenizer initialized


In [5]:
class TokenizerWrapper:
    def __init__(self, tokenizer, length, classes):
        self.tokenizer = tokenizer
        self.max_length = length
        self.classes = classes
        self.class2id = {class_: id for id, class_ in enumerate(self.classes)}
        self.id2class = {id: class_ for class_, id in self.class2id.items()}
        
    def multi_labels_to_ids(self, labels: list[str]) -> list[float]:
        ids = [0.0] * len(self.class2id)  # BCELoss requires float as target type
        for label in labels:
            ids[self.class2id[label]] = 1.0
        return ids
    
    def tokenize_function(self, example):
        result = self.tokenizer(
            example["text"],
            max_length = self.max_length,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
        result["label"] = torch.tensor([self.multi_labels_to_ids(eval(label)) for label in example["label"]])
        return result
        
data_files = {
        "test": test_short_path,
    }

tokenizer_wrapper = TokenizerWrapper(tokenizer_bert, MAX_POSITION_EMBEDDINGS, classes)
dataset = load_dataset("csv", data_files=data_files)
dataset = dataset.map(tokenizer_wrapper.tokenize_function, batched=True, num_proc=1)
dataset = dataset.with_format("torch")
print("dataset loaded")

dataset loaded


In [6]:

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.config = config
        self.device = device
        self.bert_model = model_bert
        self.can_generate = model_bert.can_generate
        self.base_model_prefix = model_bert.base_model_prefix
        self.get_input_embeddings = model_bert.get_input_embeddings
        self.dropout = torch.nn.Dropout(0.2)
        self.linear = torch.nn.Linear(self.bert_model.config.hidden_size, 50)
    
    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output
    
model_bert = BERTClass()
model_bert.load_state_dict(torch.load(CKPT))
model_bert = model_bert.to(device)

In [7]:
explainer = LimeTextExplainer(class_names=classes, bow=False)

def predictor_bert(texts):
    tk = tokenizer_bert(
            texts,
            max_length = MAX_POSITION_EMBEDDINGS,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
    ids = tk['input_ids'].to(device, dtype = torch.long)
    mask = tk['attention_mask'].to(device, dtype = torch.long)
    token_type_ids = tk['token_type_ids'].to(device, dtype = torch.long)
    outputs = model_bert(ids, mask, token_type_ids)
    probas = F.sigmoid(outputs).detach().cpu().numpy()
    del tk, outputs
    return probas


def predictor_model(texts, model, tokenizer):
    tk = tokenizer(
            texts,
            max_length = MAX_POSITION_EMBEDDINGS,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(device)
    ids = tk['input_ids'].to(device, dtype = torch.long)
    mask = tk['attention_mask'].to(device, dtype = torch.long)
    token_type_ids = tk['token_type_ids'].to(device, dtype = torch.long)
    outputs = model(ids, mask, token_type_ids)
    probas = F.sigmoid(outputs).detach().cpu().numpy()
    del tk, outputs
    return probas

The instances are formatted as a list of strings, where each string is one word used by lime. The rationales mask is a list of indices, where the first list refers to the index of the sample the label corresponds to and the second list is the index of string used in that rationale.

### Test code for faithfulness calculation

In [8]:
import faithfulness_lime as faithfulness
# this reimports the library for easy testing in the notebook
import importlib
import numpy as np
importlib.reload(faithfulness)

explainer = LimeTextExplainer(class_names=classes, bow=False)

In [9]:
# Fix N to the first 10 examples and pick k = 5
samples_start = 0
samples_end = 10
input_data = dataset["test"]["text"][samples_start:samples_end]
k = 5

indexed_text, index_array_rationalle = faithfulness.lime_create_index_arrays(input_data, predictor_bert, explainer, k_labels = k)
rationale_removed = faithfulness.remove_rationale_words(indexed_text, index_array_rationalle)
others_removed = faithfulness.remove_other_words(indexed_text, index_array_rationalle)

# We expect a list of rationales and others to be removed that correspond to the number of texts in input_data
print('length of rationale removed list:', len(rationale_removed))
print('length of others removed list:', len(others_removed))
_, faith_5 = faithfulness.calculate_faithfulness(input_data, rationale_removed, others_removed, model_bert, tokenizer_bert, predictor_model)

length of rationale removed list: 10
length of others removed list: 10
Calculating Sufficiency
Calculating Comprehensiveness

-- Metrics -------------------------------------------------------------


Faithfulness:  0.37830418
Comprehensiveness:  0.4156465
Sufficency:  0.9101585

Sufficiency list: [1.1997775  0.43490794 1.3844305  0.6592898  0.3310731  0.7969022
 1.1339489  1.4446026  0.52336097 1.1932914 ]
Comprehensiveness list: [0.12781498 0.15973037 0.47970444 0.09612808 0.17070179 0.5854102
 0.6944684  0.4119725  0.7578804  0.6726538 ]

Comprehensiveness Median:  0.44583845
Comprehensiveness q1 (25% percentile):  0.16247322782874107
Comprehensiveness q3 (75% percentile):  0.6508428901433945


Sufficency Median:  0.96542555
Sufficency q1 (25% percentile):  0.5573431700468063
Sufficency q3 (75% percentile):  1.1981559693813324



In [10]:
# Fix N to the first 10 examples and pick k = 10
samples_start = 0
samples_end = 10
input_data = dataset["test"]["text"][samples_start:samples_end]
k = 10

indexed_text, index_array_rationalle = faithfulness.lime_create_index_arrays(input_data, predictor_bert, explainer, k_labels = k)
rationale_removed = faithfulness.remove_rationale_words(indexed_text, index_array_rationalle)
others_removed = faithfulness.remove_other_words(indexed_text, index_array_rationalle)

# We expect a list of rationales and others to be removed that correspond to the number of texts in input_data
print('length of rationale removed list:', len(rationale_removed))
print('length of others removed list:', len(others_removed))
_, faith_10 = faithfulness.calculate_faithfulness(input_data, rationale_removed, others_removed, model_bert, tokenizer_bert, predictor_model)

length of rationale removed list: 10
length of others removed list: 10
Calculating Sufficiency
Calculating Comprehensiveness

-- Metrics -------------------------------------------------------------


Faithfulness:  0.21069297
Comprehensiveness:  0.20763528
Sufficency:  1.0147263

Sufficiency list: [0.66150486 0.50363344 1.3346727  1.0085003  0.56927866 0.8266784
 1.367616   1.479163   1.1131557  1.2830596 ]
Comprehensiveness list: [0.10953143 0.12813778 0.14402504 0.16545348 0.06882977 0.31301817
 0.31437403 0.362963   0.23182683 0.2381933 ]

Comprehensiveness Median:  0.19864015
Comprehensiveness q1 (25% percentile):  0.132109597325325
Comprehensiveness q3 (75% percentile):  0.2943119555711746


Sufficency Median:  1.060828
Sufficency q1 (25% percentile):  0.7027982473373413
Sufficency q3 (75% percentile):  1.3217694163322449



In [11]:
# Fix N to the first 10 examples and pick k = 15
samples_start = 0
samples_end = 10
input_data = dataset["test"]["text"][samples_start:samples_end]
k = 15

indexed_text, index_array_rationalle = faithfulness.lime_create_index_arrays(input_data, predictor_bert, explainer, k_labels = k)
rationale_removed = faithfulness.remove_rationale_words(indexed_text, index_array_rationalle)
others_removed = faithfulness.remove_other_words(indexed_text, index_array_rationalle)

# We expect a list of rationales and others to be removed that correspond to the number of texts in input_data
print('length of rationale removed list:', len(rationale_removed))
print('length of others removed list:', len(others_removed))
_, faith_15 = faithfulness.calculate_faithfulness(input_data, rationale_removed, others_removed, model_bert, tokenizer_bert, predictor_model)

length of rationale removed list: 10
length of others removed list: 10
Calculating Sufficiency
Calculating Comprehensiveness

-- Metrics -------------------------------------------------------------


Faithfulness:  0.19094196
Comprehensiveness:  0.19981489
Sufficency:  0.95559424

Sufficiency list: [0.5098805  0.4577699  1.1453369  0.73992354 0.5434567  0.774903
 1.2848253  1.570421   1.1037817  1.4256439 ]
Comprehensiveness list: [0.10799496 0.1121013  0.22378412 0.09614211 0.10092095 0.25278768
 0.24575733 0.28585857 0.2403104  0.3324914 ]

Comprehensiveness Median:  0.23204726
Comprehensiveness q1 (25% percentile):  0.10902154445648193
Comprehensiveness q3 (75% percentile):  0.2510300911962986


Sufficency Median:  0.9393424
Sufficency q1 (25% percentile):  0.5925733894109726
Sufficency q3 (75% percentile):  1.2499532103538513



In [12]:
print('For k = 5:')
print('avg faith:', faith_5)
print('For k = 10:')
print('avg faith:', faith_10)
print('For k = 15:')
print('avg faith:', faith_15)

For k = 5:
avg faith: [0.37830418]
For k = 10:
avg faith: [0.21069297]
For k = 15:
avg faith: [0.19094196]


In [15]:
faiths = faith_5 + faith_10 + faith_15
np.mean(faiths)

0.2599797

In [13]:
import math
##################################################################################################################################
# Returns the faithfulness results for choice of:
# - input_data being the same dimension/characteristics as the test/val/train dataset used for our classifier of choice Bert or OPT.
# - start_index being the starting point index of the input_data
# - N being the size of input dataset (this is how many texts you want an explanation for)
# - B being the size of explanation batch (this is how many texts your machine can explain at a given instance)
# - k being the top k features defined as our rationales for explanation
# Precondition: pipeline with tokenizer and model correctly initialized along with explainer for LIME
######################################################################################################################################
def get_faith_lime(input_data, start_index, N, B, k):
    num_steps = math.ceil(N/B)
    tail_n =  N % B    
    overall_ind = []
    overall_faith = []

    i = start_index
    while i < N + tail_n:
        if i >= (N-tail_n) and tail_n > 0:
            input_subset = input_data[i: i+tail_n]
        else:    
            input_subset = input_data[i: i+B]


        indexed_text, index_array_rationalle = faithfulness.lime_create_index_arrays(input_subset, predictor_bert, explainer, k_labels = k)   
        rationale_removed = faithfulness.remove_rationale_words(indexed_text, index_array_rationalle)
        others_removed = faithfulness.remove_other_words(indexed_text, index_array_rationalle)
        
        ind, faith = faithfulness.calculate_faithfulness(input_data, rationale_removed, others_removed, model_bert, tokenizer_bert, predictor_model)
        overall_ind.append(ind)
        overall_faith.extend(faith)

        i += B
        
    return overall_ind, overall_faith, np.mean(overall_faith)


In [14]:
#  For our evaluation, we fix for 10 examples and choose k = 5, 10, 15.
start_index = 0
N = 1
B = 1
explainer = LimeTextExplainer(class_names=classes, bow=False)
input_data = dataset['test']['text']

k = 5
_, overall_faith_5, avg_faith_5 = get_faith_lime(input_data, start_index, N, B, k)
k = 10
_, overall_faith_10, avg_faith_10 = get_faith_lime(input_data, start_index, N, B, k)
k = 15
_, overall_faith_15, avg_faith_15 = get_faith_lime(input_data, start_index, N, B, k)

KeyboardInterrupt: 

In [None]:
print('For k = 5:')
print('overall_faith:', overall_faith_5)
print('avg faith:', avg_faith_5)
print('For k = 10:')
print('overall_faith:', overall_faith_10)
print('avg faith:', avg_faith_10)
print('For k = 15:')
print('overall_faith:', overall_faith_15)
print('avg faith:', avg_faith_15)