In [28]:
import argparse
from pprint import pprint

import evaluate
import numpy as np
import pandas as pd
import torch
from torch import nn
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    EvalPrediction,
    OPTForSequenceClassification,
    Trainer,
    TrainingArguments,
)
import wandb

MODEL = "facebook/opt-350m"
MAX_POSITION_EMBEDDINGS = 2048

from dataclasses import dataclass


In [4]:
CHECKPOINT_DIR = "OPT-350m-mimic-full"
VAL_DATASET_PATH = "data/val_9.csv"
CODE_PATH = "data/icd9_codes.csv"

In [5]:
# Load dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)

code_labels = pd.read_csv("data/icd9_codes.csv")
dataset = load_dataset("csv", data_files=VAL_DATASET_PATH)

# Create class dictionaries
classes = [class_ for class_ in code_labels["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}


def multi_labels_to_ids(labels: list[str]) -> list[float]:
    ids = [0.0] * len(class2id)  # BCELoss requires float as target type
    for label in labels:
        ids[class2id[label]] = 1.0
    return ids


def preprocess_function(example):
    result = tokenizer(
        example["text"], truncation=True, max_length=MAX_POSITION_EMBEDDINGS
    )
    result["labels"] = [multi_labels_to_ids(eval(label)) for label in example["labels"]]
    return result


dataset = dataset.map(
    preprocess_function, load_from_cache_file=True, batched=True, num_proc=8
)

Generating train split: 15750 examples [00:02, 5659.18 examples/s]
Map (num_proc=8): 100%|██████████| 15750/15750 [00:15<00:00, 1005.66 examples/s]


In [9]:
config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)

if unused_kwargs:
    print(f"Unused kwargs: {unused_kwargs}")

model = OPTForSequenceClassification.from_pretrained(
    MODEL,
    config=config,
)

Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
model.load_adapter(CHECKPOINT_DIR)

In [12]:
untokenized_dataset = load_dataset("csv", data_files=VAL_DATASET_PATH)

print(untokenized_dataset['train'][0])

{'text': "Sex:   M\n \nService: SURGERY\n \nAllergies: \nGrass ___, Standard / Lactose\n \n ___.\n \nChief Complaint:\nright popliteal aneurysm\n \nMajor Surgical or Invasive Procedure:\n___: popliteal artery stent graft\n\n \nHistory of Present Illness:\nMr. ___ has a fairly focal aneurysm in the\nmid right popliteal artery and is mostly full of thrombus and\nmeasures 3.1 cm.  It is patent and does have palpable pedal\npulse distally.  He has need of upcoming ankle surgery as well.\nHe has a past medical history notable for breast cancer status\npost mastectomy and chemotherapy/radiation therapy with duodenal\nulcer, pseudogout, depression, hypothyroidism, microvascular\ncerebrovascular disease, hyperlipidemia, and COPD.  He had vein\nmapping performed today which shows the lesser saphenous veins\nto be small and noncompressible bilaterally.  He has the\nthrombosis of the right greater saphenous at the level of the\nknee.  The left greater saphenous is adequate as are both\nbasilic an

In [14]:
inputs = tokenizer(untokenized_dataset["train"][0]['text'], return_tensors="pt", truncation=True, max_length=MAX_POSITION_EMBEDDINGS)

with torch.no_grad():
    logits = model(**inputs).logits

In [30]:
predicted_class_ids = torch.arange(0, logits.shape[-1])[
    torch.sigmoid(logits).squeeze(dim=0) > 0.5
]

# Get the predicted class names
for id in predicted_class_ids:
    predicted_class = id2class[int(id)]
    pprint(code_labels[code_labels.icd_code == predicted_class])

  icd_code  icd_version                           long_title
0   d-2449            9  Unspecified acquired hypothyroidism
  icd_code  icd_version                 long_title
2   d-2720            9  Pure hypercholesterolemia
  icd_code  icd_version                            long_title
3   d-2724            9  Other and unspecified hyperlipidemia
   icd_code  icd_version                                     long_title
15    d-311            9  Depressive disorder, not elsewhere classified
   icd_code  icd_version                                         long_title
29    d-496            9  Chronic airway obstruction, not elsewhere clas...


In [43]:
for i, logit in enumerate(logits[0]):
    pprint(f'{classes[i]}, {logit}')

'd-2449, 2.680248260498047'
'd-25000, -5.279692649841309'
'd-2720, 0.6904276609420776'
'd-2724, 0.1000528335571289'
'd-2749, -4.944847583770752'
'd-2761, -3.923020601272583'
'd-2762, -5.64964485168457'
'd-27651, -3.627403736114502'
'd-27800, -4.593332290649414'
'd-2851, -1.8333784341812134'
'd-2859, -3.437279224395752'
'd-2875, -5.90502405166626'
'd-30000, -3.9239180088043213'
'd-30500, -5.418412208557129'
'd-3051, -1.8782130479812622'
'd-311, 0.6561868190765381'
'd-32723, -3.748591423034668'
'd-33829, -5.080074310302734'
'd-3572, -4.4857282638549805'
'd-4019, -2.3807945251464844'
'd-40390, -3.9646496772766113'
'd-412, -2.66621994972229'
'd-41400, -5.219027996063232'
'd-41401, -2.2151124477386475'
'd-42731, -5.226138591766357'
'd-42789, -4.7224531173706055'
'd-4280, -2.5561134815216064'
'd-486, -3.5420994758605957'
'd-49390, -6.440539836883545'
'd-496, 2.2681386470794678'
'd-53081, -3.4620933532714844'
'd-56400, -4.661619663238525'
'd-5849, -4.719424724578857'
'd-5856, -5.1200141906738