In [152]:
from pprint import pprint

import numpy as np
import pandas as pd
import shap
import torch
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from torch import nn
from transformers import (
    AutoConfig,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    OPTForSequenceClassification,
    Pipeline,
)

import wandb

MODEL = "facebook/opt-350m"
MAX_POSITION_EMBEDDINGS = 2048

from dataclasses import dataclass

In [146]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [147]:
CHECKPOINT_DIR = "OPT-350m-mimic-full"
VAL_DATASET_PATH = "data/val_9.csv"
CODE_PATH = "data/icd9_codes.csv"

In [148]:
# Load dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True, device=device)

code_labels = pd.read_csv("data/icd9_codes.csv")
dataset = load_dataset("csv", data_files=VAL_DATASET_PATH)

# Create class dictionaries
classes = [class_ for class_ in code_labels["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}


def multi_labels_to_ids(labels: list[str]) -> list[float]:
    ids = [0.0] * len(class2id)  # BCELoss requires float as target type
    for label in labels:
        ids[class2id[label]] = 1.0
    return ids


def preprocess_function(example):
    result = tokenizer(
        example["text"], truncation=True, max_length=MAX_POSITION_EMBEDDINGS
    )
    result["labels"] = [multi_labels_to_ids(eval(label)) for label in example["labels"]]
    return result


dataset = dataset.map(
    preprocess_function, load_from_cache_file=True, batched=True, num_proc=8
)

In [149]:
config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

if unused_kwargs:
    print(f"Unused kwargs: {unused_kwargs}")

model = OPTForSequenceClassification.from_pretrained(
    MODEL,
    config=config,
).to(device)

model.load_adapter(CHECKPOINT_DIR)

Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [150]:
untokenized_dataset = load_dataset("csv", data_files=VAL_DATASET_PATH)

print(untokenized_dataset["train"][0])

{'text': "Sex:   M\n \nService: SURGERY\n \nAllergies: \nGrass ___, Standard / Lactose\n \n ___.\n \nChief Complaint:\nright popliteal aneurysm\n \nMajor Surgical or Invasive Procedure:\n___: popliteal artery stent graft\n\n \nHistory of Present Illness:\nMr. ___ has a fairly focal aneurysm in the\nmid right popliteal artery and is mostly full of thrombus and\nmeasures 3.1 cm.  It is patent and does have palpable pedal\npulse distally.  He has need of upcoming ankle surgery as well.\nHe has a past medical history notable for breast cancer status\npost mastectomy and chemotherapy/radiation therapy with duodenal\nulcer, pseudogout, depression, hypothyroidism, microvascular\ncerebrovascular disease, hyperlipidemia, and COPD.  He had vein\nmapping performed today which shows the lesser saphenous veins\nto be small and noncompressible bilaterally.  He has the\nthrombosis of the right greater saphenous at the level of the\nknee.  The left greater saphenous is adequate as are both\nbasilic an

In [151]:
inputs = tokenizer(
    untokenized_dataset["train"][0]["text"],
    return_tensors="pt",
    truncation=True,
    max_length=MAX_POSITION_EMBEDDINGS,
).to(device)

with torch.no_grad():
    logits = model(**inputs).logits

In [206]:
class OPT_ICD9_Pipeline(Pipeline):
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "maybe_arg" in kwargs:
            preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, text):
        return self.tokenizer(
            text,
            truncation=True,
            max_length=MAX_POSITION_EMBEDDINGS,
            return_tensors="pt",
        )

    def _forward(self, model_inputs):
        outputs = self.model(**model_inputs)
        return outputs

    def postprocess(self, model_outputs):
        # logits = model_outputs.logits[0].numpy()
        #print(logits)
        probs = model_outputs["logits"].sigmoid()

        output = []
        for i, prob in enumerate(probs[0]):
            label = self.model.config.id2label[i]
            score = prob
            output.append({"label": label, "score": score})
        return output

In [207]:
pipeline = OPT_ICD9_Pipeline(model=model, tokenizer=tokenizer)

In [220]:
pipeline(untokenized_dataset["train"][2]["text"])

[{'label': 'd-2449', 'score': tensor(0.0058)},
 {'label': 'd-25000', 'score': tensor(0.7576)},
 {'label': 'd-2720', 'score': tensor(0.0405)},
 {'label': 'd-2724', 'score': tensor(0.8866)},
 {'label': 'd-2749', 'score': tensor(0.0047)},
 {'label': 'd-2761', 'score': tensor(0.0078)},
 {'label': 'd-2762', 'score': tensor(0.0121)},
 {'label': 'd-27651', 'score': tensor(0.0071)},
 {'label': 'd-27800', 'score': tensor(0.0812)},
 {'label': 'd-2851', 'score': tensor(0.0089)},
 {'label': 'd-2859', 'score': tensor(0.0045)},
 {'label': 'd-2875', 'score': tensor(0.0114)},
 {'label': 'd-30000', 'score': tensor(0.0135)},
 {'label': 'd-30500', 'score': tensor(0.0025)},
 {'label': 'd-3051', 'score': tensor(0.0127)},
 {'label': 'd-311', 'score': tensor(0.0275)},
 {'label': 'd-32723', 'score': tensor(0.1259)},
 {'label': 'd-33829', 'score': tensor(0.0004)},
 {'label': 'd-3572', 'score': tensor(0.0942)},
 {'label': 'd-4019', 'score': tensor(0.9881)},
 {'label': 'd-40390', 'score': tensor(0.0127)},
 {'lab

In [221]:
masker = shap.maskers.Text(pipeline.tokenizer)

In [222]:
explainer = shap.Explainer(pipeline, masker)

In [223]:
shap_values = explainer(untokenized_dataset["train"][2]["text"])

PartitionExplainer explainer:   7%|▋         | 726/9914 [03:36<48:25,  3.16it/s]  


KeyboardInterrupt: 

In [211]:
shap.plots.text(shap_values[0, :, "d-2749"])