In [68]:
import os
import pandas as pd
import tqdm
import regex as re
import polars as pl
import tiktoken
import numpy as np

In [69]:
def categorize_body_group(exam_type):
    categories = {
        "Brain": ["HEAD", "BRAIN", "STROKE PROTOCOL", "NEURO"],
        "Abdomen/Pelvis": ["ABDOMEN", "PELVIS", "PROSTATE", "RENAL", "MRCP", "UROGRAM"],
        "MSK": ["KNEE", "HIP", "SHOULDER", "EXTREMITY", "SPINE", "LUMBAR", 
                "CERVICAL", "SCOLIOSIS", "JOINT"],
        "Chest": ["CHEST", "CARDIAC", "PULMONARY EMBOLISM", "CTA", "HEART"],
        "Head/Neck": ["FACE", "NECK", "CRANIOFACIAL", "MAXILLOFACIAL", "ORBIT", "TEMPORAL BONE", 
                 "SINUS", "THYROID", "MANDIBLE", "SKULL"]
    }
    
    for category, keywords in categories.items():
        if any(keyword in exam_type for keyword in keywords):
            return category
    return "Other"

def categorize_imaging_modality(exam_type):
    categories = {
        "MRI": ["MR", "MRI"],
        "CT": ["CT", "CTA"],
        "XR": ["XR"],
        "US": ["US"]
    }
    
    for category, keywords in categories.items():
        if any(keyword in exam_type for keyword in keywords):
            return category
    return "Other"

# Clinical Indication Dataset

In [70]:
intervals = [
    "0_10000",
    "10000_20000",
    "20000_30000",
    "30000_40000",
    "40000_50000",
    "50000_60000",
    "60000_70000",
    "70000_79032"
]

indication_dataset = []

for interval in tqdm.tqdm(intervals):
    parquet_path = f"/mnt/sohn2022/Adrian/rad-llm-pmhx/dataset/indication_dataset/raw/{interval}.parquet"
    indication_dataset.append(pd.read_parquet(parquet_path).reset_index(drop=True))
    
indication_dataset = pd.concat(indication_dataset) \
.drop_duplicates(subset=["radiology_deid_note_key"]) \
.reset_index(drop=True)

100%|█████████████████████████████████████████████| 8/8 [00:55<00:00,  6.96s/it]


In [71]:
pathophysiological_intervals = [
    "llm_labels_pathophysiological_0_10000_2025-02-13_13-47-32.csv",
    "llm_labels_pathophysiological_10000_20000_2025-02-13_15-46-40.csv",
    "llm_labels_pathophysiological_20000_30000_2025-02-13_21-56-05.csv",
    "llm_labels_pathophysiological_30000_40000_2025-02-14_13-33-10.csv",
    "llm_labels_pathophysiological_40000_50000_2025-02-14_15-49-35.csv",
    "llm_labels_pathophysiological_50000_60000_2025-02-14_18-27-19.csv",
    "llm_labels_pathophysiological_60000_70000_2025-02-14_21-49-20.csv",
    "llm_labels_pathophysiological_70000_77984_2025-02-15_17-20-50.csv"
]

pathophysiological_labels = []

for interval in tqdm.tqdm(pathophysiological_intervals):
    csv_path = f"/mnt/sohn2022/Adrian/rad-llm-pmhx/dataset/pathophysiological_labels/{interval}"
    pathophysiological_labels.append(pd.read_csv(csv_path).reset_index(drop=True))
    
pathophysiological_labels = pd.concat(pathophysiological_labels)

100%|█████████████████████████████████████████████| 8/8 [00:00<00:00, 15.64it/s]


In [72]:
indication_dataset = (
    indication_dataset
    .merge(pathophysiological_labels)
    .drop_duplicates(subset=["radiology_deid_note_key"])
    .rename(columns={"generated_category": "pathophysiological_category"})
)
indication_dataset["pathophysiological_category"] = indication_dataset["pathophysiological_category"].replace({
    "cancer/mass": "Cancer/Mass",
    "surgical": "Surgical",
    "infection/inflammatory": "Infection/Inflammatory",
    "symptom-based": "Symptom-based",
    "structural": "Structural"
})

In [84]:
indication_dataset["patientdurablekey"].nunique()

28313

In [74]:
indication_dataset["radiology_deid_note_key"].nunique()

77626

In [75]:
indication_dataset["note_texts"].apply(len).sum()

740867

In [9]:
indication_dataset[["patientdurablekey", "radiology_deid_note_key"]].to_parquet("indication_dataset_keys.parquet")

In [76]:
indication_dataset.loc[
    indication_dataset["exam_type"].notna(), "imaging_modality"
] = indication_dataset.loc[indication_dataset["exam_type"].notna()]["exam_type"].apply(categorize_imaging_modality)

indication_dataset["imaging_modality"].value_counts(normalize=True)

imaging_modality
MRI      0.491627
CT       0.461276
US       0.024038
XR       0.012921
Other    0.010138
Name: proportion, dtype: float64

In [77]:
indication_dataset.loc[
    indication_dataset["exam_type"].notna(), "body_system"
] = indication_dataset.loc[indication_dataset["exam_type"].notna()]["exam_type"].apply(categorize_body_group)

indication_dataset["body_system"].value_counts(normalize=True)

body_system
Brain             0.695257
Head/Neck         0.125602
MSK               0.072038
Abdomen/Pelvis    0.050962
Chest             0.037346
Other             0.018795
Name: proportion, dtype: float64

In [78]:
indication_dataset["pathophysiological_category"] = indication_dataset["pathophysiological_category"].fillna("Other").replace("inflammatory/infection", "infection/inflammatory")

# Dataset Stratification

In [85]:
_TIMESTAMP_RE = re.compile(
    r"^\d{1,2}/\d{1,2}/\d{4}\s+\d{1,2}:\d{2}\s*(?:AM|PM)$"
)

def extract_timestamp(s):
    if not isinstance(s, str):
        return None

    candidates = []

    # Method 1
    part = s.split("    ")[0]
    pieces = part.split(":", 1)
    if len(pieces) > 1:
        candidates.append(pieces[1].strip().split("  ")[0])

    # Method 2
    pieces = part.split("  ", 1)
    if len(pieces) > 1:
        candidates.append(pieces[1].strip().split("  ")[0])

    # Method 3
    alt = s.split("    ")
    if len(alt) > 1:
        candidates.append(alt[1].strip())

    # Validate each candidate
    for r in candidates:
        if len(r) >= 18 and _TIMESTAMP_RE.fullmatch(r):
            return r

def filter_notes_before_report(row):
    dates       = pd.to_datetime(row['deid_service_dates'], errors='coerce')
    report_date = pd.to_datetime(row['radiology_report_date'], errors='coerce')
    mask = dates < report_date
    filtered = {
        'note_texts'        : [d for d, keep in zip(row['note_texts'], mask)    if keep],
        'enc_dept_names'    : [d for d, keep in zip(row['enc_dept_names'], mask)    if keep],
        'note_types'        : [t for t, keep in zip(row['note_types'], mask)        if keep],
        'auth_prov_types'   : [p for p, keep in zip(row['auth_prov_types'], mask)   if keep],
        'deid_service_dates': [s for s, keep in zip(row['deid_service_dates'], mask) if keep],
    }
    return pd.Series(filtered)

indication_dataset_reader_study = indication_dataset.copy()
raw_times = indication_dataset_reader_study["radiology_text"].apply(extract_timestamp)
indication_dataset_reader_study['radiology_report_date'] = pd.to_datetime(
    raw_times,
    format='%m/%d/%Y %I:%M %p',   
    errors='coerce'                
)

filtered = indication_dataset_reader_study.apply(filter_notes_before_report, axis=1, result_type='expand')

indication_dataset_reader_study[[
    'note_texts',
    'enc_dept_names',
    'note_types',
    'auth_prov_types',
    'deid_service_dates'
]] = filtered

indication_dataset_reader_study = indication_dataset_reader_study[
    indication_dataset_reader_study["note_texts"].apply(len) > 0
].reset_index(drop=True)

In [107]:
RANDOM_STATE = 123
MIN_LENGTH = 200

BODY_SYSTEMS = ["Brain", "Head/Neck", "MSK", "Abdomen/Pelvis", "Chest"]
PATHOPHYSIOLOGICAL_CATEGORIES = ["Cancer/Mass", "Symptom-based", "Surgical", "Infection/Inflammatory", "Structural"]

def average_note_length(notes):
    average = 0
    for note in notes:
        average += len(note.split())
    return average / len(notes)

def stratify(dataset, total_samples):
    n_samples = total_samples // (len(BODY_SYSTEMS) * len(PATHOPHYSIOLOGICAL_CATEGORIES))
    stratified_dataset = []
    for body_system in BODY_SYSTEMS:
        for pathophysiological_category in PATHOPHYSIOLOGICAL_CATEGORIES:
            subset = dataset[
                (dataset["body_system"] == body_system) & 
                (dataset["pathophysiological_category"] == pathophysiological_category)
            ].sample(
                n=n_samples, random_state=RANDOM_STATE
            ).reset_index(drop=True)
            stratified_dataset.append(subset)
    return pd.concat(stratified_dataset).sample(frac=1).reset_index(drop=True)

llm_automated_test_dataset = stratify(indication_dataset, total_samples=1000)
print("LLM Automated Test Dataset")
print(
    llm_automated_test_dataset["patientdurablekey"].nunique(),
    llm_automated_test_dataset["radiology_deid_note_key"].nunique(),
    llm_automated_test_dataset["note_texts"].apply(len).sum()
)
print("-" * 20)

subset = indication_dataset_reader_study[indication_dataset_reader_study["note_texts"].apply(average_note_length) > MIN_LENGTH]
clinical_reader_study_test_dataset = stratify(subset, total_samples=250)
print("Reader Evaluation Dataset")
print(
    clinical_reader_study_test_dataset["patientdurablekey"].nunique(),
    clinical_reader_study_test_dataset["radiology_deid_note_key"].nunique(),
    clinical_reader_study_test_dataset["note_texts"].apply(len).sum()
)

LLM Automated Test Dataset
962 1000 9505
--------------------
Reader Evaluation Dataset
247 250 2127


In [61]:
llm_validation_dataset_parquet_path = f"/mnt/sohn2022/Adrian/rad-llm-pmhx/dataset/pathophysiological_labels/manual_validation/random_sample_for_llm_labels.parquet"
llm_validation_dataset = pd.read_parquet(llm_validation_dataset_parquet_path)
print("LLM Validation Dataset")
print(
    llm_validation_dataset["patientdurablekey"].nunique(),
    llm_validation_dataset["radiology_deid_note_key"].nunique(),
    llm_validation_dataset["note_texts"].apply(len).sum()
)

LLM Validation Dataset
100 100 920


In [62]:
indication_dataset_metadata = indication_dataset[["radiology_deid_note_key", "pathophysiological_category", "imaging_modality", "body_system"]]
llm_validation_dataset = llm_validation_dataset.merge(indication_dataset_metadata)

In [63]:
llm_validation_dataset["imaging_modality"].value_counts()

imaging_modality
MRI      48
CT       46
US        3
Other     2
XR        1
Name: count, dtype: int64

In [64]:
llm_validation_dataset["body_system"].value_counts()

body_system
Brain             67
Head/Neck         16
MSK               10
Abdomen/Pelvis     5
Other              2
Name: count, dtype: int64

In [65]:
llm_validation_dataset["pathophysiological_category"].value_counts()

pathophysiological_category
Cancer/Mass               35
Symptom-based             34
Surgical                  20
Structural                 6
Infection/Inflammatory     5
Name: count, dtype: int64

In [91]:
llm_automated_test_dataset["imaging_modality"].value_counts()

imaging_modality
CT       584
MRI      307
US        73
XR        31
Other      5
Name: count, dtype: int64

In [93]:
llm_automated_test_dataset["body_system"].value_counts()

body_system
Brain             200
Chest             200
MSK               200
Head/Neck         200
Abdomen/Pelvis    200
Name: count, dtype: int64

In [94]:
llm_automated_test_dataset["pathophysiological_category"].value_counts()

pathophysiological_category
Structural                200
Cancer/Mass               200
Symptom-based             200
Surgical                  200
Infection/Inflammatory    200
Name: count, dtype: int64

In [95]:
clinical_reader_study_test_dataset["imaging_modality"].value_counts()

imaging_modality
CT       153
MRI       65
US        21
XR         7
Other      4
Name: count, dtype: int64

In [96]:
clinical_reader_study_test_dataset["body_system"].value_counts()

body_system
Brain             50
Head/Neck         50
MSK               50
Chest             50
Abdomen/Pelvis    50
Name: count, dtype: int64

In [97]:
clinical_reader_study_test_dataset["pathophysiological_category"].value_counts()

pathophysiological_category
Cancer/Mass               50
Surgical                  50
Symptom-based             50
Structural                50
Infection/Inflammatory    50
Name: count, dtype: int64

In [66]:
llm_validation_dataset.to_parquet("llm_validation_dataset.parquet")

In [115]:
llm_automated_evaluation_dataset.to_parquet("llm_automated_evaluation_dataset.parquet")
clinical_reader_study_test_dataset.to_parquet("reader_evaluation_dataset.parquet")