In [None]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder

import torch
import pandas as pd
import re
import numpy as np

MODEL = "answerdotai/ModernBERT-base"
CORPUS = "C:/Users/gioc4/Documents/blog/data/falls/neis.csv"
MAX_TOKEN_LENGTH = 256
CORPUS_SIZE = 10000

# load data
neis_data = pd.read_csv(CORPUS).head(CORPUS_SIZE)

# define a sentence transformer model
model = SentenceTransformer(MODEL)

No sentence-transformers model found with name answerdotai/ModernBERT-base. Creating a new one with mean pooling.


In [230]:
# we want the observations to be agnostic to patient age, so we remove those
# define remappings of abbreviations
# and strings to remove from narratives

remap = {
    "FX": "FRACTURE",
    "INJ": "INJURY",
    "LAC": "LACERATION",
    "CONT": "CONTUSION",
    "CHI" : "CLOSED HEAD INJURY",
    "ETOH": "ALCOHOL",
    "SDH": "SUBDURAL HEMATOMA",
    "AFIB": "ATRIAL FIBRILLATION",
    "NH": "NURSING HOME",
    "LTCF": "LONG TERM CARE FACILITY",
    "PT": "PATIENT",
    "LT": "LEFT",
    "RT": "RIGHT",
    "&" : " AND "
}
str_remove = "YOM|YOF|MOM|MOF|C/O|S/P|H/O|DX"


def process_text(txt):
    words = txt.split()
    new_words = [remap.get(word, word) for word in words]
    txt = " ".join(new_words)

    txt = re.sub("[^a-zA-Z ]", "", txt)
    txt = re.sub(str_remove, "", txt)

    return re.sub(r"^\s+", "", txt)

In [231]:
narrative_strings = neis_data['Narrative_1'].apply(process_text).tolist()

In [232]:
# encode verified falls, and neis narratives
narrative_embed = model.encode(narrative_strings)

In [307]:
# rank re-rank method

# first we get the top n most semantically similar sentences
# then use a bi-encoder to re-rank them
query = "HEAD INJURY AND RIB FRACTURES S/P FALLING DOWN A FLIGHT OF 10 STAIRS WHILE INTOXICATED WITH ALCOHOL. BAC NS."
N = 100

query_embed = model.encode(query)
sims = model.similarity(query_embed,narrative_embed)
idx = np.array(torch.topk(sims, N).indices)[0]

  idx = np.array(torch.topk(sims, N).indices)[0]


In [308]:
ce_list = []

for i in idx:
    ce_list.append([query, narrative_strings[i]])

In [309]:
ce_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-2-v2")
scores = ce_model.predict(ce_list)

In [310]:
ce_list[scores.argmax()][0]

'HEAD INJURY AND RIB FRACTURES S/P FALLING DOWN A FLIGHT OF 10 STAIRS WHILE INTOXICATED WITH ALCOHOL. BAC NS.'

In [311]:
ce_list[scores.argmax()][1]

'TRIPPED DOWN A FLIGHT OF STAIRS WHILE INTOXICATED NO BAC DRAWN  ACUTE ALCOHOL INTOXICATION HEMATOMA TO SCALP'