# Import

In [23]:
import os
from pathlib import Path
import sys
import numpy as np 
import tqdm
import torch
parent_root = Path.cwd().parent
project_root = os.path.join(parent_root, "src")
sys.path.insert(0, str(project_root))
from pytorch_lightning import seed_everything
from accelerate import Accelerator
import matplotlib.pyplot as plt
import hydra

from hydra.core.global_hydra import GlobalHydra
from hydra import compose, initialize_config_dir
from omegaconf import DictConfig, OmegaConf
from hydra import compose, initialize

from datamodule import DataModule_nlp
from clinicalbert_pl import Clinicalbert_pl
from sklearn.metrics import f1_score, accuracy_score
from transformers import AutoTokenizer, DataCollatorWithPadding



In [5]:
out_dir = Path("..") / "src"
data_dir = out_dir / "data"

# Def model / Datamodule 

In [7]:
@hydra.main(config_path="config", config_name="config_seg", version_base="1.3")
def main(cfg: DictConfig):

    seed_everything(cfg.get("seed", 42), workers=True)

    cfg_data = cfg.data
    cfg_model = cfg.model

    datamodule = DataModule_nlp(cfg_data)
    model = Clinicalbert_pl(cfg_model)
    return model, datamodule


GlobalHydra.instance().clear()
config_path = os.path.join(out_dir, "config")

with initialize(version_base=None, config_path=config_path, job_name="nb"):
    cfg = compose(config_name="config")
    model, datamodule = main(cfg)
    

Seed set to 42


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
datamodule.setup(stage="test")
test_loader = datamodule.test_dataloader()



Map: 100%|██████████| 126491/126491 [00:21<00:00, 6016.84 examples/s]
Map: 100%|██████████| 15811/15811 [00:02<00:00, 5995.81 examples/s]
Map: 100%|██████████| 15812/15812 [00:02<00:00, 6046.68 examples/s]


In [9]:
ckpt_path = out_dir / "checkpoints/best.ckpt"
model = Clinicalbert_pl.load_from_checkpoint(
    ckpt_path,
    weights_only = False
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
accelerator = Accelerator(mixed_precision="bf16")
model = accelerator.prepare(
    model
)

model.eval()

Clinicalbert_pl(
  (model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features

# Test model

In [18]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

test_logits, test_labels, test_note_map = [], [], []

with torch.no_grad():
    for batch in tqdm.tqdm(test_loader, desc="Test eval"):
        labels = batch.pop("labels")
        note_map = batch.pop("note_map") # [B] mapping each chunk -> original note id (or index)
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits

        
        test_logits.append(logits.detach().cpu())
        test_labels.append(labels.detach().cpu())
        test_note_map.append(note_map.detach().cpu())
        
logits = torch.cat(test_logits, dim=0)
y_true = torch.cat(test_labels, dim=0).numpy()

probs = torch.sigmoid(logits.float()).cpu().numpy()
note_map = torch.cat(test_note_map, dim=0).numpy()
unique_note = np.unique(note_map)
C = probs.shape[1]

probs_note = np.zeros((len(unique_note), C), dtype = np.float32)
y_true_note = np.zeros((len(unique_note), C), dtype = np.int32)

note_to_row = {nid : i for i, nid in enumerate(unique_note)}

for i in range(probs.shape[0]):
    r = note_to_row[note_map[i]]
    probs_note[r] = np.maximum(probs_note[r], probs[i])
    y_true_note[r] = np.maximum(y_true_note[r], y_true[i])
    
y_pred_note = (probs_note >= 0.5).astype(np.int32)
f1_micro = f1_score(y_true_note, y_pred_note, average="micro", zero_division=0)
f1_macro = f1_score(y_true_note, y_pred_note, average="macro", zero_division=0)
subset_acc = accuracy_score(y_true_note, y_pred_note) 

with open(data_dir / "label_names.txt") as f:
    label_names = [l.strip() for l in f.readlines()]        

results = {
    "f1_micro": float(f1_micro),
    "f1_macro": float(f1_macro),
    "accuracy": float(subset_acc),
    "probs_note" : probs_note,
    "y_true_note" : y_true_note,
    "label_names" : label_names
}
print(results)


Test eval: 100%|██████████| 321/321 [00:16<00:00, 19.61it/s]


{'f1_micro': 0.993259198832167, 'f1_macro': 0.9934201149211789, 'accuracy': 0.9907032633442955, 'probs_note': array([[2.7803096e-04, 4.5831289e-04, 1.0889691e-04, ..., 2.1654405e-04,
        1.2339458e-04, 8.2958920e-04],
       [2.3050673e-04, 5.8840885e-04, 1.3982208e-04, ..., 1.3982208e-04,
        1.5843622e-04, 5.7031569e-04],
       [3.8244836e-03, 3.1726826e-03, 2.1827165e-03, ..., 1.8102111e-03,
        1.5978456e-03, 3.4834242e-03],
       ...,
       [2.3050673e-04, 3.3535014e-04, 1.0889691e-04, ..., 2.1654405e-04,
        1.5843622e-04, 1.8102111e-03],
       [8.2958920e-04, 9.1105117e-04, 4.5831289e-04, ..., 7.5540564e-04,
        4.7285444e-04, 9.9566853e-01],
       [5.5549247e-03, 6.2899021e-03, 5.9110685e-03, ..., 2.8009270e-03,
        3.0753696e-03, 9.9193794e-01]], shape=(15812, 11), dtype=float32), 'y_true_note': array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0

In [19]:
def print_note_preds(
    probs_note,
    y_true_note,
    label_names,
    k=5,
    threshold=0.5
):
    for i in range(k):
        print(f"\nNOTE {i}")
        
        true_labels = [
            label_names[j]
            for j in range(len(label_names))
            if y_true_note[i, j] == 1
        ]

        pred_labels = [
            f"{label_names[j]} ({probs_note[i, j]:.2f})"
            for j in range(len(label_names))
            if probs_note[i, j] >= threshold
        ]

        print("True:", true_labels)
        print("Pred:", pred_labels)
        print("-" * 50)


In [21]:
print_note_preds(results["probs_note"], results["y_true_note"], results["label_names"])


NOTE 0
True: []
Pred: []
--------------------------------------------------

NOTE 1
True: []
Pred: []
--------------------------------------------------

NOTE 2
True: ['depression_anxiety']
Pred: ['depression_anxiety (0.99)']
--------------------------------------------------

NOTE 3
True: ['depression_anxiety', 'obesity']
Pred: ['depression_anxiety (0.99)', 'obesity (0.99)']
--------------------------------------------------

NOTE 4
True: ['cancer']
Pred: ['cancer (1.00)']
--------------------------------------------------


In [28]:
def predict_single_note(text, model, tokenizer, device):
    enc = tokenizer(
        text,
        truncation=True,
        max_length=256,
        stride=128,
        return_overflowing_tokens=True,
        return_tensors="pt"
    )
    enc.pop("overflow_to_sample_mapping", None)
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.sigmoid(logits)
    
    probs_note = probs.max(dim=0).values
    return probs_note.cpu().numpy()


In [29]:
text = "Patient with long history of type 2 diabetes and chronic hypertension..."
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", use_fast=True)
probs = predict_single_note(text, model, tokenizer, device)


In [30]:
for label, p in zip(results["label_names"], probs):
    if p >= 0.5:
        print(f"{label}: {p:.2f}")

diabetes: 0.89
hypertension: 0.95
