In [84]:
import argparse
from pprint import pprint

import evaluate
import numpy as np
import pandas as pd
import torch
from torch import nn
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    EvalPrediction,
    OPTForSequenceClassification,
    Trainer,
    TrainingArguments,
)
import wandb

MODEL = "facebook/opt-350m"
MAX_POSITION_EMBEDDINGS = 2048

from dataclasses import dataclass

In [103]:
CHECKPOINT_DIR = "OPT-350m-mimic-full"
TRAIN_DATSET_PATH = "data/train_9.csv"
VAL_DATASET_PATH = "data/val_9.csv"
TEST_DATSET_PATH = "data/test_9.csv"
CODE_PATH = "data/icd9_codes.csv"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [117]:
# Load dataset
# Load dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True, device=device)

data_files = {
    "train": TRAIN_DATSET_PATH,
    "validation": VAL_DATASET_PATH,
    "test": TEST_DATSET_PATH,
}
code_labels = pd.read_csv(CODE_PATH)
dataset = load_dataset("csv", data_files=data_files)
# Create class dictionaries
classes = [class_ for class_ in code_labels["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}


def multi_labels_to_ids(labels: list[str]) -> list[float]:
    ids = [0.0] * len(class2id)  # BCELoss requires float as target type
    for label in labels:
        ids[class2id[label]] = 1.0
    return ids


def preprocess_function(example):
    result = tokenizer(
        example["text"]
    )
    result["labels"] = [multi_labels_to_ids(eval(label)) for label in example["labels"]]
    return result


dataset = dataset.map(
    preprocess_function, batched=True, num_proc=8
)

Generating train split: 126004 examples [00:20, 6055.36 examples/s]
Generating validation split: 15750 examples [00:01, 8326.58 examples/s] 
Generating test split: 15751 examples [00:01, 8603.46 examples/s] 
Map (num_proc=8): 100%|██████████| 126004/126004 [02:05<00:00, 1004.94 examples/s]
Map (num_proc=8): 100%|██████████| 15750/15750 [00:16<00:00, 957.40 examples/s] 
Map (num_proc=8): 100%|██████████| 15751/15751 [00:16<00:00, 948.23 examples/s] 


In [118]:
dataset["train"]

Dataset({
    features: ['text', 'labels', 'input_ids', 'attention_mask'],
    num_rows: 126004
})

In [124]:
import numpy as np
# get average length of text
train_dataset = dataset['train']


input_ids_lengths =[]
label_lengths = []
sets = ['train', 'test', 'validation']
for set in sets:
    input_ids_lengths.extend([len(example['input_ids']) for example in dataset[set]])


In [125]:
print(np.mean(input_ids_lengths))
print(np.max(input_ids_lengths))



2830.1608202914194
5685
51.0
51


In [131]:
label_lengths = []

sets = ["train", "test", "validation"]
for set in sets:
    label_lengths.extend([example['labels'].count(1) for example in dataset[set]])


In [132]:
print(np.mean(label_lengths))
print(np.max(label_lengths))

4.315278880035555
21


In [126]:
len(dataset["train"].shard(index=1, num_shards=20))

6301

In [107]:
len(dataset["test"].shard(index=1, num_shards=20))

788

In [135]:
print("Loading biotech events classification dataset")
bio_dataset = load_dataset(
    "knowledgator/events_classification_biotech", trust_remote_code=True
)
bio_classes = [
    class_ for class_ in bio_dataset["train"].features["label 1"].names if class_
]
bio_class2id = {class_: id for id, class_ in enumerate(bio_classes)}
bio_id2class = {id: class_ for class_, id in bio_class2id.items()}
def preprocess_function(example):
    text = f"{example['title']}.\n{example['content']}"
    all_labels = example["all_labels"]
    labels = [0.0 for i in range(len(bio_classes))]
    for label in all_labels:
        label_id = bio_class2id[label]
        labels[label_id] = 1.0
    example = tokenizer(text)
    example["labels"] = labels
    return example
bio_dataset = bio_dataset.map(preprocess_function)

Loading biotech events classification dataset


Map: 100%|██████████| 2759/2759 [00:05<00:00, 480.13 examples/s]
Map: 100%|██████████| 381/381 [00:00<00:00, 461.50 examples/s]


In [138]:
bio_dataset["train"]["labels"][0]

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [141]:
bio_input_ids_lengths = []
bio_label_lengths = []
bio_sets = ["train", "test"]
for set in bio_sets:
    bio_input_ids_lengths.extend(
        [len(example["input_ids"]) for example in bio_dataset[set]]
    )
    bio_label_lengths.extend([example["labels"].count(1) for example in bio_dataset[set]])

In [143]:
print(np.mean(bio_input_ids_lengths))
print(np.max(bio_input_ids_lengths))

print(np.mean(bio_label_lengths))
print(np.max(bio_label_lengths))

672.0152866242038
3635
1.8289808917197452
5
