In [None]:
import json
import pandas as pd

from tqdm import tqdm
from typing import List, Dict, Any

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import classification_report

from breastfeeding_nlp.extraction.ner import MedSpaCyLabeler
from breastfeeding_nlp.utils.utils import filter_dataset

%load_ext autoreload
%autoreload 2

# Load the data

In [None]:
# Load original dataframe that has the raw notes
orig_df = pd.read_excel("/Volumes/RISIDataServices_MPrint_NCH/MPRINT_LACTATE_BF_1_20250424.xlsx")
orig_df.reset_index(names=['row_ix'], inplace=True)

# Add the cohort splits
cohort_splits = pd.read_csv("/Volumes/RISIDataServices_MPrint_NCH/data/cohort-split-unstratified.csv")
orig_df = orig_df.merge(cohort_splits, on="PAT_ID")

# drop 2 missing notes
orig_df.dropna(subset='NOTE_TEXT', inplace=True)

# drop 30 wic records
wic_ids = orig_df.query("NOTE_TEXT.str.contains('Ohio WIC Prescribed Formula and Food Request Form')").row_ix
orig_df = orig_df.query("row_ix not in @wic_ids.tolist()")

# drop 3k note types
filter_out_note_types = ["Patient Instructions", "Discharge Instructions", "MR AVS Snapshot", "ED AVS Snapshot", "IP AVS Snapshot", "Training", "Operative Report", "D/C Planning", "Pharmacy"]
orig_df = orig_df[~orig_df["NOTE_TYPE"].isin(filter_out_note_types)]
orig_df

In [None]:
test_df = orig_df.query("split == 'test'")
test_df.shape

# Load the medspaCy pipeline

In [None]:
medspacy_labeler = MedSpaCyLabeler()

# Run it on everything!

In [None]:
entities_df = medspacy_labeler.process_dataframe(test_df)
medspacy_doc_labels = medspacy_labeler.label_documents(entities_df)

In [None]:
medspacy_doc_labels

In [None]:
entities_df

# Load in the gold standard labels

In [None]:
with open("/Users/cxg042/Documents/git/ods-preglac/dev_cg/output_all_notes.jsonl.out", 'r') as f:
    results = f.readlines()

def organize_batch_results(batch_results):
    dfs = []
    for res in tqdm(batch_results):
        dfs.append(parse_batch_results(res))
    return pd.concat(dfs)

def parse_batch_results(single_result: Dict[str, str]) -> pd.DataFrame:
    res = json.loads(single_result)

    record_id = res['recordId']

    input_tokens = res['modelOutput']['usage']['input_tokens']
    output_tokens = res['modelOutput']['usage']['output_tokens']

    label = json.loads(res['modelOutput']['content'][0]['text'])['Label']
    reasoning = json.loads(res['modelOutput']['content'][0]['text'])['Reasoning']

    return pd.DataFrame([
        {
            "recordID": record_id,
            "Label": label,
            "Reasoning": reasoning,
            "input_cost": (input_tokens/1_000_000) * 3,
            "output_cost": (output_tokens/1_000_000) * 15,
        }
    ])

def standardize_label(label):
    if label == "Absent / Insufficient Evidence":
        return "absent"
    else:
        return label.lower()

res = organize_batch_results(results)
res.recordID = res.recordID.astype(int)
res.recordID -=1

res

## Merge the two data frames

In [None]:
res_df = pd.merge(
    test_df[[
        "PAT_ID", "BF1", "BF2", "NOTE_TYPE", "NOTE_TEXT", "row_ix", "split"
    ]],
    res.drop(["input_cost", "output_cost"], axis=1),
    left_index=True,
    right_on="recordID",
    how='left'
)

res_df.Label = res_df.Label.apply(standardize_label)
res_df = res_df.drop(columns=["BF1", "BF2", "NOTE_TEXT", "PAT_ID", "split", "recordID", "Reasoning"])
res_df

In [None]:
eval_df = res_df.merge(medspacy_doc_labels, on='row_ix', how='inner')
eval_df.rename(columns={'medspacy_document_label': 'medspacy_label'}, inplace=True)
eval_df

In [None]:
# Make confusion matrix
conf_mat = pd.crosstab(eval_df.medspacy_label, eval_df.Label)
conf_mat.index = list(map(lambda x: x.title(), conf_mat.index))
conf_mat.columns = list(map(lambda x: x.title(), conf_mat.columns))

# Set plot size and style
plt.figure(figsize=(8, 6))
sns.set(style="whitegrid")

# Create heatmap with blue color palette
ax = sns.heatmap(
    conf_mat,
    annot=True,
    fmt='d',
    cmap='Blues',
    cbar=False
)

# Set axis labels and title
ax.set_xlabel('LLM Label', fontsize=12)
ax.set_ylabel('MedspaCy Label', fontsize=12)
ax.set_title('Confusion Matrix: MedspaCy vs LLM labels', fontsize=14, pad=20, loc='left', x=-0.07)

# Improve layout
plt.tight_layout()
plt.show()

In [None]:
# Calculate performance
y_true = eval_df.Label
y_pred = eval_df.medspacy_label

# Generate the classification report as a dictionary
report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)

# Convert to DataFrame for better formatting
report_df = pd.DataFrame(report).transpose()

# Filter out aggregate rows (like accuracy, macro avg, etc.) if only class-wise metrics are needed
class_wise_report = report_df.loc[~report_df.index.str.contains("avg|accuracy")]

performance = class_wise_report.drop(columns=['support']).loc[["positive", "negative", "absent"]].apply(lambda x: round(x, 2))
performance

# Save the results

In [None]:
eval_df.to_csv("/Volumes/RISIDataServices_MPrint_NCH/data/results/test_set_document_results.csv")
entities_df.to_csv("/Volumes/RISIDataServices_MPrint_NCH/data/results/test_set_entity_results.csv")