In [1]:
from pprint import pprint

import numpy as np
import pandas as pd
import shap
import torch
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from torch import nn
from transformers import (
    AutoConfig,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    OPTForSequenceClassification,
    Pipeline,
)

import wandb

MODEL = "facebook/opt-350m"
MAX_POSITION_EMBEDDINGS = 2048

from dataclasses import dataclass

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

print(device)

cuda


In [3]:
CHECKPOINT_DIR = "OPT-350m-mimic-full"
VAL_DATASET_PATH = "data/val_9.csv"
CODE_PATH = "data/icd9_codes.csv"

In [4]:
# Load dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True, device=device)

code_labels = pd.read_csv("data/icd9_codes.csv")
dataset = load_dataset("csv", data_files=VAL_DATASET_PATH)

# Create class dictionaries
classes = [class_ for class_ in code_labels["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}


def multi_labels_to_ids(labels: list[str]) -> list[float]:
    ids = [0.0] * len(class2id)  # BCELoss requires float as target type
    for label in labels:
        ids[class2id[label]] = 1.0
    return ids


def preprocess_function(example):
    result = tokenizer(
        example["text"], truncation=True, max_length=MAX_POSITION_EMBEDDINGS
    )
    result["labels"] = [multi_labels_to_ids(eval(label)) for label in example["labels"]]
    return result


dataset = dataset.map(
    preprocess_function, load_from_cache_file=True, batched=True, num_proc=8
)

In [5]:
config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

if unused_kwargs:
    print(f"Unused kwargs: {unused_kwargs}")

model = OPTForSequenceClassification.from_pretrained(
    MODEL,
    config=config,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
).to(device)

model.load_adapter(CHECKPOINT_DIR)
model.to_bettertransformer()

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.


OPTForSequenceClassification(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 512, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
      (project_out): Linear(in_features=1024, out_features=512, bias=False)
      (project_in): Linear(in_features=512, out_features=1024, bias=False)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OptFlashAttention2(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): lora.Linear(
              (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.05, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=1024, out_features=16, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=16, out_fe

In [None]:
untokenized_dataset = load_dataset("csv", data_files=VAL_DATASET_PATH)

print(untokenized_dataset["train"][0])

In [7]:
inputs = tokenizer(
    untokenized_dataset["train"][0]["text"],
    return_tensors="pt",
    truncation=True,
    max_length=MAX_POSITION_EMBEDDINGS,
).to(device)

with torch.no_grad():
    logits = model(**inputs).logits

In [8]:
class OPT_ICD9_Pipeline(Pipeline):
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "maybe_arg" in kwargs:
            preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, text):
        return self.tokenizer(
            text,
            truncation=True,
            max_length=MAX_POSITION_EMBEDDINGS,
            return_tensors="pt",
        )

    def _forward(self, model_inputs):
        outputs = self.model(**model_inputs)
        return outputs

    def postprocess(self, model_outputs):
        # logits = model_outputs.logits[0].numpy()
        #print(logits)
        probs = model_outputs["logits"].sigmoid()

        output = []
        for i, prob in enumerate(probs[0]):
            label = self.model.config.id2label[i]
            score = prob
            output.append({"label": label, "score": score})
        return output

In [9]:
pipeline = OPT_ICD9_Pipeline(model=model, tokenizer=tokenizer, device=device)

In [None]:
model.device

In [None]:
pipeline(untokenized_dataset["train"][2]["text"])

In [12]:
masker = shap.maskers.Text(pipeline.tokenizer)

In [13]:
sample = shap.sample(untokenized_dataset["train"]["text"], 2)

In [14]:
explainer = shap.Explainer(pipeline, masker)

In [None]:
untokenized_dataset["train"][:2]["text"]

In [21]:
shap_values = explainer(untokenized_dataset["train"][:5]["text"])

PartitionExplainer explainer:  20%|██        | 1/5 [00:00<?, ?it/s]

In [None]:
shap.plots.text(shap_values[0, :, "d-2749"])