In [1]:
import argparse
from pprint import pprint

import evaluate
import numpy as np
import pandas as pd
import torch
from torch import nn
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    EvalPrediction,
    OPTForSequenceClassification,
    Trainer,
    TrainingArguments,
)
import wandb

MODEL = "facebook/opt-350m"
MAX_POSITION_EMBEDDINGS = 2048

from dataclasses import dataclass


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [3]:
CHECKPOINT_DIR = "OPT-350m-mimic-full"
VAL_DATASET_PATH = "data/val_9.csv"
CODE_PATH = "data/icd9_codes.csv"

In [4]:
# Load dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True, device=device)

code_labels = pd.read_csv("data/icd9_codes.csv")
dataset = load_dataset("csv", data_files=VAL_DATASET_PATH)

# Create class dictionaries
classes = [class_ for class_ in code_labels["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}


def multi_labels_to_ids(labels: list[str]) -> list[float]:
    ids = [0.0] * len(class2id)  # BCELoss requires float as target type
    for label in labels:
        ids[class2id[label]] = 1.0
    return ids


def preprocess_function(example):
    result = tokenizer(
        example["text"], truncation=True, max_length=MAX_POSITION_EMBEDDINGS
    )
    result["labels"] = [multi_labels_to_ids(eval(label)) for label in example["labels"]]
    return result


dataset = dataset.map(
    preprocess_function, load_from_cache_file=True, batched=True, num_proc=8
)

In [5]:
config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

if unused_kwargs:
    print(f"Unused kwargs: {unused_kwargs}")

model = OPTForSequenceClassification.from_pretrained(
    MODEL,
    config=config,
).to(device)

Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
model.load_adapter(CHECKPOINT_DIR)

In [None]:
untokenized_dataset = load_dataset("csv", data_files=VAL_DATASET_PATH)

print(untokenized_dataset['train'][0])

In [8]:
inputs = tokenizer(untokenized_dataset["train"][0]['text'], return_tensors="pt", truncation=True, max_length=MAX_POSITION_EMBEDDINGS).to(device)

with torch.no_grad():
    logits = model(**inputs).logits

In [None]:
logits = logits.to('cpu')

predicted_class_ids = torch.arange(0, logits.shape[-1])[
    torch.sigmoid(logits).squeeze(dim=0) > 0.5
]

# Get the predicted class names
for id in predicted_class_ids:
    predicted_class = id2class[int(id)]
    pprint(code_labels[code_labels.icd_code == predicted_class])

In [None]:
for i, logit in enumerate(logits[0]):
    pprint(f'{classes[i]}, {logit}')

In [27]:
import lime
from lime import lime_text
from lime.lime_text import LimeTextExplainer
from lime.lime_text import IndexedString
import numpy as np
import torch.nn.functional as F
from time import time


explainer = LimeTextExplainer(class_names=classes, bow=False)

def predictor_opt(texts):
    tk = tokenizer(texts, return_tensors="pt",truncation=True, max_length=MAX_POSITION_EMBEDDINGS).to(device)
    outputs = model(**tk)
    tensor_logits = outputs[0]
    probas = F.sigmoid(tensor_logits).detach().cpu().numpy()
    return probas

In [28]:
sentence = untokenized_dataset["train"][2]["text"]
n_samples = 10
k = 5

In [29]:
with torch.no_grad():
    exp_bert = explainer.explain_instance(
        sentence, predictor_opt, num_samples=n_samples, top_labels=k
    )

In [None]:
exp_bert.show_in_notebook(text=True)

In [None]:
untokenized_dataset["train"][2]['labels']