
# Capstone — Healthcare Claim Coverage Agent (Parallelized Runner, Fully Commented)

This notebook implements **one ReAct agent** that can call **three tools** to evaluate insurance claim coverage against policy rules.
The logic is identical to `code_code4.ipynb`; the only difference is that **every line is heavily commented** to help a junior developer follow the flow and explain it.

## Overall design & reasoning

- **Single source of truth for policies:** We parse policy JSON once into a fast **index** keyed by normalized `policy_id`. Lookups are O(1) and robust to different input shapes.
- **Normalize early:** Real-world data is messy. We normalize dates, sex, ICD and CPT codes *before* any rule checks. This keeps the checker deterministic and simple.
- **Three focused tools:**  
  1. `summarize_patient_record` — shape one raw claim into a clean dict.  
  2. `summarize_policy_guideline` — fetch policy rules by `policy_id`.  
  3. `check_claim_coverage` — decide **APPROVE** vs **ROUTE FOR REVIEW** with a concise reason.
- **Strict agent prompt:** We instruct the agent to always call tools in the same order and to produce a fixed output format. This reduces LLM variance.
- **Parallel execution:** LLM round-trips dominate latency. We use `ThreadPoolExecutor` to run multiple records concurrently, without changing any logic.


In [None]:
# === Imports & path constants ===
from __future__ import annotations  # allows forward references in type hints (Python <3.11)
from typing import Dict, Any, List, Optional, Tuple  # for clear type annotations
import os, json, csv, re, time  # stdlib utilities: file ops, JSON, CSV, regex, timing
from datetime import datetime  # for parsing/computing dates and ages
from concurrent.futures import ThreadPoolExecutor, as_completed  # simple thread-based parallelism

# LangChain / LangGraph building blocks
from langchain_core.tools import tool  # decorator to expose Python functions as LLM-callable tools
from langchain_core.messages import SystemMessage, HumanMessage  # message primitives for agent I/O
from langgraph.prebuilt import create_react_agent  # prebuilt ReAct-style agent constructor

# --- Relative paths to all project JSONs/outputs ---
DATA_DIR = './Data'  # convention from the project: all input/output JSON/CSV live here
REFERENCE_CODES_PATH = os.path.join(DATA_DIR, 'reference_codes.json')  # ICD/CPT lookups (optional)
POLICIES_PATH        = os.path.join(DATA_DIR, 'insurance_policies.json')  # plan benefit rules
VALIDATION_PATH      = os.path.join(DATA_DIR, 'validation_records.json')  # small set for preview/debug
TEST_PATH            = os.path.join(DATA_DIR, 'test_records.json')  # full test set used for submission
SUBMISSION_PATH      = os.path.join(DATA_DIR, 'submission.csv')  # required output file


In [None]:
# === Load JSON helpers & reference dictionaries ===
def _safe_load(path: str, default):
    """Best-effort JSON loader.
    If loading fails (file missing, bad JSON, etc.), we print a warning and return `default`.
    This keeps the notebook resilient to environment differences.
    """
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)  # parse JSON into Python dict/list
    except Exception as e:
        print(f"⚠️ Could not load {path}: {e}")
        return default

# Load all inputs up-front so later cells can rely on them.
reference_codes: Dict[str, Any] = _safe_load(REFERENCE_CODES_PATH, {})  # may contain icd/cpt maps
policies_raw:    Dict[str, Any] = _safe_load(POLICIES_PATH, {})         # raw policy JSON (varied shapes)
val_records:     List[Dict[str, Any]] = _safe_load(VALIDATION_PATH, []) # validation sample records
test_records:    List[Dict[str, Any]] = _safe_load(TEST_PATH, [])       # test records to score

# Optional name maps for codes (used to filter unknowns, but not strictly required)
ICD_TO_NAME: Dict[str, str] = reference_codes.get('diagnosis_codes', {}) or reference_codes.get('icd10', {})
CPT_TO_NAME: Dict[str, str] = reference_codes.get('procedure_codes', {}) or reference_codes.get('cpt', {})

print(
    f"Loaded: reference_codes={bool(reference_codes)}, "
    f"policies={bool(policies_raw)}, "
    f"validation_records={len(val_records)}, test_records={len(test_records)}"
)


In [None]:
# === Helpers for normalization + building a fast policy index ===

# We accept many date formats because upstream feeds are inconsistent.
DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    """Try to parse a date string using several formats, including bare YYYYMMDD.
    Returns a `datetime` on success, otherwise `None` (non-throwing).
    """
    if not s:
        return None  # early exit for None/empty values
    s = str(s).strip()  # normalize whitespace and ensure str
    for pat in DATE_PATTERNS:
        try:
            return datetime.strptime(s, pat)  # first pattern that works wins
        except Exception:
            pass  # silently try next pattern
    # Fallback: accept compact YYYYMMDD (e.g., "20250130")
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try:
            return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception:
            return None  # invalid month/day -> None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    """Compute age in whole years at `ref` date (or today if `ref` is None).
    Returns None if DOB is missing or the computed age is not sane (negative/too large).
    """
    if not dob:
        return None
    ref = ref or datetime.utcnow()  # default to now if no service date is provided
    # Standard age calc: subtract years, minus 1 if birthday hasn't occurred yet this year
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None  # sanity bounds

def sex_normalize(s: Any) -> str:
    """Return canonical 'M'/'F'/'U' from various inputs (e.g., 'male', 'FEMALE', None)."""
    if s is None:
        return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}:
        return 'M'
    if s_up in {'F','FEMALE'}:
        return 'F'
    return 'U'  # unknown/other

def icd_dot_normalize(code: Any) -> str:
    """Normalize ICD-like codes (e.g., 'I200' -> 'I20.0').
    We uppercase, strip, and insert a dot after the 3rd char when the shape matches.
    """
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def _norm_pid(pid: Optional[str]) -> Optional[str]:
    """Normalize policy IDs by stripping dashes/underscores/spaces and uppercasing.
    This avoids lookup mismatches like 'pol-123' vs 'POL123'.
    """
    if not pid:
        return None
    return re.sub(r'[-_\s]', '', str(pid)).upper()

def _extract_rules(obj: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Given one policy object, extract a list of normalized rule dicts.
    Each rule captures: procedure, acceptable diagnoses, age range, sex constraint, preauth.
    """
    rules: List[Dict[str, Any]] = []
    items = obj.get('covered_procedures') or []  # common field for rule list
    if isinstance(items, list):
        for it in items:
            if not isinstance(it, dict):
                continue  # skip malformed entries defensively
            # The procedure code is the anchor for the rule
            proc = str(it.get('procedure_code') or it.get('procedure') or '').strip()
            if not proc:
                continue  # rule without a procedure makes no sense

            # Accept diagnoses list or delimited string; normalize each ICD code
            diags = it.get('covered_diagnoses') or it.get('diagnoses') or []
            if isinstance(diags, str):
                diags = [d for d in re.split(r'[;,.\s]+', diags) if d]
            diags = [icd_dot_normalize(d) for d in diags]

            # Optional age constraints
            ar = it.get('age_range') or []
            age_min = None; age_max = None
            if isinstance(ar, list) and len(ar) >= 2:
                try:
                    age_min = int(ar[0])
                except Exception:
                    pass
                try:
                    age_max = int(ar[1])
                except Exception:
                    pass

            # Optional sex constraint ('M', 'F', or 'Any')
            sex_raw = str(it.get('gender') or it.get('sex') or 'Any').strip().upper()
            sex_rule = 'ANY'
            if sex_raw.startswith('M'):
                sex_rule = 'M'
            elif sex_raw.startswith('F'):
                sex_rule = 'F'

            # Whether preauthorization is required by policy for this procedure
            preauth_required = bool(it.get('requires_preauthorization') or it.get('preauth_required') or False)

            # Append a normalized rule record
            rules.append({
                'procedure': proc,
                'diagnoses': diags,
                'age_min': age_min,
                'age_max': age_max,
                'sex': sex_rule,
                'preauth_required': preauth_required
            })
    return rules

def build_policy_index(raw: Any) -> Dict[str, Dict[str, Any]]:
    """Turn raw policies (various shapes) into a dict keyed by normalized policy_id.
    Accepts dict with 'policies', dict-of-dicts, or list.
    """
    index: Dict[str, Dict[str, Any]] = {}
    # 1) Common shape: { "policies": [ {...}, {...} ] }
    if isinstance(raw, dict) and 'policies' in raw and isinstance(raw['policies'], list):
        src_list = raw['policies']
    # 2) Dict-of-dicts: { "POL123": {...}, "POL456": {...} }
    elif isinstance(raw, dict):
        src_list = []
        for k, v in raw.items():
            if isinstance(v, dict):
                v = dict(v)  # copy to avoid mutating input
                v.setdefault('policy_id', v.get('policyId') or v.get('id') or k)  # synthesize id if needed
                src_list.append(v)
    # 3) Already a list
    elif isinstance(raw, list):
        src_list = raw
    else:
        src_list = []  # unknown shape -> empty

    # Normalize each policy and insert into index
    for obj in src_list:
        if not isinstance(obj, dict):
            continue
        pid = _norm_pid(
            obj.get('policy_id') or obj.get('policyId') or obj.get('id') or obj.get('code') or obj.get('policy_code')
        )
        if not pid:
            continue  # cannot index without an id
        index[pid] = {
            'policy_id': pid,
            'plan_name': obj.get('plan_name') or obj.get('title') or '',  # human-friendly optional
            'rules': _extract_rules(obj)
        }
    return index

# Build once at import time for fast lookups during processing
POLICY_INDEX: Dict[str, Dict[str, Any]] = build_policy_index(policies_raw)
print(f"Policy index contains {len(POLICY_INDEX)} policies.")


In [None]:
# === Tools exposed to the agent ===
# Tools are plain Python functions decorated with @tool, which lets the agent call them.

@tool
def summarize_patient_record(record_str: str) -> Dict[str, Any]:
    """
    Normalize a raw patient claim record into a consistent dict the checker can use.
    - Parses JSON
    - Computes age if missing (using DOB + service date)
    - Normalizes sex and diagnosis/procedure codes
    - Flattens various preauth flags into one boolean
    """
    # Parse the raw JSON string into a Python dict; tolerate errors by returning an error object
    try:
        rec = json.loads(record_str)
    except Exception:
        return {'error': 'invalid JSON'}

    # Prefer explicit age when present; otherwise compute from DOB and a reference date
    age = rec.get('age')
    if age is None:
        dob = rec.get('date_of_birth') or rec.get('dob')
        svc = rec.get('date_of_service') or rec.get('service_date') or rec.get('claim_date')
        if dob:
            age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))

    # Normalize sex to one of 'M','F','U'
    sex = sex_normalize(rec.get('gender') or rec.get('sex'))

    # Collect diagnoses, normalize ICD formatting, optionally filter to known references
    dx  = rec.get('diagnosis_codes') or rec.get('diagnoses') or []
    dx  = [icd_dot_normalize(d) for d in dx]
    if ICD_TO_NAME:
        dx = [d for d in dx if d in ICD_TO_NAME]

    # Collect procedures, coerce to strings, optionally filter to known CPT or 5-digit fallback
    cpt = rec.get('procedure_codes') or rec.get('procedures') or []
    cpt = [str(x) for x in cpt]
    if CPT_TO_NAME:
        keep = [c for c in cpt if c in CPT_TO_NAME]
        cpt = keep if keep else [c for c in cpt if re.fullmatch(r'\d{5}', c)]

    # Collapse multiple potential flags into a single boolean "preauth provided" signal
    preauth = rec.get('preauthorization_obtained') or rec.get('preauth_provided') or rec.get('authorization_provided')
    preauth = bool(preauth)

    # Return a clean summary; note policy_id is normalized for reliable index lookup
    return {
        'patient_id': rec.get('patient_id') or rec.get('id'),
        'policy_id': _norm_pid(rec.get('insurance_policy_id') or rec.get('policy_id') or rec.get('plan_id')),
        'age': age,
        'sex': sex,
        'diagnoses': dx,
        'procedures': cpt,
        'preauth': preauth
    }

@tool
def summarize_policy_guideline(policy_id: str) -> Dict[str, Any]:
    """
    Fetch normalized rules for the given policy by ID.
    Returns a dict with 'policy_id' and 'rules', or an error if the policy is unknown.
    """
    pid = _norm_pid(policy_id)  # normalize incoming id to match index keys
    pol = POLICY_INDEX.get(pid)  # O(1) dictionary lookup
    if not pol:
        return {'error': f'Unknown policy_id {policy_id}'}
    return {'policy_id': pid, 'rules': pol.get('rules', [])}

def _rule_covers(record: Dict[str, Any], rule: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    """Evaluate whether a single policy rule covers this claim summary.
    Returns (True, None) if covered, else (False, reason) for the first failing condition.
    """
    # Convert lists to sets for efficient membership/intersection checks
    cpts = set(record.get('procedures', []))
    dxs  = set(record.get('diagnoses', []))
    age  = record.get('age')
    sex  = (record.get('sex') or 'U').upper()
    pre  = bool(record.get('preauth', False))

    # If the billed CPT isn't the one this rule covers, it's a non-match (not an error)
    if rule['procedure'] not in cpts:
        return False, None

    # If rule specifies diagnoses, at least one must be present in the claim
    rdx = set(rule.get('diagnoses', []))
    if rdx and not (dxs & rdx):  # empty rdx means "no diagnosis constraint"
        return False, f"Diagnosis mismatch for CPT {rule['procedure']}."

    # Respect age bounds when the claim has a computable age
    if age is not None:
        if rule.get('age_min') is not None and age < rule['age_min']:
            return False, f"Patient age {age} below covered minimum for CPT {rule['procedure']}."
        if rule.get('age_max') is not None and age > rule['age_max']:
            return False, f"Patient age {age} exceeds covered maximum for CPT {rule['procedure']}."

    # Enforce sex restriction when specified by the rule
    if rule.get('sex') in {'M','F'} and sex != rule['sex']:
        return False, f"Sex restriction {rule['sex']} for CPT {rule['procedure']}."

    # If preauth is required by policy, it must be present on the claim
    if rule.get('preauth_required') and not pre:
        return False, f"Preauthorization required for CPT {rule['procedure']} but not provided."

    # All constraints satisfied → this rule covers the claim for that CPT
    return True, None

@tool
def check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    """
    Decide **APPROVE** vs **ROUTE FOR REVIEW** for a claim by comparing the record summary
    to the set of rules returned for the policy.
    """
    # If policy lookup failed, we cannot safely approve
    if policy_summary.get('error'):
        return {'decision':'ROUTE FOR REVIEW','reason':'Unknown policy_id.'}

    # Extract rule list; an empty rule set is ambiguous → route for review
    rules = policy_summary.get('rules', [])
    if not rules:
        return {'decision':'ROUTE FOR REVIEW','reason':'Policy does not define covered procedures.'}

    # For every billed CPT in the claim, at least one rule must cover it fully
    for cpt in record_summary.get('procedures', []):
        # Candidate rules are those whose 'procedure' equals the CPT
        cand = [r for r in rules if r.get('procedure') == cpt]
        if not cand:
            return {'decision':'ROUTE FOR REVIEW','reason': f'The claim for CPT code {cpt} is not covered by the policy.'}

        # Try each candidate rule until one covers; remember the first failure reason for helpful feedback
        ok_any = False; first_reason = None
        for r in cand:
            ok, why = _rule_covers(record_summary, r)
            if ok:
                ok_any = True; break
            if first_reason is None:
                first_reason = why
        if not ok_any:
            return {'decision':'ROUTE FOR REVIEW','reason': first_reason or f'The claim for CPT code {cpt} does not meet policy requirements.'}

    # If we reach here, every billed CPT is covered by at least one rule → approve
    cpts = record_summary.get('procedures', [])
    detail = f'The claim for CPT code {cpts[0]} is approved.' if cpts else 'Meets policy criteria.'
    return {'decision':'APPROVE','reason': detail}


In [None]:
# === Agent wiring + parallelized runner ===

# Toolbox given to the agent: the only three functions it is allowed to call.
TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]

# System prompt: this acts like the "SOP" for the agent. It enforces tool order and a fixed output shape.
AGENT_SYS_TEXT = (
    "You are an insurance claims agent.\n"
    "You MUST always call tools in this order for each record:\n"
    "  1) summarize_patient_record -> on the raw JSON string of the record\n"
    "  2) summarize_policy_guideline -> with the policy_id from step 1\n"
    "  3) check_claim_coverage -> with outputs of steps 1 & 2\n"
    "When you have enough information, output the FINAL answer in exactly this format:\n"
    "- Decision: APPROVE | ROUTE FOR REVIEW\n"
    "- Reason: \n"
    "Rules:\n"
    "- Do not invent policy data. If policy_id is missing/unknown, route for review.\n"
    "- Keep the reason short and specific.\n"
)

# Create the ReAct agent from the existing chat client.
# Assumption: the user already defined `chat_client` (e.g., OpenAI client) in this runtime.
try:
    agent = create_react_agent(chat_client, TOOLS)
except NameError:
    # Provide a clearer error for users: they must define `chat_client` in their environment.
    raise RuntimeError("Expected an existing `chat_client` instance. Define it before running this cell.")

def _process_record(agent, rec: Dict[str, Any]) -> Dict[str, str]:
    """Helper: run the agent on a **single** record and extract the final response text."""
    rec_str = json.dumps(rec, ensure_ascii=False)  # JSON text becomes the human message payload
    messages = [
        SystemMessage(content=AGENT_SYS_TEXT),  # strict instructions to keep the agent on-rails
        HumanMessage(content=f"Process this record:\n{rec_str}")  # the actual record to process
    ]
    out = agent.invoke({"messages": messages})  # synchronous call; agent internally calls tools
    # Depending on versions, `out` can be a dict with messages or an object with .content; handle both
    final = (
        out["messages"][-1].content
        if isinstance(out, dict) and "messages" in out
        else getattr(out, "content", str(out))
    )
    return {
        "patient_id": rec.get("patient_id") or rec.get("id"),  # keep id for output mapping
        "generated_response": final  # the decision text (Decision/Reason)
    }

def run_agent_on_records(agent, records: List[Dict[str, Any]], max_workers: int = 6) -> List[Dict[str, str]]:
    """Run the agent over many records using a thread pool for concurrency.
    - We keep logic identical to the sequential version; only execution model changes (parallel).
    - `max_workers` controls parallelism; tune based on API rate limits and runtime environment.
    """
    results: List[Optional[Dict[str, str]]] = [None] * len(records)  # pre-size result list to preserve order
    if not records:
        return []

    # Warm up the first record to establish network connections and reduce cold-start overhead.
    try:
        results[0] = _process_record(agent, records[0])
    except Exception as e:
        pid = records[0].get("patient_id") or records[0].get("id")
        results[0] = {
            "patient_id": pid,
            "generated_response": f"- Decision: ROUTE FOR REVIEW\n- Reason: processing error: {e}"
        }

    # Submit the rest to a thread pool for concurrent processing
    tasks = {}
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        for idx in range(1, len(records)):
            # Keep track of which future maps to which record index so we can place results correctly
            tasks[ex.submit(_process_record, agent, records[idx])] = idx

        # As each future completes, store its result (or a fallback error) into `results` at the right index
        for fut in as_completed(tasks):
            idx = tasks[fut]
            try:
                results[idx] = fut.result()
            except Exception as e:
                pid = records[idx].get("patient_id") or records[idx].get("id")
                results[idx] = {
                    "patient_id": pid,
                    "generated_response": f"- Decision: ROUTE FOR REVIEW\n- Reason: processing error: {e}"
                }

    # Replace any remaining Nones with a neutral fallback to avoid downstream issues
    for i, r in enumerate(results):
        if r is None:
            pid = records[i].get("patient_id") or records[i].get("id")
            results[i] = {
                "patient_id": pid,
                "generated_response": "- Decision: ROUTE FOR REVIEW\n- Reason: unknown error."
            }

    return results  # ordered list aligned with input records


In [None]:
# === Validation & test runs ===

# Run on validation set (if present) and print a tiny preview for sanity checking.
if val_records:
    start = time.time()
    val_out = run_agent_on_records(agent, val_records, max_workers=6)
    print(f"Validation processed {len(val_out)} records in {time.time()-start:.2f}s (showing first 5):")
    for row in val_out[:5]:
        # show patient id and just the first line (usually the Decision line)
        first_line = row["generated_response"].splitlines()[0] if row.get("generated_response") else ""
        print(row["patient_id"], "=>", first_line)
else:
    # Keep behavior explicit to avoid silent no-ops for new users
    print("No validation records found.")

# Ensure the output directory exists, then run on the test set and write the CSV required by the grader.
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
if test_records:
    start = time.time()
    test_results = run_agent_on_records(agent, test_records, max_workers=6)
    elapsed = time.time() - start
    print(f"Test processed {len(test_results)} records in {elapsed:.2f} seconds")
    with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
        w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])  # required two columns
        w.writeheader(); w.writerows(test_results)
    print(f"Wrote {SUBMISSION_PATH}")
else:
    print("No test records found.")
