# Train NER model

This notebook shows the workflow I used to fine-tune the NER model of the end-to-end clinical reasoning engine. Some tweaks have been left out; contact me if you want to go into every detail as to how I trained it.

**Data:** synthetic, de-identified examples in `./db/training_data.jsonl`  
**Model base:** `UFNLP/gatortron-base`  
**Outputs:** checkpoints stored to `../pipeline_ingest/db/ner_model`, sample predictions stored to `./db/example_predictions.jsonl`

All used EHR notes are synthetic; no real patient data was used. 

In [10]:
import numpy as np
import torch
import json
from datasets import Dataset
from transformers import (AutoTokenizer,TrainingArguments,Trainer,DataCollatorForTokenClassification,EarlyStoppingCallback)
from ner_helpers import *

In [11]:
# make sure GPU is available
print(torch.cuda.is_available())
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


In [12]:
# load data, tokenizer and model
raw_data_trainset, raw_data_devset = load_train_and_devsets_from_jsonl('./db/training_data.jsonl')
model_name = 'UFNLP/gatortron-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
label_set_trainset = {span['label'] for example in raw_data_trainset for span in example.get('spans',[])}
label_set_devset = {span['label'] for example in raw_data_devset for span in example.get('spans',[])}
label_list_trainset = ['O'] + [f"{prefix}-{label}" for label in sorted(label_set_trainset) for prefix in ('B','I')]
label_list_devset = ['O'] + [f"{prefix}-{label}" for label in sorted(label_set_devset) for prefix in ('B','I')]
label2id = {label: i for i,label in enumerate(label_list_trainset)}
id2label = {value: key for key, value in label2id.items()}

chunked_data_trainset, chunked_data_devset = [],[]
for example in raw_data_trainset:
    chunked_data_trainset.extend(chunk_text(example, tokenizer))
for example in raw_data_devset:
    chunked_data_devset.extend(chunk_text(example, tokenizer))

dataset_trainset = Dataset.from_list(chunked_data_trainset)
dataset_devset = Dataset.from_list(chunked_data_devset)
dataset_trainset = dataset_trainset.map(lambda x: convert_spans_to_bio_labels(x, tokenizer, label2id, expand_entities=False), batched=False)
dataset_devset = dataset_devset.map(lambda x: convert_spans_to_bio_labels(x, tokenizer, label2id), batched=False)

model = WeightedTokenClassificationModel(base_model_name=model_name,label_list=label_list_trainset,weight_boosts={"TIME":3.0,"NEGATION":3.0}) # to penalize false negatives for TIME+NEGATION entities more (i.e. to force the model to find all TIME+NEGATION entities)
model = model.to(device)
print(next(model.parameters()).device)

Map:   0%|          | 0/2323 [00:00<?, ? examples/s]

Map:   0%|          | 0/18 [00:00<?, ? examples/s]

cuda:0


In [13]:
# sanity check input data by reviewing a few examples
# example_i = 7
# id2label = {value: key for key, value in label2id.items()}
# id2label[-100] = 'O'
# print(len(dataset_trainset[example_i]['input_ids']))
# print(len([id2label[item] for item in dataset_trainset[example_i]['labels']]))
# print(render_labeled_text_with_xml_sanity_check(dataset_trainset[example_i]['text'],dataset_trainset[example_i]['input_ids'],[id2label[item] for item in dataset_trainset[example_i]['labels']],tokenizer))


In [14]:
# set dropout rates
set_dropout(model,0.1) # use 0.1 dropouts for all training

# set up training arguments
training_args = TrainingArguments(
    output_dir='../pipeline_ingest/db/ner_model',
    learning_rate=2e-5, 
    per_device_train_batch_size=8, 
    per_device_eval_batch_size=16,
    num_train_epochs=45, # 45
    gradient_accumulation_steps=2, # 1-2
    warmup_steps=25,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    logging_steps=1,
    weight_decay = 0.05, # use 0.01-0.05
    max_grad_norm = 1.0, # use 1.0 
    lr_scheduler_type='cosine', # use cosine for main training
    label_smoothing_factor=0.1, # use 0.1
)

# define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_trainset,
    eval_dataset=dataset_devset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=training_metrics_func(list(label2id.keys())),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], # 2-4
)


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
# train model
trainer.train()

In [7]:
# evaluate model
# trainer.evaluate(eval_dataset=dataset_trainset)
trainer.evaluate(eval_dataset=dataset_devset)

              precision    recall  f1-score   support

    ADM_TIME       1.00      1.00      1.00         5
         AGE       1.00      1.00      1.00         5
       C_ENT       0.95      0.96      0.96       573
         DOB       1.00      1.00      1.00         5
    NEGATION       0.84      1.00      0.92        27
        PLAN       0.94      0.94      0.94       193
         SEX       1.00      1.00      1.00         9
       TABLE       1.00      1.00      1.00        15
        TIME       0.94      0.86      0.90       150

   micro avg       0.95      0.94      0.95       982
   macro avg       0.96      0.97      0.97       982
weighted avg       0.95      0.94      0.95       982



{'eval_loss': 0.6383048295974731,
 'eval_model_preparation_time': 0.0067,
 'eval_accuracy': 0.9870283018867925,
 'eval_precision': 0.9487704918032787,
 'eval_recall': 0.9429735234215886,
 'eval_f1': 0.9458631256384066,
 'eval_runtime': 0.933,
 'eval_samples_per_second': 19.292,
 'eval_steps_per_second': 2.144}

In [7]:
# save/load model
# out_dir = "../pipeline_ingest/db/ner_model"
# torch.save(model.state_dict(), out_dir+"/model_weights.pt")
# tokenizer.save_pretrained(out_dir)
# with open(out_dir+"/labels.json", "w", encoding="utf-8") as f: json.dump({"label_list": label_list_trainset,"label2id": label2id,"id2label": id2label}, f)
# model.load_state_dict(torch.load(out_dir+"/model_weights.pt"))
# tokenizer = AutoTokenizer.from_pretrained(out_dir)

In [8]:
# run a few predictions and inspect them
jsonl_input_path = './db/training_data.jsonl'
jsonl_output_path = './db/example_predictions.jsonl'

with open(jsonl_input_path, "r") as f:
    lines = [json.loads(line) for line in f]
output_annotations_list = []
with open(jsonl_output_path, "w") as out:
    for example in lines[:5]: ### set how many notes to generate predictions over ###
        text = example["text"]
        spans = predict_entities(text, model, tokenizer, id2label, device=device, prob_threshold=0.6)
        spans = filter_overlapping_spans_for_predictions(spans=spans, min_char_length=1)
        output = {"text": text, "spans": spans}
        output_annotations_list.append(output)
        out.write(json.dumps(output) + "\n")

# show XML-annotated example
predictions_xml_tagged_format = convert_prodigy_jsonl_to_tagged_text_and_relations(jsonl_output_path)
example_i = 1
print(predictions_xml_tagged_format[example_i][1])

**EPIC EHR - ICU ADMISSION SUMMARY**

---
**Patient Name:** Melina Zieme  
**MRN:** 00938491-ICU  
**DOB:** <dob id="e1">09/24/1968</dob>  
**Sex:** <sex id="e2">Female</sex>  
**Admit Date:** <adm_time id="e3">2024-05-22</adm_time>  
**Attending:** Dr. R. Carver  
**Consults:** Cardiology, Critical Care, Infectious Disease, Endocrinology  
**Location:** MICU, Bed 14  

---

### Hospital Course Summary

**History of Present Illness:**  
Melina Zieme, a <age id="e4">55-year-old</age> <sex id="e5">female</sex> with a <time id="e6">history</time> notable for <c_ent id="e7">cardiac arrest</c_ent>, <c_ent id="e8">obesity</c_ent>, and <time id="e9">recent</time> <c_ent id="e10">upper respiratory tract infection</c_ent>, was <c_ent id="e11">found unresponsive at home</c_ent> by her daughter. EMS reported <c_ent id="e12">pulseless electrical activity (PEA) arrest</c_ent>; <c_ent id="e13">ROSC achieved</c_ent> <time id="e14">after 12 minutes of ACLS</time>. <c_ent id="e15">Brought to ED intubat