# Train relations model

This notebook shows the workflow I used to fine-tune the relations model of the end-to-end clinical reasoning engine. Some tweaks have been left out; contact me if you want to go into every detail as to how I trained it.

**Data:** synthetic, de-identified examples in `./db/training_data.jsonl`  
**Model base:** `michiyasunaga/BioLinkBERT-large`  
**Outputs:** checkpoints stored to `../pipeline_ingest/db/relations_model`, sample predictions stored to `./db/example_predictions.jsonl`

All used EHR notes are synthetic; no real patient data was used. 

In [12]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import numpy as np
import torch
import json
from collections import Counter
from itertools import product
from datasets import Dataset
from typing import List, Dict, Tuple
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    Trainer,
    TrainingArguments,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
    AutoModelForSequenceClassification,
    AutoConfig,
    AutoModel
)
from relations_helpers import *

In [2]:
# make sure GPU is available
print(torch.cuda.is_available())
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


In [17]:
model_name = 'michiyasunaga/BioLinkBERT-large'
config = AutoConfig.from_pretrained(model_name,num_labels=3,hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2) # use 0.2
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
model = model.to(device)

label_encoder = LabelEncoder()
label_encoder.fit(["TIME_RELATION", "NEGATION_RELATION", "NO_RELATION"])
data_trainset, data_devset = load_train_and_devsets_from_jsonl('/scratch/users/jsokol/escalation_project/data/extractions/batch2/batch2_prodigy_and_gpt_annotations_v4_manuallycorrected_fixed_formatting.jsonl',True,4)
data_trainset = prepare_relation_examples_with_chunking(data_trainset, tokenizer, label_encoder, overlap=0.25, max_positives_token_distance=100, max_control_token_distance=30, max_control_token_distance_lr=100, lr_controls_prop=0.25) 
data_devset = prepare_relation_examples_with_chunking(data_devset, tokenizer, label_encoder, overlap=0.25, max_positives_token_distance=100, max_control_token_distance=30, max_control_token_distance_lr=100, lr_controls_prop=0.25) 

# compute weighting scheme to feed to trainer
label_counts = Counter(data_trainset['labels'])
num_classes = len(label_counts)
total = sum(label_counts.values())
weights = torch.tensor([
    total / (label_counts.get(i, 1)) for i in range(num_classes)
], dtype=torch.float)

print([len(data_trainset),len(data_devset)])
print(next(model.parameters()).device)


[3746, 1023]
cuda:0


In [4]:
# inspect token distance distributions for each label
# import matplotlib.pyplot as plt
# dist_list = [data_trainset[i]['distance'] for i in range(3700) if data_trainset[i]['labels']==2]
# dist_list = [data_devset[i]['distance'] for i in range(len(data_devset)) if data_devset[i]['labels']==2]
# plt.hist(dist_list)
# plt.show()

In [5]:
# inspect a few examples to see if they are formatted correctly
# example_i = 296
# id2labels = {0:'NEGATION_RELATION',1:'NO_RELATION',2:'TIME_RELATION'}
# print(id2labels[data_trainset[example_i]['labels']])
# print(data_trainset[example_i]['distance'])
# print(data_trainset[example_i]['debug'])

In [15]:
# set up training arguments
training_args = TrainingArguments(
    output_dir="../pipeline_ingest/db/relations_model",
    learning_rate=5e-6, 
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16,
    num_train_epochs=20, 
    gradient_accumulation_steps=4,  # use 4-1
    warmup_steps=800, # use 800 
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    logging_steps=1000,
    weight_decay = 0.05, # use 0.1-0.05
    max_grad_norm = 1.0,
    lr_scheduler_type='cosine', 
    label_smoothing_factor=0.2, # 0.2-0.4
    fp16=True, # use for gatortron
)

# define trainer
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=data_trainset,
    eval_dataset=data_devset,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], # use 2-3
    compute_metrics=compute_metrics_with_confusion_matrix,
    class_weights=weights,
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [8]:
# train model
trainer.train()

In [10]:
# evaluate model using devset
print({0:'NEGATION_RELATION',1:'NO_RELATION',2:'TIME_RELATION'})
# trainer.evaluate(eval_dataset=data_trainset)
trainer.evaluate(eval_dataset=data_devset)

{0: 'NEGATION_RELATION', 1: 'NO_RELATION', 2: 'TIME_RELATION'}
              precision    recall  f1-score   support

           0     1.0000    0.8485    0.9180        33
           1     0.9572    0.9468    0.9520       733
           2     0.8687    0.9085    0.8881       284

    accuracy                         0.9333      1050
   macro avg     0.9420    0.9012    0.9194      1050
weighted avg     0.9346    0.9333    0.9336      1050



{'eval_loss': 0.3981826603412628,
 'eval_model_preparation_time': 0.0067,
 'eval_accuracy': 0.9333333333333333,
 'eval_confusion_matrix_0_0': 28,
 'eval_confusion_matrix_0_1': 5,
 'eval_confusion_matrix_0_2': 0,
 'eval_confusion_matrix_1_0': 0,
 'eval_confusion_matrix_1_1': 694,
 'eval_confusion_matrix_1_2': 39,
 'eval_confusion_matrix_2_0': 0,
 'eval_confusion_matrix_2_1': 26,
 'eval_confusion_matrix_2_2': 258,
 'eval_runtime': 8.0749,
 'eval_samples_per_second': 130.032,
 'eval_steps_per_second': 8.173}

In [7]:
# save/load model
# output_dir = "../pipeline_ingest/db/relations_model"
# trainer.save_model(output_dir) 
# tokenizer.save_pretrained(output_dir)
# tokenizer = AutoTokenizer.from_pretrained(output_dir)
# model = AutoModelForSequenceClassification.from_pretrained(output_dir)
# model = model.to(device)

In [9]:
# run a few predictions and inspect them
id2labels = {0:'NEGATION_RELATION',1:'NO_RELATION',2:'TIME_RELATION'}
label2id = {v: k for k, v in id2labels.items()}
jsonl_input_path = './db/training_data.jsonl'
jsonl_output_path = './db/example_predictions.jsonl'
allowed_relations = [("C_ENT", "TIME", "TIME_RELATION"),("TABLE", "TIME", "TIME_RELATION"),("C_ENT", "NEGATION", "NEGATION_RELATION")]
with open(jsonl_input_path, "r") as f:
    entries = [json.loads(line) for line in f]
with open(jsonl_output_path, "w") as out_f:
    for entry in entries[:5]: ### set how many notes to generate predictions over ###
        counter += 1
        text = entry["text"]
        spans = entry.get("spans", [])
        relations, xml_tagged_texts, xml_tagged_texts_no_relations = predict_relations_with_chunking(text,spans,model,tokenizer,label2id,id2labels,allowed_relations,device,prob_threshold=0.9,max_token_distance=100)
        relations_filtered = filter_relations_for_predictions(relations)
        output = {"text": text, "spans": spans, "relations": relations_filtered}
        out_f.write(json.dumps(output) + "\n")


In [10]:
# show relative counts of predicted negation/time relations
np.unique([relations_filtered[i]['label'] for i in range(len(relations_filtered))],return_counts=True)


(array(['NEGATION_RELATION', 'TIME_RELATION'], dtype='<U17'), array([ 9, 74]))

In [11]:
# spot check predicted relation
print(xml_tagged_texts[10]) # positive example (i.e. predicted negation/time relation)
# print(xml_tagged_texts_no_relations[10]) # negative example (i.e. predicted no relation)


---
**Patient Name:** Harold Hilll  
**MRN:** [Redacted]  
**DOB:** 03/24/1993  
**Sex:** Male  
**Admit Date:** 2024-04-13  
**Attending:** Dr. K. Adams  
**Location:** MICU, Bed 7  
**Consults:** Cardiology, Infectious Disease, PT/OT  
**Allergies:** NKDA  
**Code Status:** Full Code  
**Height:** 172 cm  
**Weight:** 81 kg  
**BMI:** 27.3 kg/m2  

---

### Chief Complaint
Progressive dyspnea, orthopnea, lower extremity edema.

---

### HPI
Mr. Harold Hilll is a 31-year-old male with a history of acute viral pharyngitis and recent right ankle sprain, presenting with 1 week of worsening shortness of breath, orthopnea, and bilateral leg swelling. He initially attributed symptoms to his viral illness and limited mobility. On the <CHILD>day of admission</CHILD>, he developed acute chest discomfort and <HEAD>near-syncope</HEAD>. EMS noted hypotension (SBP 72 mmHg), tachycardia, and hypoxia. He was emergently transferred to the ICU for presumed cardiogenic shock.  

Past 24 hours:  
- Intu