In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd 
import numpy as np
from preprocessing.cleaning_utils import *
from functools import partial
import torch 

from datasets import load_dataset
from datasets import Value, ClassLabel, Features, DatasetDict
import transformers
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from transformers import logging
from transformers import TrainingArguments, Trainer

logging.set_verbosity_warning()

In [3]:
# !python -m torch.utils.collect_env

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [5]:
mimic_dir = "/home/vs428/project/MIMIC/files/mimiciii/1.4/"
n2c2_dir = "/home/vs428/project/n2c2/2022/N2C2-AP-Reasoning/"
n2c2_data_dir =  "/home/vs428/project/n2c2/2022/Data/"


In [None]:
samples = pd.read_csv(n2c2_dir + "n2c2_sample.csv")
samples_raw = pd.read_csv(n2c2_dir + "n2c2_sample_raw.csv")

In [None]:

# classes = ['Not Relevant', 'Neither', 'Indirect', 'Direct']
# features = Features({
#     'ROW ID':Value("int64"),
#     'HADM ID':Value("int64"),
#     'Assessment':Value("string"),
#     'PlanSubsection':Value("string"),
#     "Relation":Value("string")
# }) 
# dataset = load_dataset("csv", data_files=n2c2_dir + "n2c2_sample_raw.csv", 
#                        features=features)
# dataset = dataset.class_encode_column("Relation")
# dataset = dataset.rename_column("Relation", "label")

In [6]:
classes = ['Not Relevant', 'Neither', 'Indirect', 'Direct']
features = Features({
    'ROW ID':Value("int64"),
    'HADM ID':Value("int64"),
    'Assessment':Value("string"),
    'Plan Subsection':Value("string"),
    "Relation":Value("string")
}) 

dataset = load_dataset("csv", data_files={
                            "train":n2c2_data_dir + "train.csv",
                            "valid":n2c2_data_dir + "dev.csv",
                        },

                       features=features)
dataset = dataset.class_encode_column("Relation")
dataset = dataset.rename_column("Relation", "label")

Using custom data configuration default-b1948d86214b7517
Reusing dataset csv (/home/vs428/.cache/huggingface/datasets/csv/default-b1948d86214b7517/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /home/vs428/.cache/huggingface/datasets/csv/default-b1948d86214b7517/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-61f06d309abebffe.arrow
Loading cached processed dataset at /home/vs428/.cache/huggingface/datasets/csv/default-b1948d86214b7517/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-db00bc55293d3377.arrow
Loading cached processed dataset at /home/vs428/.cache/huggingface/datasets/csv/default-b1948d86214b7517/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-9019b65120396e78.arrow
Loading cached processed dataset at /home/vs428/.cache/huggingface/datasets/csv/default-b1948d86214b7517/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-cadbd3a7f24393c8.arrow


In [7]:
dataset['train'][23]

{'ROW ID': 711150,
 'HADM ID': 183333,
 'Assessment': '75 yo M with h/o cardiomyopahy (EF 15%), atrial fibrillation, pulmonary\n   embolism (on coumadin), adrenal insufficiency, and s/p several recent\n   admissions following episode of Klebsilella pneumonia complicated by\n   respiratory failure and septic shock. Presents to hospital with\n   syncope, hypoxia, hypotension and altered mental status.',
 'Plan Subsection': '# Adrenal insufficiency:\n   Patient on 10 mg hydrocortisone at home.\n   - Increase to 50 mg hydrocortisone given stress of acute illness (day 2\n   today)',
 'label': 1}

In [8]:
# # split dataset: 80% train, 10% validation, 10% test 
# train_testvalid = dataset['train'].train_test_split(test_size=0.2)
# # test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
# split_dataset = DatasetDict({
#     'train': train_testvalid['train'],
#     'valid': test_valid['train'],
#     # 'test': test_valid['test'],
# })



In [9]:
bio_clinicalbert_tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Train Model

In [10]:
bio_clinicalbert_model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=4)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model

In [11]:
training_args = TrainingArguments(output_dir="test_trainer", 
                                  evaluation_strategy="epoch",
                                  
)

In [12]:
import numpy as np
from datasets import load_metric

acc = load_metric("accuracy")
macrof1 = load_metric("f1")

In [13]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy":acc.compute(predictions=predictions, references=labels),
            "f1-macro":macrof1.compute(predictions=predictions, references=labels, 
                                       average="macro")}

In [14]:
def tokenize_function(examples):
    return bio_clinicalbert_tokenizer(examples['Assessment'], examples['Plan Subsection'],
                                      truncation="longest_first",
                                      max_length=512,
                                      verbose=True)
                    


In [15]:
dataset = dataset.map(tokenize_function, batched=True)

Loading cached processed dataset at /home/vs428/.cache/huggingface/datasets/csv/default-b1948d86214b7517/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-3942956ec260f7ce.arrow
Loading cached processed dataset at /home/vs428/.cache/huggingface/datasets/csv/default-b1948d86214b7517/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-951b17d1e7bb78b3.arrow


In [16]:
bio_clinicalbert_tokenizer.decode(dataset['valid'][52]['input_ids'])

'[CLS] alteration in nutrition, arousal, attention, and cognition, impaired, balance, impaired, gait, impaired, knowledge, impaired, muscle performace, impaired, transfers, impaired, cholecystitis, acalculous, renal failure, end stage ( end stage renal disease, esrd ), [ * * last name 121 * * ] problem - enter description in comments, altered mental status ( not delirium ) assessment and plan : 78yo m with esrd w / acute cholecystitis s / p perc ccy now w / mdr e. coli infected biliary fluid, now new subacute stroke [SEP] neurologic : poor ms d / t cva on coumadin / afib vs. infection. rect in few days. improving. no narcotics / bzds. b carotid stenosis < 40 %. tte w / o thrombus. [SEP]'

In [17]:
dataset['train'].set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
dataset['valid'].set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])

In [18]:
data_collator = DataCollatorWithPadding(bio_clinicalbert_tokenizer,
                                        max_length=512, 
                                        padding="max_length",
                                        return_tensors="pt" )

In [19]:
trainer = Trainer(
    model=bio_clinicalbert_model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['valid'],
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

In [20]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: Assessment, ROW ID, HADM ID, Plan Subsection.
***** Running training *****
  Num examples = 4633
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 1740


Epoch,Training Loss,Validation Loss,Accuracy,F1-macro
1,0.7153,0.597065,{'accuracy': 0.7487437185929648},{'f1': 0.766732642338792}
2,0.4576,0.672382,{'accuracy': 0.7805695142378559},{'f1': 0.7940059075142675}
3,0.3017,0.918875,{'accuracy': 0.7671691792294807},{'f1': 0.7860078751767712}


Saving model checkpoint to test_trainer/checkpoint-500
Configuration saved in test_trainer/checkpoint-500/config.json
Model weights saved in test_trainer/checkpoint-500/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: Assessment, ROW ID, HADM ID, Plan Subsection.
***** Running Evaluation *****
  Num examples = 597
  Batch size = 8
Saving model checkpoint to test_trainer/checkpoint-1000
Configuration saved in test_trainer/checkpoint-1000/config.json
Model weights saved in test_trainer/checkpoint-1000/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: Assessment, ROW ID, HADM ID, Plan Subsection.
***** Running Evaluation *****
  Num examples = 597
  Batch size = 8
Saving model checkpoint to test_trainer/checkpoint-1500
Configuration saved in test_trainer/checkpo

TrainOutput(global_step=1740, training_loss=0.45485844447694973, metrics={'train_runtime': 821.007, 'train_samples_per_second': 16.929, 'train_steps_per_second': 2.119, 'total_flos': 3657046227554304.0, 'train_loss': 0.45485844447694973, 'epoch': 3.0})