# Step 3: Accuracy CalculationThis notebook:1. Loads **shared** ground truth from `output/ground_truth/` (generated by Step 1)2. Loads LLM predictions from Step 2 (`set1_llm_output.csv`, `set2_llm_output.csv`)3. Compares Set 1 (A4-A7) and Set 2 (B4-B7) predictions against ground truth4. Calculates accuracy metrics (exact-match recall and Jaccard)**Prerequisites**: Run `step1_ground_truth.ipynb` (in `ground_truth/`) and `step2_llm_queries.ipynb` (in this folder) first.

## Configuration

In [ ]:
import os
import re
from pathlib import Path

import pandas as pd

# ============================================================
# Path Configuration
# ============================================================

_cwd = Path(".").resolve()
if _cwd.name.startswith("testing_"):
    REPO_ROOT = _cwd.parent
    LLM_NAME = _cwd.name.replace("testing_", "")
else:
    REPO_ROOT = _cwd
    LLM_NAME = "gpt"

OUTPUT_ROOT = (REPO_ROOT / "output").resolve()
PIPELINE_ROOT = OUTPUT_ROOT / LLM_NAME

# Shared ground truth
GT_ROOT = OUTPUT_ROOT / "ground_truth"
GT_PATH = GT_ROOT / "ground_truth.csv"

# Find the LATEST run directory for this LLM
run_dirs = sorted([d for d in PIPELINE_ROOT.iterdir() if d.is_dir() and d.name.startswith("run_")])
if not run_dirs:
    raise FileNotFoundError("No run directories found. Run step2_llm_queries.ipynb first.")
RUN_DIR = run_dirs[-1]

# LLM prediction files
SET1_PATH = RUN_DIR / "step2_llm_set1" / "set1_llm_output.csv"
SET2_PATH = RUN_DIR / "step2_llm_set2" / "set2_llm_output.csv"

# Output directory
STEP3_DIR = RUN_DIR / "step3_accuracy"
STEP3_DIR.mkdir(parents=True, exist_ok=True)

# Output files
OUT_COMPARE_PATH = STEP3_DIR / "comparison_set1_set2_vs_ground_truth.csv"

# Verify prerequisites
assert GT_PATH.exists(), f"Ground truth not found: {GT_PATH}\nRun step1_ground_truth.ipynb in ground_truth/ first."
assert SET1_PATH.exists(), f"Set 1 predictions not found: {SET1_PATH}\nRun step2_llm_queries.ipynb first."
assert SET2_PATH.exists(), f"Set 2 predictions not found: {SET2_PATH}\nRun step2_llm_queries.ipynb first."

print("RUN_DIR:", RUN_DIR)
print("GT (shared):", GT_PATH)
print("Set 1:", SET1_PATH)
print("Set 2:", SET2_PATH)


## Accuracy Helpers

In [ ]:
# ============================================================
# Helpers: normalization and accuracy
# ============================================================

def strip_semantic_tag(term: str) -> str:
    return re.sub(r"\s*\([^()]*\)\s*$", "", term or "").strip()

def normalize_term(term: str) -> str:
    term = (term or "").lower().strip()
    term = re.sub(r"\s+", " ", term)
    term = strip_semantic_tag(term)
    return term

def pipe_to_set(s: str):
    if not s:
        return set()
    s = str(s).strip()
    if not s or s.upper() == "UNKNOWN":
        return set()
    return {
        normalize_term(x)
        for x in s.split("|")
        if x.strip() and x.upper() != "UNKNOWN"
    }

def accuracy_exact_match(gt_set: set, pred_set: set) -> float:
    """Exact-match (recall): |GT & pred| / |GT| when |GT| > 0, else 0."""
    if len(gt_set) == 0:
        return 0.0
    return len(gt_set & pred_set) / len(gt_set)

def accuracy_jaccard(gt_set: set, pred_set: set) -> float:
    """Jaccard: |GT & pred| / |GT | pred| when union > 0; 0 when GT empty."""
    if len(gt_set) == 0:
        return 0.0
    union = gt_set | pred_set
    if len(union) == 0:
        return 0.0
    return len(gt_set & pred_set) / len(union)

print("Helpers loaded.")


## Load Ground Truth and Predictions

In [ ]:
# ============================================================
# Load ground truth
# ============================================================
gt_df = pd.read_csv(GT_PATH, dtype=str).fillna("")
print(f"Ground truth: {len(gt_df)} concepts from {GT_PATH}")

gt_map = gt_df.set_index("concept_term").to_dict(orient="index")
concepts = gt_df["concept_term"].tolist()

# ============================================================
# Load Set 1 / Set 2 LLM predictions
# ============================================================
set1 = pd.read_csv(SET1_PATH, dtype=str).fillna("")
set2 = pd.read_csv(SET2_PATH, dtype=str).fillna("")

set1_idx = set1.set_index("concept_term") if "concept_term" in set1.columns else pd.DataFrame().set_index(pd.Index([]))
set2_idx = set2.set_index("concept_term") if "concept_term" in set2.columns else pd.DataFrame().set_index(pd.Index([]))

# Ensure prediction columns exist
# Set 1: A4_parents, A5_grandparents, A6_children, A7_siblings
for col in ["A4_parents", "A5_grandparents", "A6_children", "A7_siblings"]:
    if col not in set1.columns:
        set1[col] = ""
# Set 2: B4_immediate_broader, B5_grandparents, B6_immediate_narrower, B7_peer_terms
for col in ["B4_immediate_broader", "B5_grandparents", "B6_immediate_narrower", "B7_peer_terms"]:
    if col not in set2.columns:
        set2[col] = ""

print(f"Set 1 predictions: {len(set1)} rows")
print(f"Set 2 predictions: {len(set2)} rows")


## Compare Predictions vs Ground Truth

In [ ]:
# ============================================================
# Compare predictions vs ground truth
# ============================================================

rows = []

for concept_term in concepts:
    gt_row = gt_map.get(concept_term, {})

    gt_id = str(gt_row.get("snomed_id", ""))
    gt_fsn = str(gt_row.get("fsn", ""))

    gt_parents_pipe      = str(gt_row.get("parents", ""))
    gt_grandparents_pipe = str(gt_row.get("grandparents", ""))
    gt_children_pipe     = str(gt_row.get("children", ""))
    gt_siblings_pipe     = str(gt_row.get("siblings", ""))

    gt_parents_set      = pipe_to_set(gt_parents_pipe)
    gt_grandparents_set = pipe_to_set(gt_grandparents_pipe)
    gt_children_set     = pipe_to_set(gt_children_pipe)
    gt_siblings_set     = pipe_to_set(gt_siblings_pipe)

    # Pull Set 1 predictions (A4=parents, A5=grandparents, A6=children, A7=siblings)
    if concept_term in set1_idx.index:
        s1_row = set1_idx.loc[concept_term]
        s1_A4 = str(s1_row.get("A4_parents", ""))
        s1_A5 = str(s1_row.get("A5_grandparents", ""))
        s1_A6 = str(s1_row.get("A6_children", ""))
        s1_A7 = str(s1_row.get("A7_siblings", ""))
    else:
        s1_A4 = s1_A5 = s1_A6 = s1_A7 = ""

    # Pull Set 2 predictions (B4=broader/parents, B5=grandparents, B6=narrower/children, B7=peers/siblings)
    if concept_term in set2_idx.index:
        s2_row = set2_idx.loc[concept_term]
        s2_B4 = str(s2_row.get("B4_immediate_broader", ""))
        s2_B5 = str(s2_row.get("B5_grandparents", ""))
        s2_B6 = str(s2_row.get("B6_immediate_narrower", ""))
        s2_B7 = str(s2_row.get("B7_peer_terms", ""))
    else:
        s2_B4 = s2_B5 = s2_B6 = s2_B7 = ""

    # Prediction sets - Set 1
    s1_parents_set      = pipe_to_set(s1_A4)
    s1_grandparents_set = pipe_to_set(s1_A5)
    s1_children_set     = pipe_to_set(s1_A6)
    s1_siblings_set     = pipe_to_set(s1_A7)

    # Prediction sets - Set 2
    s2_parents_set      = pipe_to_set(s2_B4)
    s2_grandparents_set = pipe_to_set(s2_B5)
    s2_children_set     = pipe_to_set(s2_B6)
    s2_siblings_set     = pipe_to_set(s2_B7)

    # ---- Set 1 accuracy ----
    s1_p_exact  = accuracy_exact_match(gt_parents_set, s1_parents_set)
    s1_gp_exact = accuracy_exact_match(gt_grandparents_set, s1_grandparents_set)
    s1_c_exact  = accuracy_exact_match(gt_children_set, s1_children_set)
    s1_s_exact  = accuracy_exact_match(gt_siblings_set, s1_siblings_set)
    s1_parents_j      = accuracy_jaccard(gt_parents_set, s1_parents_set)
    s1_grandparents_j = accuracy_jaccard(gt_grandparents_set, s1_grandparents_set)
    s1_children_j     = accuracy_jaccard(gt_children_set, s1_children_set)
    s1_siblings_j     = accuracy_jaccard(gt_siblings_set, s1_siblings_set)
    s1_concept_exact   = (s1_p_exact + s1_gp_exact + s1_c_exact + s1_s_exact) / 4.0
    s1_concept_jaccard = (s1_parents_j + s1_grandparents_j + s1_children_j + s1_siblings_j) / 4.0

    # ---- Set 2 accuracy ----
    s2_p_exact  = accuracy_exact_match(gt_parents_set, s2_parents_set)
    s2_gp_exact = accuracy_exact_match(gt_grandparents_set, s2_grandparents_set)
    s2_c_exact  = accuracy_exact_match(gt_children_set, s2_children_set)
    s2_s_exact  = accuracy_exact_match(gt_siblings_set, s2_siblings_set)
    s2_parents_j      = accuracy_jaccard(gt_parents_set, s2_parents_set)
    s2_grandparents_j = accuracy_jaccard(gt_grandparents_set, s2_grandparents_set)
    s2_children_j     = accuracy_jaccard(gt_children_set, s2_children_set)
    s2_siblings_j     = accuracy_jaccard(gt_siblings_set, s2_siblings_set)
    s2_concept_exact   = (s2_p_exact + s2_gp_exact + s2_c_exact + s2_s_exact) / 4.0
    s2_concept_jaccard = (s2_parents_j + s2_grandparents_j + s2_children_j + s2_siblings_j) / 4.0

    rows.append({
        "concept_name": concept_term,
        "gt_snomed_id": gt_id,
        "gt_fsn": gt_fsn,

        "gt_parents": gt_parents_pipe,
        "gt_grandparents": gt_grandparents_pipe,
        "gt_children": gt_children_pipe,
        "gt_siblings": gt_siblings_pipe,

        "set1_parents": s1_A4,
        "set1_grandparents": s1_A5,
        "set1_children": s1_A6,
        "set1_siblings": s1_A7,

        "set1_parents_exact": s1_p_exact,
        "set1_grandparents_exact": s1_gp_exact,
        "set1_children_exact": s1_c_exact,
        "set1_siblings_exact": s1_s_exact,
        "set1_concept_exact": s1_concept_exact,
        "set1_parents_jaccard": s1_parents_j,
        "set1_grandparents_jaccard": s1_grandparents_j,
        "set1_children_jaccard": s1_children_j,
        "set1_siblings_jaccard": s1_siblings_j,
        "set1_concept_jaccard": s1_concept_jaccard,

        "set2_parents": s2_B4,
        "set2_grandparents": s2_B5,
        "set2_children": s2_B6,
        "set2_siblings": s2_B7,

        "set2_parents_exact": s2_p_exact,
        "set2_grandparents_exact": s2_gp_exact,
        "set2_children_exact": s2_c_exact,
        "set2_siblings_exact": s2_s_exact,
        "set2_concept_exact": s2_concept_exact,
        "set2_parents_jaccard": s2_parents_j,
        "set2_grandparents_jaccard": s2_grandparents_j,
        "set2_children_jaccard": s2_children_j,
        "set2_siblings_jaccard": s2_siblings_j,
        "set2_concept_jaccard": s2_concept_jaccard,
    })

out_df = pd.DataFrame(rows)
out_df.to_csv(OUT_COMPARE_PATH, index=False)

print(f"Wrote comparison: {OUT_COMPARE_PATH}")
print(out_df[["concept_name", "gt_snomed_id", "gt_fsn"]].head(10))


## Accuracy Summary Tables

In [ ]:
# ============================================================
# Table 1 - Exact-match (recall): average per field and overall
# ============================================================

table_exact = pd.DataFrame([
    {"metric": "parents",       "set1_avg": round(100 * out_df["set1_parents_exact"].mean(), 2),       "set2_avg": round(100 * out_df["set2_parents_exact"].mean(), 2)},
    {"metric": "grandparents",  "set1_avg": round(100 * out_df["set1_grandparents_exact"].mean(), 2),  "set2_avg": round(100 * out_df["set2_grandparents_exact"].mean(), 2)},
    {"metric": "children",      "set1_avg": round(100 * out_df["set1_children_exact"].mean(), 2),      "set2_avg": round(100 * out_df["set2_children_exact"].mean(), 2)},
    {"metric": "siblings",      "set1_avg": round(100 * out_df["set1_siblings_exact"].mean(), 2),      "set2_avg": round(100 * out_df["set2_siblings_exact"].mean(), 2)},
    {"metric": "overall_avg",   "set1_avg": round(100 * out_df["set1_concept_exact"].mean(), 2),       "set2_avg": round(100 * out_df["set2_concept_exact"].mean(), 2)},
])

print("=" * 80)
print("TABLE 1 - Exact-match (recall, %)")
print("=" * 80)
print(table_exact.to_string(index=False))

# ============================================================
# Table 2 - Jaccard: average per field and overall
# ============================================================

table_jaccard = pd.DataFrame([
    {"metric": "parents",       "set1_avg": round(100 * out_df["set1_parents_jaccard"].mean(), 2),       "set2_avg": round(100 * out_df["set2_parents_jaccard"].mean(), 2)},
    {"metric": "grandparents",  "set1_avg": round(100 * out_df["set1_grandparents_jaccard"].mean(), 2),  "set2_avg": round(100 * out_df["set2_grandparents_jaccard"].mean(), 2)},
    {"metric": "children",      "set1_avg": round(100 * out_df["set1_children_jaccard"].mean(), 2),      "set2_avg": round(100 * out_df["set2_children_jaccard"].mean(), 2)},
    {"metric": "siblings",      "set1_avg": round(100 * out_df["set1_siblings_jaccard"].mean(), 2),      "set2_avg": round(100 * out_df["set2_siblings_jaccard"].mean(), 2)},
    {"metric": "overall_avg",   "set1_avg": round(100 * out_df["set1_concept_jaccard"].mean(), 2),       "set2_avg": round(100 * out_df["set2_concept_jaccard"].mean(), 2)},
])

print("\n" + "=" * 80)
print("TABLE 2 - Jaccard (%)")
print("=" * 80)
print(table_jaccard.to_string(index=False))
print("\nStep 3 (Accuracy) complete!")
