In [None]:
import re
from datetime import datetime

import polars as pl
from rapidfuzz import fuzz
from tqdm import tqdm
from make_clinical_dataset.shared.constants import ROOT_DIR
from make_clinical_dataset.epic.combine import get_clinic_prior_to_treatment
from make_clinical_dataset.epic.util import hash_text

In [None]:
DATE = '2025-03-29'
DATA_DIR = f"{ROOT_DIR}/data/final/data_{DATE}"

## ED Risk Summary

In [None]:
# take ALL notes within 5 days prior to a treatment session
clinic = pl.read_parquet(f'{DATA_DIR}/interim/clinic_visits.parquet')
chemo = pl.read_parquet(f'{DATA_DIR}/interim/chemo.parquet')
df = get_clinic_prior_to_treatment(clinic, chemo, lookback_window=5, strategy='all')
df.write_parquet(f'{DATA_DIR}/interim/subsets/clinic_visits_prior_to_treatment/notes.parquet')

In [None]:
# Run ml4o-batch-inference (see https://github.com/ml4oncology/ml4o-batch-inference)
# Example of the SLURM script below
"""
#!/bin/bash
#SBATCH --job-name=batch-inference
#SBATCH --partition=gpu
#SBATCH --account=grantgroup_gpu
#SBATCH --time=23:59:59
#SBATCH --nodes=1
#SBATCH --gres=gpu:l40:1
#SBATCH --cpus-per-task=8
#SBATCH --mem=32G
#SBATCH --output=/cluster/home/%u/logs/%j.out
#SBATCH --error=/cluster/home/%u/logs/%j.err

mkdir -p /cluster/home/$USER/logs

# Load Apptainer module
module load apptainer

# Load the paths
source .env

# Set up bind paths
export APPTAINER_BINDPATH=$APPTAINER_BINDPATH,$MODEL_PATH

# Run batch inference script inside the container
apptainer exec --nv $IMAGE_PATH python3.10 ~/repos/ml4o-batch-inference/batch_inference.py \
        --data-path $DATA_PATH \
        --output-path $OUTPUT_PATH \
        --prompt-path ~/repos/make-clinical-dataset/epic/prompts/ed_risk_summarizer.txt \
        --model-name Qwen_Qwen3-14B-IQ4_XS.gguf \
        --tokenizer-path $LLM_PATH/Qwen3-14B \
        --max-model-len 5120 \
        --max-num-seqs 42 \
"""

In [None]:
OUTPUT_PATH = "/cluster/projects/gliugroup/work_dir/kevin_he/BatchInferOutput/ed_risk_summary/batch_infer/generated_output"
df = pl.read_parquet(f"{OUTPUT_PATH}/*.parquet")
notes = pl.read_parquet(f'{DATA_DIR}/interim/subsets/clinic_visits_prior_to_treatment/notes.parquet', columns=['note_id', 'note'])
df, notes = df.unique('note_id'), notes.unique('note_id') # take the first if duplicated
df = df.join(notes, on='note_id', how='left')

In [None]:
# Separate into individual sections
# TODO: fix from source, ensure guided regex with vllm
SECTION_MAP = {
    "=== ACTIVE SYMPTOMS ===": "active_symptoms", 
    "=== RECENT COMPLICATIONS / ADVERSE EVENTS ===": "recent_complications",
    "=== RECENT HEALTHCARE UTILIZATION ===": "healthcare_utilization",
    "=== FUNCTIONAL STATUS / DECLINE ===": "functional_status",
    "=== MEDICATION-RELATED RISKS ===": "medication_risks",
    "=== PSYCHOSOCIAL / SUPPORT RISKS ===": "psychosocial_risks",
    "=== CLINICAL UNCERTAINTY / WATCHFUL WAITING ===": "clinical_uncertainty",
    "=== OVERALL ACUITY ASSESSMENT ===": "acuity_assessment",
}
SECTION_COLS = list(SECTION_MAP.values())


# remove samples where section was not present
# for section in SECTION_NAMES:
#     mask = ~df['generated_output'].str.contains(section)
#     print(f"Excluding {mask.sum()} ({mask.mean()*100:.2f}%) samples without section {section}")
#     df = df.filter(~mask)

data = []
for text in tqdm(df['generated_output']):
    res = {}

    # get all the headers from the generated output
    pattern = r"===\s*([^=]+)\s*==="
    matches = list(re.finditer(pattern, text))

    for i, match in enumerate(matches):
        header = text[match.start():match.end()]

        # find the section that matches the header the closest (LLM does make typos unfortunately)
        for section, name in SECTION_MAP.items():
            score = fuzz.ratio(header, section)
            if score >= 90:
                break

        # get the content of the section
        section_start = match.end()
        section_end = matches[i+1].start() if i < len(matches) - 1 else len(text)
        content = text[section_start:section_end]

        # clean up the content
        content = content.strip()

        # store in res
        res[name] = content
        
    data.append(res)

data = pl.DataFrame(data)
data = pl.concat([data, df.select('note_id', 'generated_output')], how='horizontal')

In [None]:
"""
Metadata
- store the metadata in the parquet file
- source, model, created_at, version, embedding_dim

Filepath - I need two tables
- interim/embedding/<model_name>/text_embeddings.parquet
- interim/embedding/<model_name>/ed_risk_summary.parquet

Schema
text_embedding -> text | text_id | embedding
ed_risk_summary -> section_1_text | section_1_text_id | ... | note_id | generated_output
"""

## PubMedBERT

In [None]:
version = 1
source_path = f"{DATA_DIR}/interim/clinic_visits.parquet"
model_path = f"{ROOT_DIR}/LLMs/PubMedBERT"
current_timestamp = str(datetime.now().date())

metadata = {
    "source": source_path,
    "model": model_path,
    "created_at": current_timestamp,
    "version": version
}

#### Sentence Transformer

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_path)

# add embedding dim to metadata
metadata["embedding_dim"] = model.get_sentence_embedding_dimension()

In [None]:
for col in SECTION_COLS:
    break

In [None]:
outputs = model.encode(prompts)