# Capstone 2 – Healthcare Claim Coverage Agent (Deterministic, Agent + Supervisor Style)

This notebook mirrors your housing sample structure with tools, agents, a supervisor, and a runner — **no LLM calls**.

In [None]:

import json, csv
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from pathlib import Path

ROOT = Path('.').resolve()
DATA = ROOT / 'data'

def load_json(path: Path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

def exists(path: Path) -> bool:
    return Path(path).exists()

print('Working directory:', ROOT)
print('Data directory:', DATA)


In [None]:

@dataclass
class PatientSummary:
    patient_id: str
    age: Optional[int]
    gender: Optional[str]
    diagnoses: List[str]
    procedures: List[str]
    claim_procedure_code: Optional[str]
    preauth_id: Optional[str]

@dataclass
class PolicySummary:
    policy_id: str
    covered_procedure_codes: List[str]
    diagnosis_requirements: List[str]
    min_age: Optional[int]
    max_age: Optional[int]
    gender_restriction: Optional[str]
    preauth_required: bool

def normalize_code(code: Optional[str]) -> Optional[str]:
    if not code: return None
    return code.strip().upper().replace('.', '')


In [None]:

def summarize_patient_record(record: Dict[str, Any], ref_codes: Dict[str, Any]) -> PatientSummary:
    """Deterministically extract patient details and claim content."""
    pid = str(record.get("patient_id"))
    age = record.get("age")
    gender = (record.get("gender") or None)
    gender = gender.upper() if isinstance(gender, str) else None
    diagnoses = [normalize_code(d) for d in record.get("diagnoses", [])]
    procedures = [normalize_code(p) for p in record.get("procedures", [])]
    claim_proc = normalize_code(record.get("claim", {}).get("procedure_code"))
    preauth_id = record.get("claim", {}).get("preauthorization_id")
    return PatientSummary(pid, age, gender, diagnoses, procedures, claim_proc, preauth_id)

def summarize_policy_guideline(policy_id: str, policies: Dict[str, Any]) -> PolicySummary:
    """Deterministically extract policy constraints and coverage data."""
    p = policies[str(policy_id)]
    return PolicySummary(
        policy_id=str(policy_id),
        covered_procedure_codes=[normalize_code(c) for c in p.get("covered_procedure_codes", [])],
        diagnosis_requirements=[normalize_code(d) for d in p.get("diagnosis_requirements", [])],
        min_age=p.get("min_age"),
        max_age=p.get("max_age"),
        gender_restriction=(p.get("gender_restriction") or None),
        preauth_required=bool(p.get("preauthorization_required", False)),
    )

def _age_ok(age: Optional[int], min_age: Optional[int], max_age: Optional[int]) -> bool:
    if age is None: return True
    if min_age is not None and age < min_age: return False
    if max_age is not None and age > max_age: return False
    return True

def _gender_ok(gender: Optional[str], restriction: Optional[str]) -> bool:
    if not restriction: return True
    if not gender: return False
    return gender.upper().startswith(restriction.upper())

def _diagnosis_ok(patient_dx: List[str], required_dx: List[str]) -> bool:
    if not required_dx: return True
    pset = set([d or "" for d in patient_dx])
    return any(req in d for d in pset for req in required_dx)

def check_claim_coverage(record_summary: PatientSummary, policy_summary: PolicySummary) -> Dict[str, str]:
    """Deterministic coverage rules with clear, concise reasons."""
    if record_summary.claim_procedure_code not in policy_summary.covered_procedure_codes:
        return {"Decision":"ROUTE FOR REVIEW","Reason":f"Procedure {record_summary.claim_procedure_code} not covered by policy {policy_summary.policy_id}."}
    if not _diagnosis_ok(record_summary.diagnoses, policy_summary.diagnosis_requirements):
        return {"Decision":"ROUTE FOR REVIEW","Reason":"Required diagnosis criteria not met."}
    if not _age_ok(record_summary.age, policy_summary.min_age, policy_summary.max_age):
        return {"Decision":"ROUTE FOR REVIEW","Reason":f"Age {record_summary.age} outside allowed range."}
    if not _gender_ok(record_summary.gender, policy_summary.gender_restriction):
        return {"Decision":"ROUTE FOR REVIEW","Reason":f"Gender restriction {policy_summary.gender_restriction} not satisfied."}
    if policy_summary.preauth_required and not record_summary.preauth_id:
        return {"Decision":"ROUTE FOR REVIEW","Reason":"Preauthorization required but missing."}
    return {"Decision":"APPROVE","Reason":"Meets procedure, diagnosis, demographic, and preauthorization criteria."}


In [None]:

# Agents (wrappers around tools)
def patient_summary_agent(record: Dict[str, Any], ref_codes: Dict[str, Any]) -> PatientSummary:
    return summarize_patient_record(record, ref_codes)

def policy_summary_agent(policy_id: str, policies: Dict[str, Any]) -> PolicySummary:
    return summarize_policy_guideline(policy_id, policies)

def coverage_checker_agent(patient_summary: PatientSummary, policy_summary: PolicySummary) -> Dict[str, str]:
    return check_claim_coverage(patient_summary, policy_summary)


In [None]:

# Supervisor: fixed flow patient -> policy -> coverage
def claim_supervisor(record: Dict[str, Any], policies: Dict[str, Any], refs: Dict[str, Any]) -> Dict[str, str]:
    patient_summary = patient_summary_agent(record, refs)
    policy_id = record.get("claim", {}).get("policy_id")
    policy_summary = policy_summary_agent(policy_id, policies)
    decision = coverage_checker_agent(patient_summary, policy_summary)
    return {"patient_id": patient_summary.patient_id, "generated_response": f"Decision: {decision['Decision']}\nReason: {decision['Reason']}"}


In [None]:

def normalize_policies(policies_raw):
    if isinstance(policies_raw, list):
        return {str(p.get('policy_id')): p for p in policies_raw}
    return policies_raw

def run_claim_validation(records_path: Path, policies_path: Path, refs_path: Path, out_csv: Optional[Path]=None):
    records = load_json(records_path)
    policies = normalize_policies(load_json(policies_path))
    refs = load_json(refs_path)

    results = []
    for record in records:
        results.append(claim_supervisor(record, policies, refs))

    if out_csv:
        with open(out_csv,'w',newline='',encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
            writer.writeheader(); writer.writerows(results)
    return results

print("Deterministic agent-supervisor pipeline ready.")


In [None]:

# Validation run (no CSV) and optional comparison to reference
def extract_decision(resp: str) -> str:
    for line in resp.splitlines():
        if line.upper().startswith("DECISION:"):
            return line.split(":",1)[1].strip().upper()
    return ""

paths = {
    "validation_records": DATA / "validation_records.json",
    "test_records": DATA / "test_records.json",
    "policies": DATA / "insurance_policies.json",
    "refs": DATA / "reference_codes.json",
    "validation_reference": DATA / "validation_reference_results.csv",
}

print({k:(str(v), exists(v)) for k,v in paths.items()})

if all(exists(p) for p in [paths["validation_records"], paths["policies"], paths["refs"]]):
    val_results = run_claim_validation(paths["validation_records"], paths["policies"], paths["refs"])
    print("Validation sample:", val_results[:3])
else:
    print("Validation run skipped: one or more files missing.")

if exists(paths["validation_reference"]) and 'val_results' in globals():
    ref_map = {}
    with open(paths["validation_reference"], "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            ref_map[row["patient_id"]] = row["generated_response"].strip()

    total = correct = 0
    mismatches = []
    for r in val_results:
        pid = r["patient_id"]
        ours = extract_decision(r["generated_response"])
        ref = extract_decision(ref_map.get(pid, ""))
        if ref == "":
            continue
        total += 1
        if ours == ref:
            correct += 1
        else:
            mismatches.append((pid, ours, ref))
    if total:
        print(f"Compared {total} rows. Accuracy: {correct}/{total} = {correct/total:.2%}")
        for m in mismatches[:10]:
            print("Mismatch:", m)
    else:
        print("No comparable rows in reference.")
else:
    print("Reference comparison skipped.")


In [None]:

# Final run -> submission.csv
out_path = Path('submission.csv')
if all(exists(p) for p in [paths["test_records"], paths["policies"], paths["refs"]]):
    _ = run_claim_validation(paths["test_records"], paths["policies"], paths["refs"], out_csv=out_path)
    print("Wrote", out_path)
else:
    print("Test run skipped: one or more files missing.")
