In [1]:
import os
import json
import jsonlines
import numpy as np
import pandas as pd

from irrCAC.raw import CAC
from itertools import combinations

In [2]:
def load_json(filename):
    with open(filename, "r") as file:
        data = json.load(file)
    return data


def load_jsonl(filename):
    with jsonlines.open(filename, "r") as file:
        data_list = [line for line in file]
    return data_list


def get_profile(scenario_dict, trg_id):
    for profile in scenario_dict:
        if str(int(profile["hadm_id"])) == str(int(trg_id)):
            return profile

In [3]:
data_dir = "."
display_order = ['gemini-2.5-flash-preview-04-17', 'gpt-4o-mini', 'vllm-deepseek-llama-70b', 'vllm-qwen2.5-72b-instruct', 'vllm-llama3.3-70b-instruct','vllm-llama3.1-70b-instruct',  'vllm-llama3.1-8b-instruct', 'vllm-qwen2.5-7b-instruct']

## Patient Profile

In [4]:
patient_profile = load_json(os.path.join(data_dir, "patient_profile.json"))
patient_profile_df = pd.DataFrame(patient_profile)
for _key in ["split", "diagnosis"]:
    print(patient_profile_df[_key].value_counts())
    print()

split
persona    108
info        52
valid       10
Name: count, dtype: int64

diagnosis
Intestinal obstruction     39
Pneumonia                  34
Urinary tract infection    34
Myocardial infarction      34
Cerebral infarction        29
Name: count, dtype: int64



### RQ1: Do LLMs naturally reflect diverse persona traits in their responses?

In [5]:
persona_dir = os.path.join(data_dir, "persona_test", "llm_simulation")
llm_persona_result_list = []
for llm_backbone in os.listdir(persona_dir):
    llm_result = load_jsonl(os.path.join(persona_dir, llm_backbone, "llm_dialogue.jsonl"))
    llm_result = pd.DataFrame(llm_result)
    llm_persona_result_list.append(llm_result)
llm_persona_result_df = pd.concat(llm_persona_result_list)
llm_persona_result_df.groupby("patient_engine_name")[["personality", "cefr", "recall", "confused", "realism"]].mean().round(2).loc[display_order]

Unnamed: 0_level_0,personality,cefr,recall,confused,realism
patient_engine_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
gemini-2.5-flash-preview-04-17,3.94,3.54,3.64,3.38,3.37
gpt-4o-mini,3.58,3.55,3.78,3.88,3.26
vllm-deepseek-llama-70b,3.87,3.58,3.42,2.5,3.19
vllm-qwen2.5-72b-instruct,3.3,3.68,3.62,3.5,3.22
vllm-llama3.3-70b-instruct,3.92,3.4,3.78,4.0,3.28
vllm-llama3.1-70b-instruct,3.65,3.51,3.62,4.0,3.23
vllm-llama3.1-8b-instruct,3.53,3.29,3.7,4.0,3.2
vllm-qwen2.5-7b-instruct,3.23,3.49,3.31,3.5,3.16


In [6]:
llm_persona_result_df.groupby("patient_engine_name")[["personality", "cefr", "recall", "confused", "realism"]].mean().round(2).loc[display_order].mean(axis=1).round(2)

patient_engine_name
gemini-2.5-flash-preview-04-17    3.57
gpt-4o-mini                       3.61
vllm-deepseek-llama-70b           3.31
vllm-qwen2.5-72b-instruct         3.46
vllm-llama3.3-70b-instruct        3.68
vllm-llama3.1-70b-instruct        3.60
vllm-llama3.1-8b-instruct         3.54
vllm-qwen2.5-7b-instruct          3.34
dtype: float64

### RQ2: Do LLMs accurately derive responses based on the given profile? & RQ3: Can LLMs reasonably fill in the blanks?

### Dialogue-level

In [7]:
def is_valid(row, key):
    if row[f"{key}_gt"] == "Not recorded" or row[f"{key}_pred"] == "Not recorded":
        return 0
    if key == "pain" and "(predicted)" in str(row[f"{key}_pred"]):
        return 0
    return 1

def compute_valid_column(df, key):
    gt_col = f"{key}_gt"
    pred_col = f"{key}_pred"
    valid_col = f"{key}_valid"
    if gt_col in df.columns and pred_col in df.columns:
        df[valid_col] = df.apply(lambda row: is_valid(row, key), axis=1)
    return df

def aggregate_per_key(df, key, group_keys):
    valid_col = f"{key}_valid"
    llm_col = f"{key}_llm"
    item_df = df[df[valid_col] == 1]

    grouped_valid = item_df.groupby(group_keys).agg(
        valid_count=(valid_col, "count"),
        llm_score_mean=(llm_col, "mean")
    ).reset_index()

    total_counts = df.groupby(group_keys)[valid_col].count().reset_index(name="total_count")

    merged = pd.merge(grouped_valid, total_counts, on=group_keys, how="right")
    merged["valid_percentage"] = merged["valid_count"].fillna(0) / merged["total_count"]
    merged["item"] = key
    return merged

def get_metric_table(df, value_col, metric_name, groupby_cols, ordered_cats):
    grouped = df.groupby(groupby_cols)[value_col].mean().reset_index()
    pivoted = grouped.pivot(index=groupby_cols[:1], columns="category", values=value_col)
    pivoted = pivoted.reindex(columns=ordered_cats)
    # pivoted["metric"] = metric_name
    return pivoted.reset_index().set_index(groupby_cols[:1])
    # return pivoted.reset_index().set_index(groupby_cols[:2] + ["metric"])


def flatten_dict_simple(d, parent_key="", sep="_"):
    items = []
    for k, v in d.items():
        if parent_key == "present_illness":
            new_key = f"{parent_key}{sep}{k}" if parent_key else k
        else:
            new_key = k
        if isinstance(v, dict):
            items.extend(flatten_dict_simple(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

eval_key_cat = {
    "Social_History": ['tobacco', 'alcohol','illicit_drug', 'exercise', 'marital_status', 'children', 'living_situation', 'occupation'],
    "Previous_Medical_History": ['allergies', 'family_medical_history', 'medical_device', 'medical_history'],
    "Current_Visit_Information": ['chiefcomplaint', 'present_illness_positive', 'present_illness_negative', 'pain', 'medication',],
}
eval_key_to_cat = {key: category for category, keys in eval_key_cat.items() for key in keys}

In [8]:
info_dir = os.path.join(data_dir, "info_test", "llm_simulation")
llm_info_result_list = []
for llm_backbone in os.listdir(info_dir):
    dialogue_data = load_jsonl(os.path.join(info_dir, llm_backbone, "llm_dialogue.jsonl"))
    predict_profiles = load_json(os.path.join(info_dir, llm_backbone, "gemini-2.5-flash-preview-04-17_profile_consistency_Patient.json"))
    LLM_score_dicts = load_json(os.path.join(info_dir, llm_backbone, "gemini-2.5-flash-preview-04-17_profile_consistency_LLMscore_Patient.json"))
    for data in dialogue_data:
        hadm_id = data["hadm_id"]
        profile_gt = get_profile(patient_profile, hadm_id)
        predict_profile = flatten_dict_simple(predict_profiles[hadm_id])
        
        profile_gt = {k: v for k, v in profile_gt.items() if k in predict_profile.keys()}
        LLM_dict = LLM_score_dicts[hadm_id]
        LLM_score = {k: int(v[-1]) for k, v in LLM_dict.items()}
        
        meta_keys = ["doctor_engine_name", "patient_engine_name", "cefr_type",  "personality_type", "recall_level_type", "dazed_level_type"]
        meta_data = {key: data.get(key) for key in meta_keys}
        meta_data["hadm_id"] = hadm_id
        meta_df = pd.DataFrame([meta_data])
        dfs = [
            pd.DataFrame([profile_gt]).add_suffix("_gt"),
            pd.DataFrame([predict_profile]).add_suffix("_pred"),
            pd.DataFrame([LLM_score]).add_suffix("_llm")
        ]
        sample_data = pd.concat([meta_df] + dfs, axis=1)

        llm_info_result_list.append(sample_data)
llm_info_result_df = pd.concat(llm_info_result_list)

In [9]:
group_keys = ["patient_engine_name"]
groupby_cols = group_keys + ["category"]
ordered_cats = ["Social_History", "Previous_Medical_History", "Current_Visit_Information"]
metric_order = ["valid_percentage", "bert_score_mean", "llm_score_mean"]

for key in profile_gt.keys():
    llm_info_result_df = compute_valid_column(llm_info_result_df, key)

grouped_results = [
    aggregate_per_key(llm_info_result_df, key, group_keys)
    for key in profile_gt.keys() if key in eval_key_to_cat
]

final_df = pd.concat(grouped_results, ignore_index=True)
final_df["category"] = final_df["item"].map(eval_key_to_cat)

valid_tbl = get_metric_table(final_df, "valid_percentage", "valid_percentage", groupby_cols, ordered_cats)
llm_tbl = get_metric_table(final_df, "llm_score_mean", "llm_score_mean", groupby_cols, ordered_cats)
final_stacked = pd.concat([valid_tbl, llm_tbl], axis=1)
final_stacked = final_stacked.sort_index(axis=0, level=[1, 0])  # Sort by metric then category
final_stacked = final_stacked.round(2)
final_stacked.loc[display_order]

category,Social_History,Previous_Medical_History,Current_Visit_Information,Social_History,Previous_Medical_History,Current_Visit_Information
patient_engine_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
gemini-2.5-flash-preview-04-17,0.44,0.77,0.88,3.82,3.51,3.18
gpt-4o-mini,0.55,0.76,0.89,3.72,3.33,3.01
vllm-deepseek-llama-70b,0.5,0.76,0.91,3.73,3.31,3.08
vllm-qwen2.5-72b-instruct,0.47,0.77,0.9,3.75,3.5,2.95
vllm-llama3.3-70b-instruct,0.53,0.78,0.89,3.72,3.47,3.1
vllm-llama3.1-70b-instruct,0.56,0.77,0.89,3.82,3.43,3.05
vllm-llama3.1-8b-instruct,0.61,0.78,0.88,3.68,3.19,2.85
vllm-qwen2.5-7b-instruct,0.44,0.75,0.89,3.6,3.32,2.89


### Sentence-level evaluation

In [11]:
def flatten_dict(d):
    items = {}
    for k, v in d.items():
        for _k, _v in v.items():
            items[_k] = _v 
    return items
    
def safe_div(numerator, denominator):
    """Safely divide two numbers, returning None if denominator is zero."""
    return numerator / denominator if denominator > 0 else None

def calc_f1(precision, recall):
    """Calculate F1 score given precision and recall values."""
    return 2 * precision * recall / (precision + recall) if precision is not None and recall is not None and (precision + recall) > 0 else None

In [12]:
def analyze_sentence_data(hadm_id, dialogue, nli_entry):
    patient_utts = [d["content"] for d in dialogue if d["role"] == "Patient"]
    assert set(nli_entry.keys()) == set(patient_utts)
    total_utter_num = len(nli_entry)
    sent_results = flatten_dict(nli_entry)
    
    stats = {
        "total_sent_num": 0,
        "num_infomatic_sent": 0,
        "num_entail": 0,
        "num_support": 0,
        "num_support_only": 0,
        "num_unsupport": 0,
        "num_unsupport_only": 0,
        "both": 0,
        "plausibility_score": np.nan,
        "contradict_cnt": 0,
        "contradict_result": [],
        "plausibility_score_list": [],
        "sent_category": [],
    }

    plausibility_scores = []
    contradict_results = []
    sent_categories = []
    for sent, result in sent_results.items():
        stats["total_sent_num"] += 1
        sent_categories.append(result["step0"]["prediction"])

        if result["step0"]["prediction"] != "information":
            continue

        stats["num_infomatic_sent"] += 1

        related_categories = [s["category"] for s in result.get("step1-1", []) if int(s["prediction"]) == 1]
        unsupport_flag = int(result.get("step1-2", {}).get("prediction", 0)) == 1
        entail_dict = [s for s in result.get("step2-2", []) if s["entailment_prediction"] != 0]

        is_unsupport = unsupport_flag or not entail_dict
        if "step2-1" in result:
            assert is_unsupport

        if related_categories and entail_dict:
            if any(s["entailment_prediction"] == -1 for s in entail_dict):
                contradict_results.append([s for s in result["step2-2"] if s["entailment_prediction"] == -1])
            else:
                stats["num_entail"] += 1

        if is_unsupport:
            stats["num_unsupport"] += 1
            plausibility_scores.append(result["step2-1"]["likelihood_rating"])
            if not entail_dict:
                stats["num_unsupport_only"] += 1
            else:
                stats["num_support"] += 1
                stats["both"] += 1
        else:
            stats["num_support"] += 1
            stats["num_support_only"] += 1

    stats["plausibility_score"] = np.mean(plausibility_scores) if plausibility_scores else np.nan
    stats["contradict_cnt"] = len(contradict_results)
    stats["contradict_result"] = contradict_results
    stats["plausibility_score_list"] = plausibility_scores
    stats["sent_category"] = sent_categories
    stats["total_utter_num"] = total_utter_num

    return stats


def build_llm_info_sent_result_df(info_dir):
    result_list = []
    for llm_backbone in os.listdir(info_dir):
        dialogue_data = load_jsonl(os.path.join(info_dir, llm_backbone, "llm_dialogue.jsonl"))
        nli_data = load_json(os.path.join(info_dir, llm_backbone, "gemini-2.5-flash-preview-04-17_sentence_label.json"))

        for data in dialogue_data:
            hadm_id = data["hadm_id"]
            try:
                stats = analyze_sentence_data(hadm_id, data["dialog_history"], nli_data[hadm_id])
            except:
                print(hadm_id, data["patient_engine_name"])
            stats.update({
                "hadm_id": hadm_id,
                "doctor_engine_name": data["doctor_engine_name"],
                "patient_engine_name": data["patient_engine_name"],
            })
            result_list.append(stats)

    df = pd.DataFrame(result_list)

    df["info_frac"] = df["num_infomatic_sent"] / df["total_sent_num"]
    df["support_frac"] = df["num_support"] / df["num_infomatic_sent"]
    df["entail_frac"] = df["num_entail"] / df["num_support"]
    df["contradict_frac"] = df["contradict_cnt"] / df["num_support"]
    df["unsupport_frac"] = df["num_unsupport"] / df["num_infomatic_sent"]
    df["both_frac"] = df["both"] / df["num_infomatic_sent"]
    df["unsupport_only"] = df["unsupport_frac"] - df["both_frac"]
    return df

info_dir = os.path.join(data_dir, "info_test", "llm_simulation")
llm_info_sent_result_df = build_llm_info_sent_result_df(info_dir)

summary_df = llm_info_sent_result_df.groupby(["patient_engine_name"])[[
    "info_frac", "support_frac", "unsupport_frac", "entail_frac", "contradict_frac", "plausibility_score"
]].mean().round(3)
summary_df.loc[display_order]

Unnamed: 0_level_0,info_frac,support_frac,unsupport_frac,entail_frac,contradict_frac,plausibility_score
patient_engine_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
gemini-2.5-flash-preview-04-17,0.972,0.763,0.316,0.978,0.022,3.953
gpt-4o-mini,0.957,0.721,0.428,0.968,0.032,3.929
vllm-deepseek-llama-70b,0.975,0.762,0.416,0.968,0.032,3.911
vllm-qwen2.5-72b-instruct,0.975,0.683,0.468,0.954,0.046,3.928
vllm-llama3.3-70b-instruct,0.958,0.796,0.387,0.981,0.019,3.963
vllm-llama3.1-70b-instruct,0.948,0.812,0.399,0.962,0.038,3.958
vllm-llama3.1-8b-instruct,0.944,0.771,0.488,0.944,0.056,3.897
vllm-qwen2.5-7b-instruct,0.987,0.703,0.453,0.939,0.061,3.862


In [13]:
llm_info_sent_result_df.groupby(["patient_engine_name"])[["total_utter_num", "total_sent_num", "num_infomatic_sent", "num_support",  "num_unsupport", "num_entail", "contradict_cnt"]].sum().round().loc[display_order]

Unnamed: 0_level_0,total_utter_num,total_sent_num,num_infomatic_sent,num_support,num_unsupport,num_entail,contradict_cnt
patient_engine_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
gemini-2.5-flash-preview-04-17,889,2286,2220,1695,705,1659,36
gpt-4o-mini,786,1937,1852,1331,795,1287,44
vllm-deepseek-llama-70b,806,1657,1614,1225,679,1186,39
vllm-qwen2.5-72b-instruct,824,1820,1774,1201,839,1146,55
vllm-llama3.3-70b-instruct,806,2180,2087,1654,817,1623,31
vllm-llama3.1-70b-instruct,699,1946,1842,1493,745,1438,55
vllm-llama3.1-8b-instruct,742,1774,1672,1284,826,1210,74
vllm-qwen2.5-7b-instruct,877,1579,1558,1092,712,1024,68


### Human Evaluation - Persona Fidelity

In [14]:
human_dialog = load_jsonl(os.path.join(data_dir, "persona_test", "expert_dialogue.jsonl"))
human_dialog_df = pd.DataFrame(human_dialog)
human_dialog_df[["personality", "cefr", "recall", "confused", "realism", "tool_usefulness"]].mean().round(2)

personality        3.96
cefr               3.84
recall             3.87
confused           4.00
realism            3.89
tool_usefulness    3.75
dtype: float64

### Human Evaluation - Plausibility

In [15]:
human_plausibility_label = load_jsonl(os.path.join(data_dir, "info_test", "expert_plausibility_label.jsonl"))
human_plausibility_label = pd.DataFrame(human_plausibility_label)
print("# of labeled utterance (per labeler): ", human_plausibility_label.groupby("labeler_name").utterance_id.nunique().mean())
print("# of labeled utterance: ", human_plausibility_label.utterance_id.nunique())
print("avg score: ", human_plausibility_label.score.mean().round(2))
human_plausibility_label.groupby("labeler_name").score.mean().round(3)

# of labeled utterance (per labeler):  615.75
# of labeled utterance:  821
avg score:  3.91


labeler_name
A    3.955
B    3.923
C    3.985
D    3.781
Name: score, dtype: float64

### Plausibility Agreement

In [16]:
def compute_pairwise_agreement_with_ci(long_df, rater_col="labeler_name", agreement_type="exact_agreement", max_score=4, n_bootstrap=1000, ci=95, seed=42):
    np.random.seed(seed)
    pivot = long_df.pivot_table(index='utterance_id', columns=rater_col, values='score')
    raters = pivot.columns
    results = []

    for r1, r2 in combinations(raters, 2):
        sub = pivot[[r1, r2]].dropna()
        if sub.empty:
            continue

        if agreement_type == "exact_agreement":
            def stat_fn(df): return (df[r1] == df[r2]).mean()
        elif agreement_type == "exact_mse_norm":
            def stat_fn(df): return 1 - (df[r1] - df[r2]).abs().mean() / (max_score - 1)
        elif agreement_type == "gwets_ac1":
            def stat_fn(df):
                try:
                    cac = CAC(df, categories=list(range(1, max_score + 1)), weights='identity')
                    return cac.gwet()["est"]["coefficient_value"]
                except Exception:
                    return np.nan  
        else:
            raise ValueError(f"Unknown agreement_type: {agreement_type}")
        

        point_estimate = stat_fn(sub)

        # Bootstrap
        scores = []
        for _ in range(n_bootstrap):
            sample_df = sub.sample(n=len(sub), replace=True)
            score = stat_fn(sample_df)
            if not np.isnan(score):
                scores.append(score)
        try:
            lower = np.percentile(scores, (100 - ci) / 2)
            upper = np.percentile(scores, 100 - (100 - ci) / 2)
        except:
            return sub

        results.append({
            "rater_pair": f"{r1}-{r2}",
            "agreement": point_estimate,
            f"mean": np.mean(scores),
            f"ci_lower_{ci}%": lower,
            f"ci_upper_{ci}%": upper,
            f"num_sample": len(scores)
        })

    return pd.DataFrame(results)

llm_plausibility_label = load_jsonl(os.path.join(data_dir, "info_test", "llm_plausibility_label.jsonl"))
llm_plausibility_label = pd.DataFrame(llm_plausibility_label)
human_llm_plausibility_label = pd.concat([llm_plausibility_label, human_plausibility_label]).reset_index(drop=True)
agreement_ci = compute_pairwise_agreement_with_ci(human_llm_plausibility_label, agreement_type="gwets_ac1")
agreement_ci

Unnamed: 0,rater_pair,agreement,mean,ci_lower_95%,ci_upper_95%,num_sample
0,A-B,0.94917,0.948846,0.92688,0.96873,1000
1,A-C,0.96819,0.968232,0.95079,0.98296,1000
2,A-D,0.86583,0.865568,0.827517,0.901122,1000
3,A-gemini-2.5-flash-preview-04-17,0.94371,0.943714,0.92341,0.96043,1000
4,B-C,0.96161,0.960926,0.93959,0.97852,1000
5,B-D,0.8536,0.853072,0.818096,0.88567,1000
6,B-gemini-2.5-flash-preview-04-17,0.94459,0.944609,0.9262,0.96109,1000
7,C-D,0.87824,0.878705,0.842892,0.91282,1000
8,C-gemini-2.5-flash-preview-04-17,0.96398,0.963905,0.94733,0.97718,1000
9,D-gemini-2.5-flash-preview-04-17,0.88266,0.88259,0.857062,0.907741,1000


In [17]:
KEY_DESCRIPTION = {
    "age": "Age: {age}",
    "gender": "Gender: {gender}",
    "race": "Race: {race}",
    "tobacco": "Tobacco: {tobacco}",
    "alcohol": "Alcohol: {alcohol}",
    "illicit_drug": "Illicit drug use: {illicit_drug}",
    "sexual_history": "Sexual History: {sexual_history}",
    "exercise": "Exercise: {exercise}",
    "marital_status": "Marital status: {marital_status}",
    "children": "Children: {children}",
    "living_situation": "Living Situation: {living_situation}",
    "occupation": "Occupation: {occupation}",
    "insurance": "Insurance: {insurance}",
    "allergies": "Allergies: {allergies}",
    "family_medical_history": "Family medical history: {family_medical_history}",
    "medical_device": "Medical devices previously used or currently in use before this ED admission: {medical_device}",
    "medical_history": "Medical history prior to this ED admission: {medical_history}",
    "present_illness": "Present illness:\n\tpositive: {present_illness_positive}\n\tnegative (denied): {present_illness_negative}",
    "chief_complaint": "ED chief complaint: {chiefcomplaint}",
    "pain": "Pain level at ED Admission (0 = no pain, 10 = worst pain imaginable): {pain}",
    "medication": "Current medications they are taking: {medication}",
    "arrival_transport": "ED Arrival Transport: {arrival_transport}",
    "diagnosis": "ED Diagnosis: {diagnosis}",
}

def reverse_map_key(target_sentence):
    for key, value_template in KEY_DESCRIPTION.items():
        prefix = value_template.split("{")[0].strip()  # { 이전의 고정된 부분만 추출
        if target_sentence.startswith(prefix):
            return key
    return None


def evaluate_step0(human_label, model_pred):
    is_info = human_label == "information"
    is_correct = human_label == model_pred
    is_tp = is_info and is_correct
    return is_info, is_tp, is_correct

def evaluate_step2(human_step2, model_step2):
    model_preds = {
        reverse_map_key(item["profile"]): item["entailment_prediction"]
        for item in model_step2 if item["entailment_prediction"] != 0
    }
    human_labels = {k: 1 if v == "e" else -1 for k, v in human_step2.items()}
    eval_metrics = evaluate_prediction(human_labels, model_preds)
    return eval_metrics, set(model_preds.keys()), set(human_labels.keys())

def is_unsupported_pred(utterance_data):
    return "step2-1" in utterance_data
    
def evaluate_sets(gt_set, pred_set):
    intersection = len(gt_set & pred_set)
    precision = intersection / len(pred_set) if pred_set else 0
    recall = intersection / len(gt_set) if gt_set else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return {
        "Precision": round(precision, 4),
        "Recall": round(recall, 4),
        "F1score": round(f1_score, 4)
    }
def evaluate_prediction(ans, pred):
    ans_keys = set(ans.keys())
    pred_keys = set(pred.keys())
    common_keys = ans_keys.intersection(pred_keys)
    
    key_eval = evaluate_sets(ans_keys, pred_keys)
    correct_values = sum(1 for key in common_keys if ans[key] == pred[key])
    value_accuracy = correct_values / len(common_keys) if common_keys else 0
    value_tp = 0
    for key, value in ans.items():
        if value == pred.get(key):
            value_tp += 1

    return {
        'key_precision': key_eval["Precision"],
        'key_recall': key_eval["Recall"],
        'key_f1': key_eval["F1score"],
        'value_accuracy': value_accuracy,
        "value_recall": value_tp / len(ans_keys) if len(ans_keys) > 0 else 0
    }

In [49]:
sentence_label_human = load_json(os.path.join(data_dir, "sentence_cls_valid", "sentence_label_manual.json"))
dialogue_hists = load_jsonl(os.path.join(data_dir, "sentence_cls_valid", "dialogue.jsonl"))

print(len(dialogue_hists))
print(len(sentence_label_human))
valid_result_dict = []
for llm_path in ["sentence_label_gemini-2.5-flash-preview-04-17.json", "sentence_label_gpt-4o.json"]:
    nli_data = load_json(os.path.join(data_dir, "sentence_cls_valid", llm_path))
    
    for data in dialogue_hists:
        hadm_id = data["hadm_id"]
        counters = {
            "total_sentences": 0,
            "step0_correct": 0,
            "step0_tp_sum": 0,
            "step0_gt_info": 0,
            "step2_gt_count": 0,
            "step2_pred_count": 0,
            "step2_eval_count": 0,
            "step2_key_precision_sum": 0,
            "step2_key_recall_sum": 0,
            "step2_key_f1_sum": 0,
            "step2_value_correct": 0,
            "step2_value_recall": 0,
            "unsupported_gt_count": 0,
            "unsupported_pred_count": 0,
            "unsupported_tp": 0,
            "unsupported_only_gt_count": 0,
            "unsupported_only_pred_count": 0,
        }

        dialogue_info_tp = 0
        dialogue_info_pred_total = 0
        dialogue_info_gt_total = 0

        for utter in data["dialog_history"]:
            if utter["role"] != "Patient":
                continue
            utterance = utter["content"]

            if utterance not in nli_data[hadm_id] or utterance not in sentence_label_human[hadm_id]:
                print(utterance)
                continue

            utterance_data = nli_data[hadm_id][utterance]
            human_utterance_data = sentence_label_human[hadm_id][utterance]

            for sent, model_info in utterance_data.items():
                if sent not in human_utterance_data:
                    continue

                counters["total_sentences"] += 1
                human_step0 = human_utterance_data[sent]["step0"]
                model_step0 = "information" if model_info["step0"]["prediction"] == "information" else "non-information"

                is_info, is_tp, is_correct = evaluate_step0(human_step0, model_step0)
                counters["step0_gt_info"] += is_info
                counters["step0_tp_sum"] += is_tp
                counters["step0_correct"] += is_correct

                if "step2-2" in human_utterance_data[sent]:
                    counters["step2_gt_count"] += 1
                if "step2-2" in model_info:
                    counters["step2_pred_count"] += 1
                if "step2-2" in human_utterance_data[sent] and "step2-2" in model_info:
                    counters["step2_eval_count"] += 1

                # Step1: unsupported
                if is_tp:
                    step1_unsupported_human = "unsupported" in human_utterance_data[sent]["step1"]
                    step1_unsupported_pred = is_unsupported_pred(model_info)

                    counters["unsupported_gt_count"] += step1_unsupported_human
                    counters["unsupported_only_gt_count"] += step1_unsupported_human and not step1_unsupported_pred
                    counters["unsupported_only_pred_count"] += step1_unsupported_pred and not step1_unsupported_human
                    counters["unsupported_tp"] += step1_unsupported_human and step1_unsupported_pred
                    counters["unsupported_pred_count"] += step1_unsupported_pred

                    # Step2: fine-grained eval
                    if "step2-2" in human_utterance_data[sent] and "step2-2" in model_info:
                        eval_metrics, pred_keys, gt_keys = evaluate_step2(human_utterance_data[sent]["step2-2"], model_info["step2-2"])

                        counters["step2_key_precision_sum"] += eval_metrics["key_precision"]
                        counters["step2_key_recall_sum"] += eval_metrics["key_recall"]
                        counters["step2_key_f1_sum"] += eval_metrics["key_f1"]
                        counters["step2_value_correct"] += eval_metrics["value_accuracy"]
                        counters["step2_value_recall"] += eval_metrics["value_recall"]

                        dialogue_info_tp += len(pred_keys & gt_keys)
                        dialogue_info_pred_total += len(pred_keys)
                        dialogue_info_gt_total += len(gt_keys)

        if counters["total_sentences"] > 0:
            metrics = {
                "llm": llm_path.replace("_nli_Patient.json", ""),
                "total_sentences": counters["total_sentences"],
                "step0_accuracy": safe_div(counters["step0_correct"], counters["total_sentences"]),
                "step0_recall": safe_div(counters["step0_tp_sum"], counters["step0_gt_info"]),
                "step2_macro_avg_key_precision": safe_div(counters["step2_key_precision_sum"], counters["step2_eval_count"]),
                "step2_macro_avg_key_recall": safe_div(counters["step2_key_recall_sum"], counters["step2_eval_count"]),
                "step2_macro_avg_key_f1": safe_div(counters["step2_key_f1_sum"], counters["step2_eval_count"]),
                "step2_macro_avg_value_accuracy": safe_div(counters["step2_value_correct"], counters["step2_eval_count"]),
                "step2_macro_avg_value_recall": safe_div(counters["step2_value_recall"], counters["step2_eval_count"]),
                "unsupported_precision": safe_div(counters["unsupported_tp"], counters["unsupported_pred_count"]),
                "unsupported_recall": safe_div(counters["unsupported_tp"], counters["unsupported_gt_count"]),
                "unsupported_f1": calc_f1(
                    safe_div(counters["unsupported_tp"], counters["unsupported_pred_count"]),
                    safe_div(counters["unsupported_tp"], counters["unsupported_gt_count"])
                )
            }
            valid_result_dict.append(metrics)

sample_df = pd.DataFrame(valid_result_dict)
sample_df.groupby("llm").mean().T.round(2)

10
10


llm,sentence_label_gemini-2.5-flash-preview-04-17.json,sentence_label_gpt-4o.json
total_sentences,41.1,41.1
step0_accuracy,0.96,0.94
step0_recall,0.99,0.98
step2_macro_avg_key_precision,0.9,0.92
step2_macro_avg_key_recall,0.96,0.94
step2_macro_avg_key_f1,0.92,0.92
step2_macro_avg_value_accuracy,0.98,0.97
step2_macro_avg_value_recall,0.96,0.94
unsupported_precision,0.84,0.89
unsupported_recall,0.86,0.64
