### Training a Sensitive Information NER Model using Auto-Annotated Clinical Notes

This script auto-annotates clinical notes from the `SYNTHETIC_TEXT` column in your CSV.
Sensitive information is marked by placeholders of the form `[** ... **]`. We use custom
heuristics to label the text inside these placeholders:
- If the content matches a date pattern (YYYY-MM-DD), label as DATE.
- If the content matches a phone pattern (e.g., 555-0109), label as PHONE.
- If the content contains "hospital" (case-insensitive), label as LOCATION.
- If the content is a numeric range (e.g., 5-9), label as NUMERIC.
- Otherwise, label as NAME.
Tokens are then annotated in a BIO scheme (e.g., B-DATE, I-DATE).

After annotation, the data is converted into a Hugging Face Dataset, tokenized (with label alignment)
and used to fine-tune a pre-trained model (`BioBERT` in this example) for token classification.

Dependencies:
- pandas
- nltk
- transformers
- datasets
- scikit-learn



In [1]:
%pip install datasets transformers torch
%pip install accelerate>=0.26.0

Note: you may need to restart the kernel to use updated packages.
zsh:1: 0.26.0 not found
Note: you may need to restart the kernel to use updated packages.


In [2]:
import re
import random
import pandas as pd
import nltk
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification

nltk.download('punkt_tab')

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt_tab to /Users/areef/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [3]:
# ----------------------------
# Module 1: Data Loading & Preprocessing
# ----------------------------
def load_data(file_path: str) -> pd.DataFrame:
    """Loads the CSV file with clinical notes from the column SYNTHETIC_TEXT."""
    try:
        df = pd.read_csv(file_path, dtype={4: str, 5: str})
        print("Number of notes loaded:", len(df))
        return df
    except Exception as e:
        print("Error loading dataset:", str(e))
        exit(1)

def preprocess_text(text: str) -> str:
    """For training annotation we leave the placeholders intact and just strip extra whitespace."""
    return text.strip()

def preprocess_dataframe(df: pd.DataFrame, text_column: str = "SYNTHETIC_TEXT") -> pd.DataFrame:
    df['TEXT_PRE'] = df[text_column].apply(preprocess_text)
    return df

In [4]:
# ----------------------------
# Module 2: Auto-Annotation of Sensitive Tokens
# ----------------------------
def infer_label(placeholder_content: str) -> str:
    """
    Infers a sensitive label from the content inside a placeholder.
    Heuristics:
      - If content matches a date pattern (YYYY-MM-DD), return "DATE".
      - If content matches a phone pattern (e.g., 555-0109), return "PHONE".
      - If content matches a numeric range (e.g., 5-9), return "NUMERIC".
      - If content contains "hospital" (case-insensitive), return "LOCATION".
      - Otherwise, return "NAME".
    """
    content = placeholder_content.strip()
    if re.match(r"^\d{4}-\d{1,2}-\d{1,2}$", content):
        return "DATE"
    if re.match(r"^\d{3}-\d{4}$", content):
        return "PHONE"
    if re.match(r"^\d+-\d+$", content):
        return "NUMERIC"
    if "hospital" in content.lower():
        return "LOCATION"
    return "NAME"

def annotate_text(text: str):
    """
    Auto-annotates a clinical note.
    Sensitive information is assumed to be enclosed within [** ... **].
    Tokens inside such a span are annotated with a BIO scheme using the label
    inferred from the content. Tokens outside are labeled "O".
    
    Returns:
        tokens (list[str]), ner_tags (list[str])
    """
    annotations = []
    last_index = 0
    # Find placeholders
    for match in re.finditer(r"\[\*\*(.*?)\*\*\]", text):
        start, end = match.span()
        sensitive_content = match.group(1)
        label = infer_label(sensitive_content)
        # Annotate tokens before the sensitive span as non-sensitive.
        before = text[last_index:start]
        tokens_before = nltk.word_tokenize(before)
        for token in tokens_before:
            annotations.append((token, "O"))
        # Annotate sensitive tokens using BIO scheme.
        sensitive_tokens = nltk.word_tokenize(sensitive_content)
        for i, token in enumerate(sensitive_tokens):
            if i == 0:
                annotations.append((token, f"B-{label}"))
            else:
                annotations.append((token, f"I-{label}"))
        last_index = end
    # Annotate any remaining tokens after the last placeholder.
    after = text[last_index:]
    tokens_after = nltk.word_tokenize(after)
    for token in tokens_after:
        annotations.append((token, "O"))
    tokens = [token for token, tag in annotations]
    ner_tags = [tag for token, tag in annotations]
    return tokens, ner_tags

def convert_csv_to_ner_dicts(csv_path: str):
    """
    Converts the CSV file into a list of dictionaries with keys:
      "id", "tokens", "ner_tags".
    Each row is a complete clinical note.
    """
    df = load_data(csv_path)
    ner_data = []
    for i, row in df.iterrows():
        text = row["SYNTHETIC_TEXT"]
        tokens, tags = annotate_text(text)
        ner_data.append({
            "id": i,
            "tokens": tokens,
            "ner_tags": tags
        })
    return ner_data

# Test annotation on a sample note.
sample_note = """Admission Date: [**1914-12-13**]       Discharge Date: [**1952-09-09**]

Date of Birth:                    Sex:  F

Service:  MICU and then to [**Thompson**] Medicine

HISTORY OF PRESENT ILLNESS:  This is an 81-year-old female with a history of emphysema (not on home O2), who presents with three days of shortness of breath. Presented to the [**County Hospital**] Emergency Room. Followup with Dr. [**Jackson**] at [**555-0109**]."""
tokens, ner_tags = annotate_text(sample_note)
print("Sample Tokens:", tokens)
print("Sample Labels:", ner_tags)


Sample Tokens: ['Admission', 'Date', ':', '1914-12-13', 'Discharge', 'Date', ':', '1952-09-09', 'Date', 'of', 'Birth', ':', 'Sex', ':', 'F', 'Service', ':', 'MICU', 'and', 'then', 'to', 'Thompson', 'Medicine', 'HISTORY', 'OF', 'PRESENT', 'ILLNESS', ':', 'This', 'is', 'an', '81-year-old', 'female', 'with', 'a', 'history', 'of', 'emphysema', '(', 'not', 'on', 'home', 'O2', ')', ',', 'who', 'presents', 'with', 'three', 'days', 'of', 'shortness', 'of', 'breath', '.', 'Presented', 'to', 'the', 'County', 'Hospital', 'Emergency', 'Room', '.', 'Followup', 'with', 'Dr', '.', 'Jackson', 'at', '555-0109', '.']
Sample Labels: ['O', 'O', 'O', 'B-DATE', 'O', 'O', 'O', 'B-DATE', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NAME', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOCATION', 'I-LOCATION', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NAME',

In [5]:
# ----------------------------
# Module 3: Tokenization and Label Alignment for Training
# ----------------------------

model_checkpoint = "dmis-lab/biobert-base-cased-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# Define our label list (using BIO scheme).
label_list = ["O", "B-DATE", "I-DATE", "B-NAME", "I-NAME", "B-LOCATION", "I-LOCATION", "B-NUMERIC", "I-NUMERIC", "B-PHONE", "I-PHONE"]
num_labels = len(label_list)
label_to_id = { label: i for i, label in enumerate(label_list) }
id_to_label = { i: label for i, label in enumerate(label_list) }

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"],
                                 truncation=True,
                                 padding="max_length",
                                 max_length=128,
                                 is_split_into_words=True)
    all_labels = []
    for i, words in enumerate(examples["tokens"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        labels = []
        for word_idx in word_ids:
            if word_idx is None:
                labels.append(-100)
            else:
                label = examples["ner_tags"][i][word_idx]
                if word_idx == previous_word_idx and label.startswith("B-"):
                    label = "I-" + label[2:]
                labels.append(label_to_id[label])
            previous_word_idx = word_idx
        all_labels.append(labels)
    tokenized_inputs["labels"] = all_labels
    return tokenized_inputs

In [6]:
# ----------------------------
# Module 4: Prepare the Dataset for Training
# ----------------------------

ner_data = convert_csv_to_ner_dicts("data/SYNTHETIC_DISCHARGE_REPORTS.csv")
dataset = Dataset.from_list(ner_data)
dataset_dict = dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = dataset_dict["train"]
val_dataset = dataset_dict["test"]

train_dataset = train_dataset.map(tokenize_and_align_labels, batched=True)
val_dataset = val_dataset.map(tokenize_and_align_labels, batched=True)

Number of notes loaded: 11940


Map: 100%|██████████| 9552/9552 [00:37<00:00, 258.02 examples/s]
Map: 100%|██████████| 2388/2388 [00:09<00:00, 257.06 examples/s]


In [7]:
# ----------------------------
# Module 5: Training the NER Model
# ----------------------------

model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    id2label=id_to_label,
    label2id=label_to_id
)

training_args = TrainingArguments(
    output_dir="model_checkpoint_v2",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    logging_steps=50,
    weight_decay=0.01,
)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

def compute_metrics(p):
    # Implement token-level precision/recall/F1 here if desired.
    return {}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.evaluate()
trainer.save_model("model_v2")

Some weights of BertForTokenClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,0.0041,0.002975
2,0.0021,0.002027
3,0.0014,0.00189
