In [None]:
import re
from datetime import datetime

import polars as pl
from rapidfuzz import fuzz
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from tqdm import tqdm
from vllm import LLM

from make_clinical_dataset.shared.constants import ROOT_DIR
from make_clinical_dataset.epic.combine import get_clinic_prior_to_treatment
from make_clinical_dataset.epic.util import hash_text

In [None]:
DATE = '2025-03-29'
DATA_DIR = f"{ROOT_DIR}/data/final/data_{DATE}"

## ED Risk Summary

In [None]:
# take ALL notes within 5 days prior to a treatment session
clinic = pl.read_parquet(f'{DATA_DIR}/interim/clinic_visits.parquet')
chemo = pl.read_parquet(f'{DATA_DIR}/interim/chemo.parquet')
df = get_clinic_prior_to_treatment(clinic, chemo, lookback_window=5, strategy='all')
df.write_parquet(f'{DATA_DIR}/interim/subsets/clinic_visits_prior_to_treatment/notes.parquet')

In [None]:
# Extract ED risk summary from clinical notes by prompting Qwen3-14B
# Run ml4o-batch-inference (see https://github.com/ml4oncology/ml4o-batch-inference)
# Example of the SLURM script below
"""
#!/bin/bash
#SBATCH --job-name=batch-inference
#SBATCH --partition=gpu
#SBATCH --account=grantgroup_gpu
#SBATCH --time=23:59:59
#SBATCH --nodes=1
#SBATCH --gres=gpu:l40:1
#SBATCH --cpus-per-task=8
#SBATCH --mem=32G
#SBATCH --output=/cluster/home/%u/logs/%j.out
#SBATCH --error=/cluster/home/%u/logs/%j.err

mkdir -p /cluster/home/$USER/logs

# Load Apptainer module
module load apptainer

# Load the paths
source .env

# Set up bind paths
export APPTAINER_BINDPATH=$APPTAINER_BINDPATH,$MODEL_PATH

# Run batch inference script inside the container
apptainer exec --nv $IMAGE_PATH python3.10 ~/repos/ml4o-batch-inference/batch_inference.py \
        --data-path $DATA_PATH \
        --output-path $OUTPUT_PATH \
        --prompt-path ~/repos/make-clinical-dataset/epic/prompts/ed_risk_summarizer.txt \
        --model-name Qwen_Qwen3-14B-IQ4_XS.gguf \
        --tokenizer-path $LLM_PATH/Qwen3-14B \
        --max-model-len 5120 \
        --max-num-seqs 42 \
"""

In [None]:
# Load the generated output
OUTPUT_PATH = "/cluster/projects/gliugroup/work_dir/kevin_he/BatchInferOutput/ed_risk_summary/batch_infer/generated_output"
df = pl.read_parquet(f"{OUTPUT_PATH}/*.parquet")
notes = pl.read_parquet(f'{DATA_DIR}/interim/subsets/clinic_visits_prior_to_treatment/notes.parquet', columns=['note_id', 'note'])
df, notes = df.unique('note_id'), notes.unique('note_id') # take the first if duplicated
df = df.join(notes, on='note_id', how='left')

In [None]:
# Constants
SECTION_MAP = {
    "=== ACTIVE SYMPTOMS ===": "active_symptoms", 
    "=== RECENT COMPLICATIONS / ADVERSE EVENTS ===": "recent_complications",
    "=== RECENT HEALTHCARE UTILIZATION ===": "healthcare_utilization",
    "=== FUNCTIONAL STATUS / DECLINE ===": "functional_status",
    "=== MEDICATION-RELATED RISKS ===": "medication_risks",
    "=== PSYCHOSOCIAL / SUPPORT RISKS ===": "psychosocial_risks",
    "=== CLINICAL UNCERTAINTY / WATCHFUL WAITING ===": "clinical_uncertainty",
    "=== OVERALL ACUITY ASSESSMENT ===": "acuity_assessment",
}
SECTION_COLS = list(SECTION_MAP.values())

In [None]:
# Separate into individual sections

# remove samples where section was not present
# for section in SECTION_NAMES:
#     mask = ~df['generated_output'].str.contains(section)
#     print(f"Excluding {mask.sum()} ({mask.mean()*100:.2f}%) samples without section {section}")
#     df = df.filter(~mask)

data = []
for text in tqdm(df['generated_output']):
    res = {}

    # get all the headers from the generated output
    pattern = r"===\s*([^=]+)\s*==="
    matches = list(re.finditer(pattern, text))

    for i, match in enumerate(matches):
        header = text[match.start():match.end()]

        # find the section that matches the header the closest (LLM does make typos unfortunately)
        # TODO: fix from source, ensure guided regex with vllm
        for section, name in SECTION_MAP.items():
            score = fuzz.ratio(header, section)
            if score >= 90:
                break

        # get the content of the section
        section_start = match.end()
        section_end = matches[i+1].start() if i < len(matches) - 1 else len(text)
        content = text[section_start:section_end]

        # clean up the content
        content = content.strip()

        # store in res
        res[name] = content
        
    data.append(res)

data = pl.DataFrame(data)

# Keep note id and the original generated output
data = pl.concat([data, df.select('note_id', 'generated_output')], how='horizontal')

# Create text hash for each section
data = data.with_columns([pl.col(col).map_elements(hash_text, return_dtype=pl.String).alias(f"{col}_text_id") for col in SECTION_COLS])

data.write_parquet(f"{DATA_DIR}/interim/embedding/ed_risk_summary.parquet")

In [None]:
# Pre-compute the embeddings
data = pl.read_parquet(f"{DATA_DIR}/interim/embedding/ed_risk_summary.parquet")
data = pl.concat([data.select(pl.col(col).alias("text"), pl.col(f"{col}_text_id").alias("text_id")) for col in SECTION_COLS]).drop_nulls().unique()

version = "1.0.0"
source_path = f"{DATA_DIR}/interim/clinic_visits.parquet"
current_timestamp = datetime.now().date()

### PubMedBERT

In [None]:
model_path = f"{ROOT_DIR}/LLMs/PubMedBERT"
model = SentenceTransformer(model_path)
embed_dim = model.get_sentence_embedding_dimension()

# metadata values need to be all string
metadata = {
    "source": source_path,
    "model": model_path,
    "created_at": str(current_timestamp),
    "embedding_dim": str(embed_dim),
    "version": str(version)
}

# check number of texts that exceed max seq length
# tokenizer = AutoTokenizer.from_pretrained(model_path)
# token_lengths = tokenizer(data['text'].to_list(), add_special_tokens=False, return_length=True)["length"]
# print(f"Number of texts exceeding max seq length: {(pl.Series(token_lengths) > model.max_seq_length).sum()}")
# del tokenizer

outputs = model.encode(data['text'], batch_size=64)
data = data.with_columns(pl.Series("embedding", outputs))
data.write_parquet(f"{DATA_DIR}/interim/embedding/PubMedBERT/text_embedding_0.parquet", metadata=metadata)

### ModernBERT-large

In [None]:
model_path = f"{ROOT_DIR}/LLMs/ModernBERT-large"
model = SentenceTransformer(model_path)
embed_dim = model.get_sentence_embedding_dimension()

metadata = {
    "source": source_path,
    "model": model_path,
    "created_at": current_timestamp,
    "embedding_dim": embed_dim,
    "version": version
}

outputs = model.encode(data['text'], batch_size=64)
data = data.with_columns(pl.Series("embedding", outputs))
data.write_parquet(f"{DATA_DIR}/interim/embedding/ModernBERT-large/text_embedding_0.parquet", metadata=metadata)

### Try VLLM

In [None]:
# Need to write in batches or else OOM error occurs

model_path = f"{ROOT_DIR}/LLMs/PubMedBERT"
# model_path = f"{ROOT_DIR}/LLMs/ModernBERT-large" # dang, not supported yet
model = LLM(model=model_path, task="embed")
# max_seq_len = model.model_config.max_model_len
# embed_dim = model.model_config.hf_config.hidden_size

batch_size = int(1e5)
for batch_num, i in enumerate(tqdm(range(0, len(data), batch_size))):
    outputs = model.embed(data['text'][i:i+batch_size])
    embeddings = [output.outputs.embedding for output in outputs]
    break