In [None]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder

import numpy as np
import pandas as pd
import re

from src.search_funcs import RetrieveReranker

# local vars
BI_ENCODER_MODEL = "answerdotai/ModernBERT-base"
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
CORPUS = "C:/Users/gioc4/Documents/blog/data/falls/neis.csv"
MAX_TOKEN_LENGTH = 256
CORPUS_SIZE = 10000

# we want the observations to be agnostic to patient age, so we remove those
# define remappings of abbreviations
# and strings to remove from narratives

remap = {
    "FX": "FRACTURE",
    "INJ": "INJURY",
    "LAC": "LACERATION",
    "LOC": "LOSS OF CONCIOUSNESS",
    "CONT": "CONTUSION",
    "CHI" : "CLOSED HEAD INJURY",
    "ETOH": "ALCOHOL",
    "SDH": "SUBDURAL HEMATOMA",
    "AFIB": "ATRIAL FIBRILLATION",
    "NH": "NURSING HOME",
    "LTCF": "LONG TERM CARE FACILITY",
    "PT": "PATIENT",
    "LT": "LEFT",
    "RT": "RIGHT",
    "&" : " AND "
}
str_remove = "YOM|YOF|MOM|MOF|C/O|S/P|H/O|DX"


def process_text(txt):
    words = txt.split()
    new_words = [remap.get(word, word) for word in words]
    txt = " ".join(new_words)

    txt = re.sub("[^a-zA-Z ]", "", txt)
    txt = re.sub(str_remove, "", txt)

    return re.sub(r"^\s+", "", txt)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# strings to encode as searchable

# load data
neis_data = pd.read_csv(CORPUS).head(CORPUS_SIZE)
narrative_strings = neis_data['Narrative_1'].apply(process_text).tolist()

# define models and ranker
biencoder = SentenceTransformer(BI_ENCODER_MODEL)
crossencoder = CrossEncoder(CROSS_ENCODER_MODEL)

No sentence-transformers model found with name answerdotai/ModernBERT-base. Creating a new one with mean pooling.


In [3]:
# set up a Retriveal-Ranker class
ranker = RetrieveReranker(
    corpus=narrative_strings,
    bi_encoder_model=biencoder,
    cross_encoder_model=crossencoder,
    save_corpus=True,
    corpus_path="C:/Users/gioc4/Documents/blog/data/corpus.pkl"
)

In [10]:
# now pass queries 

query = ["14YOF PRESENTS WITH LEFT THUMB PAIN AFTER SHE WAS TRYING TO GET UP OUT OF CHAIR WHEN SHE PUSHED UP FROM THE ARM OF THE CHAIR WITH HER THUMB AND FELT HER THUMB HYPEREXTEND AND THE FELT SHARP PAIN. DX: THUMB SPRAIN",
         "96YOF WAS GOING TO BATHROOM AT ASSISTED LIVINGN, MISSTEPPED AND MISSED TOILET, FELL, POSSIBLY HIT HEAD ON SINK, DX FALL, HEAD INJ?",
         "52YOF REPORTS MVA. PT STATES SHE WAS ON A BICYCLE AND BACK OF BIKE GOT CLIPPED BY A CAR AND STATES SHE FEL LON HER LT LEG. DX:CLOSED FX OF SHAFT OF LT FIBULA"]

for q in query:
    output = ranker.query(process_text(q), number_ranks=100, number_results=3)
    print(output)

['PRESENTS WITH LEFT THUMB INJURY SHE WAS DOING A CHEERLEADING MANEUVER HOLDING  OF THE CHEERLEADERS FOOT AND THE LOWER EXTREMITY IN THE AIR AND ACCIDENTALLY HELD THE FOOT UP WITH HER THUMB IN THE WROTN DIRECTION AND FELT LIKE IT HYPEREXTENDED BACKWARD AND SHE FELT A POP  LEFT THUMB SPRAIN', 'WAS PLAYING BASKETBALL WHEN THE BALL HIT HIS LEFT THUMB AND JAMMED IT NOW WITH THUMB PAIN AND SWELLING  SPRAIN OF LEFT THUMB', 'PRESENTS FOR RIGHT THUMB INJURY MOTHER STATES PATIENT HAD HIS RIGHT THUMB IN THE HINGE OF A TOY CHEST AND STATES THE CHEST SLAMMED ON HIS RIGHT THUMB  LACERATION OF RIGHT THUMB WITHOUT FOREIGN BODY NAIL DAMAGE STATUS UNSPECFIED']
['INJURY MID BACK LOST BALANCE IN BATHROOM FELL STRUCK BACK ON EDGE OF BATHTUB AT ASSISTED LIVING  T COMPRESSION FRACTURE ', 'FROM ASSISTED LIVING GETTING OUT OF SHOWER INTO WHEELCHAIR AND FELL TO FLOOR  HEMATOMA TO HEAD HEMORRHAGE SUBDURAL TRAUMATIC FALL HEAD INJURY', 'FELL GOT HIS HEAD WEDGED BETWEEN THE TOILET AND THE BATHTUB  CLOSED HEAD INJU