# Capstone 2 – Healthcare Claim Coverage Agent (Deterministic Tools)

This notebook expects an authenticated LLM client named **`chat_client`** to already exist in the session. All input files live under `./Data` and the output CSV is written to `./Data/submission.csv`.


In [None]:
# === Capstone 2: Healthcare Claim Coverage Agent (Deterministic Tools) ===
# Assumes you already created an authenticated LLM client named `chat_client`.
# If package names differ in your env, tweak imports accordingly.

from typing import TypedDict, List, Annotated, Dict, Any
from langchain.tools import tool
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain.agents import create_react_agent
import json, csv, os, re

# ----------------------------------------------------------------------
# 0) Paths (all under ./Data)
# ----------------------------------------------------------------------
DATA_DIR = "./Data"
VALIDATION_PATH = f"{DATA_DIR}/validation_records.json"
TEST_PATH       = f"{DATA_DIR}/test_records.json"
POLICIES_PATH   = f"{DATA_DIR}/insurance_policies.json"
REF_CODES_PATH  = f"{DATA_DIR}/reference_codes.json"
SUBMISSION_PATH = f"{DATA_DIR}/submission.csv"

SYSTEM_PROMPT = (
    "You are a careful, compliance-oriented claims AI. "
    "Use tools to extract facts deterministically, then decide APPROVE or ROUTE FOR REVIEW "
    "with a short factual reason based strictly on policy rules."
)

# ----------------------------------------------------------------------
# 1) Load reference artifacts
# ----------------------------------------------------------------------
def _safe_load_json(path, default):
    try:
        with open(path, "r") as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"\u26a0\ufe0f Missing file: {path}. Skipping.")
        return default

policies  = _safe_load_json(POLICIES_PATH, {})
ref_codes = _safe_load_json(REF_CODES_PATH, {})

ICD_TO_NAME = ref_codes.get("diagnosis_codes", {})       # e.g., {"J45.909": "Asthma, unspecified", ...}
CPT_TO_NAME = ref_codes.get("procedure_codes", {})       # e.g., {"93000": "Electrocardiogram", ...}

# ----------------------------------------------------------------------
# 2) Helpers
# ----------------------------------------------------------------------
def normalize_age(age_val):
    try:
        return int(age_val)
    except Exception:
        m = re.search(r"\d+", str(age_val))
        return int(m.group()) if m else None

def sex_normalize(s):
    s = str(s).strip().upper()
    if s in {"M","MALE"}: return "M"
    if s in {"F","FEMALE"}: return "F"
    return "U"

def dict_get(d, *keys, default=None):
    cur = d
    for k in keys:
        if not isinstance(cur, dict):
            return default
        cur = cur.get(k)
    return default if cur is None else cur

# ----------------------------------------------------------------------
# 3) Tool: summarize_patient_record(record_str)
# ----------------------------------------------------------------------
@tool
def summarize_patient_record(record_str: str) -> str:
    """
    Deterministically extract key fields from a semi-structured patient/claim record string.
    Returns a JSON string with: patient_id, age, sex, diagnoses, procedures, policy_id, preauth_provided (bool)
    """
    patient_id = None
    age = None
    sex = "U"
    policy_id = None
    diagnoses = []
    procedures = []
    preauth_provided = False

    pid = re.search(r"patient[_\s]*id[:\s-]*([A-Za-z0-9_-]+)", record_str, re.I)
    if pid: patient_id = pid.group(1)

    age_m = re.search(r"\bage[:\s-]*([0-9]{1,3})\b", record_str, re.I)
    if age_m: age = int(age_m.group(1))

    sex_m = re.search(r"\b(sex|gender)[:\s-]*([A-Za-z]+)", record_str, re.I)
    if sex_m: sex = sex_normalize(sex_m.group(2))

    pol_m = re.search(r"\bpolicy[_\s]*id[:\s-]*([A-Za-z0-9_-]+)", record_str, re.I)
    if pol_m: policy_id = pol_m.group(1)

    diagnoses = list({c.upper() for c in re.findall(r"\b([A-Z]\d{1,2}[A-Z0-9.\-]+)\b", record_str) if c.upper() in ICD_TO_NAME})
    procedures = list({c for c in re.findall(r"\b(\d{4,5})\b", record_str) if c in CPT_TO_NAME})

    if re.search(r"pre-?auth(?:orization)?[:\s-]*(yes|y|true|provided|approved)", record_str, re.I):
        preauth_provided = True

    summary = {
        "patient_id": patient_id,
        "age": age,
        "sex": sex,
        "diagnoses": diagnoses,
        "procedures": procedures,
        "policy_id": policy_id,
        "preauth_provided": preauth_provided
    }
    return json.dumps(summary)

# ----------------------------------------------------------------------
# 4) Tool: summarize_policy_guideline(policy_id)
# ----------------------------------------------------------------------
@tool
def summarize_policy_guideline(policy_id: str) -> str:
    """
    Converts a policy entry into a compact JSON rule set the checker can use.
    Returns JSON with: allowed_diagnoses, allowed_procedures, age_min, age_max, sex, preauth_required, policy_title, notes
    """
    pol = policies.get(policy_id)
    if not pol:
        return json.dumps({"error": f"Unknown policy_id '{policy_id}'"})

    allowed_dx   = dict_get(pol, "criteria", "diagnoses", default=[])
    allowed_cpt  = dict_get(pol, "criteria", "procedures", default=[])
    age_min      = dict_get(pol, "criteria", "age_min", default=None)
    age_max      = dict_get(pol, "criteria", "age_max", default=None)
    sex_rule     = dict_get(pol, "criteria", "sex", default=None)  # "M","F", or None
    preauth_req  = bool(dict_get(pol, "criteria", "preauth_required", default=False))

    out = {
        "policy_id": policy_id,
        "allowed_diagnoses": allowed_dx,
        "allowed_procedures": allowed_cpt,
        "age_min": age_min,
        "age_max": age_max,
        "sex": sex_rule,
        "preauth_required": preauth_req,
        "policy_title": pol.get("title", ""),
        "notes": dict_get(pol, "notes", default="")
    }
    return json.dumps(out)

# ----------------------------------------------------------------------
# 5) Tool: check_claim_coverage(record_summary, policy_summary)
# ----------------------------------------------------------------------
@tool
def check_claim_coverage(record_summary: str, policy_summary: str) -> str:
    """
    Deterministic rules engine: compare patient/claim vs policy rules.
    Returns JSON: {"decision": "...", "reason": "..."}
    """
    try:
        rs = json.loads(record_summary)
        ps = json.loads(policy_summary)
    except Exception as e:
        return json.dumps({"decision":"ROUTE FOR REVIEW", "reason": f"Malformed inputs: {e}"})

    missing = [k for k in ["age","sex","diagnoses","procedures","policy_id"] if rs.get(k) in [None, [], ""]]
    if missing:
        return json.dumps({"decision": "ROUTE FOR REVIEW",
                           "reason": f"Missing/insufficient fields in record: {', '.join(missing)}."})

    age  = normalize_age(rs.get("age"))
    sex  = rs.get("sex","U")
    dxs  = {d.upper() for d in rs.get("diagnoses",[])}
    cpts = set(rs.get("procedures",[]))

    allowed_dx  = {d.upper() for d in ps.get("allowed_diagnoses",[])}
    allowed_cpt = set(ps.get("allowed_procedures",[]))
    age_min = ps.get("age_min")
    age_max = ps.get("age_max")
    sex_rule = ps.get("sex")
    preauth_required = bool(ps.get("preauth_required", False))
    preauth_provided = bool(rs.get("preauth_provided", False))

    reasons = []

    if allowed_cpt and not (cpts & allowed_cpt):
        reasons.append("Claimed procedure not listed as covered for this policy.")
    if allowed_dx and not (dxs & allowed_dx):
        reasons.append("Diagnosis does not meet covered indications.")
    if age is not None:
        if age_min is not None and age < age_min:
            reasons.append(f"Patient age {age} is below minimum {age_min}.")
        if age_max is not None and age > age_max:
            reasons.append(f"Patient age {age} exceeds maximum {age_max}.")
    else:
        reasons.append("Age unavailable.")
    if sex_rule in {"M","F"} and sex != sex_rule:
        reasons.append(f"Policy restricted to sex '{sex_rule}'.")
    if preauth_required and not preauth_provided:
        reasons.append("Preauthorization required but not provided.")

    if reasons:
        return json.dumps({"decision":"ROUTE FOR REVIEW", "reason":"; ".join(reasons)[:500]})
    return json.dumps({"decision":"APPROVE", "reason":"Meets policy criteria."})

# ----------------------------------------------------------------------
# 6) Agent wiring (ReAct-style, single node graph)
# ----------------------------------------------------------------------
BASE_SYS = SYSTEM_PROMPT
TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]

agent = create_react_agent(
    model=chat_client,     # <-- your authenticated client
    tools=TOOLS,
    state_modifier=BASE_SYS
)

class State(TypedDict):
    messages: Annotated[List, add_messages]

graph_builder = StateGraph(State)
def agent_node(state: State):
    return agent.invoke(state)

graph_builder.add_node("runner", agent_node)
graph_builder.add_edge(START, "runner")
graph_builder.add_edge("runner", END)
claim_graph = graph_builder.compile()

# ----------------------------------------------------------------------
# 7) Runner helpers
# ----------------------------------------------------------------------
def run_on_record(record: Dict[str, Any]) -> Dict[str, str]:
    """
    Expects:
      record['patient_id'], record['record_str'], record['policy_id']
    Returns: {"patient_id": ..., "generated_response": "..."}
    """
    record_str = record.get("record_str","")
    policy_id  = record.get("policy_id","")

    prompt = f"""
Use the tools to:
1) summarize_patient_record from the provided record text,
2) summarize_policy_guideline for policy_id={policy_id},
3) check_claim_coverage using both summaries,
then return ONE short sentence:
"Decision: <APPROVE|ROUTE FOR REVIEW>. Reason: <key reason>"

Record:
{record_str}

Policy ID: {policy_id}
"""
    result = claim_graph.invoke({"messages":[("user", prompt)]})
    final_text = str(result["messages"][-1].content).strip()
    return {"patient_id": record.get("patient_id"), "generated_response": final_text}

def load_records(path):
    try:
        with open(path,"r") as f:
            data = json.load(f)
        if isinstance(data, dict) and "records" in data:
            return data["records"]
        return data if isinstance(data, list) else []
    except FileNotFoundError:
        print(f"\u26a0\ufe0f Missing file: {path}. Skipping.")
        return []

# ----------------------------------------------------------------------
# 8) Quick sanity on validation set (prints a couple of rows if present)
# ----------------------------------------------------------------------
val = load_records(VALIDATION_PATH)
if val:
    print("Validation sample:")
    for rec in val[:3]:
        out = run_on_record(rec)
        print(out["patient_id"], "->", out["generated_response"])
else:
    print("No validation records found or file missing.")

# ----------------------------------------------------------------------
# 9) Produce ./Data/submission.csv from test_records.json
# ----------------------------------------------------------------------
test = load_records(TEST_PATH)
rows = [run_on_record(r) for r in test]

os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, "w", newline="", encoding="utf-8") as f:
    w = csv.DictWriter(f, fieldnames=["patient_id","generated_response"])
    w.writeheader()
    w.writerows(rows)

print(f"✅ Wrote {SUBMISSION_PATH} with {len(rows)} rows.")
