In [43]:
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from datasets import load_dataset
from lime.lime_text import LimeTextExplainer
import torch.nn.functional as F
from time import time
import shap


from __future__ import print_function
import os 

In [2]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRAIN_BATCH_SIZE = 1
VALID_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
data_dir = "output/"
print(device)

cuda


## Load Dataset

In [3]:
test_short_path = "data/test_10_top50_short.csv"
labels_10_top50 = pd.read_csv('data/icd10_codes_top50.csv')
code_labels_10 = pd.read_csv("data/icd10_codes.csv")

## Load Bert-Based Custom Tokenized Dataset

In [4]:
# Hyperparameters
MAX_LEN = 512
MODEL = "emilyalsentzer/Bio_ClinicalBERT"
CKPT = os.path.join(data_dir,"best_model_state.bin")

In [5]:
# Create class dictionaries
classes = [class_ for class_ in code_labels_10["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}

config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

tokenizer_bert = AutoTokenizer.from_pretrained(MODEL)
model_bert = AutoModel.from_pretrained(MODEL, config=config, cache_dir='./model_ckpt/')

In [6]:
class TokenizerWrapper:
    def __init__(self, tokenizer, length, classes):
        self.tokenizer = tokenizer
        self.max_length = length
        self.classes = classes
        self.class2id = {class_: id for id, class_ in enumerate(self.classes)}
        self.id2class = {id: class_ for class_, id in self.class2id.items()}
        
    def multi_labels_to_ids(self, labels: list[str]) -> list[float]:
        ids = [0.0] * len(self.class2id)  # BCELoss requires float as target type
        for label in labels:
            ids[self.class2id[label]] = 1.0
        return ids
    
    def tokenize_function(self, example):
        result = self.tokenizer(
            example["text"],
            max_length = self.max_length,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        )
        result["label"] = torch.tensor([self.multi_labels_to_ids(eval(label)) for label in example["label"]])
        return result

In [7]:
data_files = {
        "test": test_short_path,
    }

tokenizer_wrapper = TokenizerWrapper(tokenizer_bert, MAX_LEN, classes)
dataset = load_dataset("csv", data_files=data_files)
dataset = dataset.map(tokenizer_wrapper.tokenize_function, batched=True, num_proc=1)
dataset = dataset.with_format("torch")

test_data_loader = torch.utils.data.DataLoader(dataset['test'], 
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

## Load the Bio_ClinicalBert Tuned Model

In [8]:
class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.bert_model = model_bert
        self.config = config
        self.can_generate = model_bert.can_generate
        self.dropout = torch.nn.Dropout(0.2)
        self.linear = torch.nn.Linear(self.bert_model.config.hidden_size, 50)
    
    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output

In [9]:
model_bert = BERTClass()
model_bert.load_state_dict(torch.load(CKPT))
model_bert = model_bert.to(device)

In [10]:
text = dataset["test"]['text'][125]
inputs = tokenizer_bert(text, return_tensors="pt", truncation=True, padding = 'max_length', max_length=MAX_LEN).to(device)
ids = inputs['input_ids'].to(device, dtype = torch.long)
mask = inputs['attention_mask'].to(device, dtype = torch.long)
token_type_ids = inputs['token_type_ids'].to(device, dtype = torch.long)

with torch.no_grad():
    output = model_bert(ids, mask, token_type_ids).to(device)

output = torch.sigmoid(output).detach().cpu()

predicted_class_ids = torch.arange(0, output.shape[-1])[
    torch.squeeze(output, dim=0) > 0.5
]

# Get the predicted class names
for id in predicted_class_ids:
    predicted_class = tokenizer_wrapper.id2class[int(id)]
    print(code_labels_10[code_labels_10.icd_code == predicted_class])

  icd_code  icd_version                                       long_title
2   d-A001           10  Cholera due to Vibrio cholerae 01, biovar eltor
  icd_code  icd_version         long_title
9  d-A0103           10  Typhoid pneumonia
   icd_code  icd_version                       long_title
20   d-A022           10  Localized salmonella infections
   icd_code  icd_version                                   long_title
21  d-A0220           10  Localized salmonella infection, unspecified
   icd_code  icd_version                                 long_title
27  d-A0229           10  Salmonella with other localized infection
   icd_code  icd_version                                       long_title
48   d-A048           10  Other specified bacterial intestinal infections


## LIME Interpretation

In [11]:
explainer = LimeTextExplainer(class_names=classes, bow=False)
def predictor_bert(texts):
    result = tokenizer_bert(
            texts,
            max_length = MAX_LEN,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        )
    ids = result['input_ids'].to(device, dtype = torch.long)
    mask = result['attention_mask'].to(device, dtype = torch.long)
    token_type_ids = result['token_type_ids'].to(device, dtype = torch.long)
    outputs = model_bert(ids, mask, token_type_ids)
    # probas = F.softmax(outputs).detach().cpu().numpy()
    probas = F.sigmoid(outputs).detach().cpu().numpy() # if there's more than one possible diagnosis
    return probas

In [12]:
# Explanation Parameters
n_samples = 10
k = 5

In [13]:
with torch.no_grad():
  exp_bert = explainer.explain_instance(text, predictor_bert, num_samples = n_samples, top_labels = k)

In [14]:
exp_bert.show_in_notebook(text=True)

## SHAP Interpretation

In [15]:
from transformers import Pipeline

In [65]:
device = "cuda"
class BERT_ICD10_Pipeline(Pipeline):
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "maybe_arg" in kwargs:
            preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, text):
        return self.tokenizer(
            text,
            max_length = MAX_LEN,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        ).to(self.device)

    def _forward(self, model_inputs):
        ids = model_inputs['input_ids'].to(self.device, dtype = torch.long)
        mask = model_inputs['attention_mask'].to(self.device, dtype = torch.long)
        token_type_ids = model_inputs['token_type_ids'].to(self.device, dtype = torch.long)
        outputs = self.model(ids, mask, token_type_ids).to(self.device)
        return outputs

    def postprocess(self, model_outputs):
        probs = F.sigmoid(model_outputs).detach().cpu().numpy() # if there's more than one possible diagnosis

        output = []
        for i, prob in enumerate(probs[0]):
            label = self.model.config.id2label[i]
            score = prob
            output.append({"label": label, "score": score})
        # print(output)
        return output

In [66]:
pipeline = BERT_ICD10_Pipeline(model=model_bert, tokenizer=tokenizer_bert, device = device)

In [67]:
pipeline(text)

[{'label': 'd-A00', 'score': 0.3226577},
 {'label': 'd-A000', 'score': 0.30186564},
 {'label': 'd-A001', 'score': 0.4138908},
 {'label': 'd-A009', 'score': 0.37944123},
 {'label': 'd-A01', 'score': 0.09744338},
 {'label': 'd-A010', 'score': 0.043811247},
 {'label': 'd-A0100', 'score': 0.15087287},
 {'label': 'd-A0101', 'score': 0.37450364},
 {'label': 'd-A0102', 'score': 0.3645462},
 {'label': 'd-A0103', 'score': 0.8618943},
 {'label': 'd-A0104', 'score': 0.22992697},
 {'label': 'd-A0105', 'score': 0.36821178},
 {'label': 'd-A0109', 'score': 0.13694131},
 {'label': 'd-A011', 'score': 0.23469485},
 {'label': 'd-A012', 'score': 0.17953075},
 {'label': 'd-A013', 'score': 0.14210102},
 {'label': 'd-A014', 'score': 0.4636024},
 {'label': 'd-A02', 'score': 0.2998755},
 {'label': 'd-A020', 'score': 0.11658162},
 {'label': 'd-A021', 'score': 0.0789383},
 {'label': 'd-A022', 'score': 0.7151925},
 {'label': 'd-A0220', 'score': 0.5994953},
 {'label': 'd-A0221', 'score': 0.3817885},
 {'label': 'd-

In [68]:

masker = shap.maskers.Text(pipeline.tokenizer)

In [69]:
explainer = shap.Explainer(pipeline, masker)

In [70]:
shap_values = explainer(dataset['test']['text'][:2])

MemoryError: Unable to allocate 8.07 GiB for an array with shape (6135, 176478) and data type float64

## Some Other Interp Implementation
TODO: Find another interp. SHAP's a nightmare with custom tokenizers

In [None]:
from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())

In [None]:
from interpret import show
from interpret.blackbox import PartialDependence


pdp = PartialDependence(predictor_bert, [text])

ValueError: text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).