In [7]:
import argparse
from pprint import pprint

import evaluate
import numpy as np
import pandas as pd
import torch
from torch import nn
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    EvalPrediction,
    OPTForSequenceClassification,
    Trainer,
    TrainingArguments,
)

MODEL = "facebook/opt-350m"
MAX_POSITION_EMBEDDINGS = 2048

from dataclasses import dataclass

In [8]:
TRAIN_DATSET_PATH = "./data/train_10_top50.csv"
VAL_DATASET_PATH = "./data/val_10_top50.csv"
TEST_DATSET_PATH = "./data/test_10_top50.csv"
CODE_PATH = "./data/icd10_codes_top50.csv"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
# Load dataset
# Load dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True, device=device)

data_files = {
    "train": TRAIN_DATSET_PATH,
    "validation": VAL_DATASET_PATH,
    "test": TEST_DATSET_PATH,
}
code_labels = pd.read_csv(CODE_PATH)
dataset = load_dataset("csv", data_files=data_files)
# Create class dictionaries
classes = [class_ for class_ in code_labels["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}


def multi_labels_to_ids(labels: list[str]) -> list[float]:
    ids = [0.0] * len(class2id)  # BCELoss requires float as target type
    for label in labels:
        ids[class2id[label]] = 1.0
    return ids


def preprocess_function(example):
    result = tokenizer(
        example["text"]
    )
    result["label"] = [multi_labels_to_ids(eval(label)) for label in example["label"]]
    return result


dataset = dataset.map(
    preprocess_function, batched=True, num_proc=8
)

Map (num_proc=8): 100%|██████████| 33768/33768 [00:21<00:00, 1537.07 examples/s]
Map (num_proc=8): 100%|██████████| 4221/4221 [00:02<00:00, 1556.54 examples/s]
Map (num_proc=8): 100%|██████████| 4221/4221 [00:02<00:00, 1504.32 examples/s]


In [13]:
dataset["train"]

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 33768
})

In [14]:
import numpy as np
# get average length of text
train_dataset = dataset['train']


input_ids_lengths =[]
label_lengths = []
sets = ['train', 'test', 'validation']
for set in sets:
    input_ids_lengths.extend([len(example['input_ids']) for example in dataset[set]])


In [15]:
print(np.mean(input_ids_lengths))
print(np.max(input_ids_lengths))



3041.86259180289
6262


In [16]:
label_lengths = []

sets = ["train", "test", "validation"]
for set in sets:
    label_lengths.extend([example['label'].count(1) for example in dataset[set]])


In [17]:
print(np.mean(label_lengths))
print(np.max(label_lengths))

5.023146173892442
23


In [18]:
len(dataset["train"].shard(index=1, num_shards=20))

1689

In [19]:
len(dataset["test"].shard(index=1, num_shards=20))

211

In [20]:
print("Loading biotech events classification dataset")
bio_dataset = load_dataset(
    "knowledgator/events_classification_biotech", trust_remote_code=True
)
bio_classes = [
    class_ for class_ in bio_dataset["train"].features["label 1"].names if class_
]
bio_class2id = {class_: id for id, class_ in enumerate(bio_classes)}
bio_id2class = {id: class_ for class_, id in bio_class2id.items()}
def preprocess_function(example):
    text = f"{example['title']}.\n{example['content']}"
    all_labels = example["all_labels"]
    labels = [0.0 for i in range(len(bio_classes))]
    for label in all_labels:
        label_id = bio_class2id[label]
        labels[label_id] = 1.0
    example = tokenizer(text)
    example["labels"] = labels
    return example
bio_dataset = bio_dataset.map(preprocess_function)

Loading biotech events classification dataset


Map: 100%|██████████| 2759/2759 [00:04<00:00, 620.18 examples/s]
Map: 100%|██████████| 381/381 [00:00<00:00, 645.96 examples/s]


In [21]:
bio_dataset["train"]["labels"][0]

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [22]:
bio_input_ids_lengths = []
bio_label_lengths = []
bio_sets = ["train", "test"]
for set in bio_sets:
    bio_input_ids_lengths.extend(
        [len(example["input_ids"]) for example in bio_dataset[set]]
    )
    bio_label_lengths.extend([example["labels"].count(1) for example in bio_dataset[set]])

In [23]:
print(np.mean(bio_input_ids_lengths))
print(np.max(bio_input_ids_lengths))

print(np.mean(bio_label_lengths))
print(np.max(bio_label_lengths))

672.0152866242038
3635
1.8289808917197452
5
