# Capstone — ReAct LLM Agent (One Agent, Three Tools)

**This version avoids `state_modifier`** and injects the system prompt as a `SystemMessage` on each call.

- Uses your existing `chat_client` (do **not** recreate it here).
- Reads from `./Data`.
- Tools: `summarize_patient_record`, `summarize_policy_guideline`, `check_claim_coverage`.
- Prints step‑by‑step DEBUG for the first few validation records.
- Writes `./Data/submission.csv` in the required two‑line format.

In [None]:

from __future__ import annotations
from typing import Dict, Any, List, Optional, Tuple
import os, json, csv, re, textwrap
from datetime import datetime

# LangChain / LangGraph
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.prebuilt import create_react_agent

# ---- Config ----
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

DEBUG = True
VALIDATION_DEBUG_N = 3

def _dbg(title: str, payload: Any = None, width: int = 1400):
    if not DEBUG: 
        return
    print(f"[DEBUG] {title}")
    if payload is None: 
        return
    try:
        s = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False, indent=2)
    except Exception:
        s = str(payload)
    if len(s) > width: 
        s = s[:width] + " ... [truncated]"
    print(textwrap.indent(s, "  "))


In [None]:

def _safe_load(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"⚠️ Could not read {path}: {e}")
        return default

policies_raw = _safe_load(POLICIES_PATH, {})
ref_codes    = _safe_load(REF_CODES_PATH, {})
ICD_TO_NAME  = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {})
CPT_TO_NAME  = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {})


In [None]:

DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try: return datetime.strptime(s, pat)
        except Exception: pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try: return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception: return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def _norm_pid(pid: Optional[str]) -> Optional[str]:
    if not pid: return None
    return re.sub(r'[-_\s]', '', str(pid)).upper()

def walk_items(obj):
    if isinstance(obj, dict):
        for k, v in obj.items():
            yield k, v
            for subk, subv in walk_items(v):
                yield subk, subv
    elif isinstance(obj, list):
        for it in obj:
            for subk, subv in walk_items(it):
                yield subk, subv

def deep_get_first(d: Any, keys: List[str]):
    keys_l = {k.lower() for k in keys}
    for k, v in walk_items(d):
        if isinstance(k, str) and k.lower() in keys_l:
            return v
    return None

def _extract_rules(obj: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Normalize per‑CPT rules from a policy object (supports 'covered_procedures')."""
    rules = []
    items = obj.get('covered_procedures') or []
    if isinstance(items, list):
        for it in items:
            if not isinstance(it, dict):
                continue
            proc = str(it.get('procedure_code') or it.get('procedure') or '').strip()
            if not proc:
                continue
            diags = it.get('covered_diagnoses') or it.get('diagnoses') or []
            if isinstance(diags, str):
                diags = [d for d in re.split(r'[;,\s]+', diags) if d]
            diags = [str(d).strip().upper() for d in diags]
            diags = [(d[:3] + '.' + d[3:]) if re.fullmatch(r'[A-Z]\d{3,6}', d) and '.' not in d and len(d)>=4 else d for d in diags]
            ar = it.get('age_range') or []
            age_min = None; age_max = None
            if isinstance(ar, list) and len(ar) >= 2:
                try: age_min = int(ar[0])
                except Exception: pass
                try: age_max = int(ar[1])
                except Exception: pass
            sex_raw = str(it.get('gender') or it.get('sex') or 'Any').strip().upper()
            sex_rule = 'ANY'
            if sex_raw.startswith('M'): sex_rule = 'M'
            elif sex_raw.startswith('F'): sex_rule = 'F'
            preauth_required = bool(it.get('requires_preauthorization') or it.get('preauth_required') or False)
            rules.append({
                'procedure': proc,
                'diagnoses': diags,
                'age_min': age_min,
                'age_max': age_max,
                'sex': sex_rule,
                'preauth_required': preauth_required
            })
    return rules

def build_policy_index(raw) -> Dict[str, Dict[str, Any]]:
    index = {}
    if isinstance(raw, dict) and 'policies' in raw and isinstance(raw['policies'], list):
        src_list = raw['policies']
    elif isinstance(raw, dict):
        src_list = []
        for k, v in raw.items():
            if isinstance(v, dict):
                v = dict(v)
                v.setdefault('policy_id', v.get('policyId') or v.get('id') or k)
                src_list.append(v)
    elif isinstance(raw, list):
        src_list = raw
    else:
        src_list = []

    for obj in src_list:
        if not isinstance(obj, dict):
            continue
        pid = _norm_pid(obj.get('policy_id') or obj.get('policyId') or obj.get('id') or obj.get('code') or obj.get('policy_code'))
        if not pid:
            continue
        index[pid] = {
            'policy_id': pid,
            'plan_name': obj.get('plan_name') or obj.get('title') or '',
            'rules': _extract_rules(obj)
        }
    return index

POLICY_INDEX = build_policy_index(policies_raw)
_dbg("Policy index (sample counts)", {k: len(v['rules']) for k, v in list(POLICY_INDEX.items())[:8]})


In [None]:

@tool
def summarize_patient_record(record_str: str) -> Dict[str, Any]:
    """Summarize patient record JSON string into key fields (patient_id, policy_id, age, sex, diagnoses, procedures, preauth)."""
    try:
        rec = json.loads(record_str)
    except Exception:
        return {'error': 'invalid JSON'}
    age = rec.get('age')
    if age is None:
        dob = rec.get('date_of_birth') or rec.get('dob')
        svc = rec.get('date_of_service') or rec.get('service_date') or rec.get('claim_date')
        age = None
        if dob:
            age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))
    sex = sex_normalize(rec.get('gender') or rec.get('sex'))
    dx  = rec.get('diagnosis_codes') or rec.get('diagnoses') or []
    dx  = [icd_dot_normalize(d) for d in dx]
    if ICD_TO_NAME:
        dx = [d for d in dx if d in ICD_TO_NAME]
    cpt = rec.get('procedure_codes') or rec.get('procedures') or []
    cpt = [str(x) for x in cpt]
    if CPT_TO_NAME:
        keep = [c for c in cpt if c in CPT_TO_NAME]
        cpt = keep if keep else [c for c in cpt if re.fullmatch(r'\d{5}', c)]
    preauth = rec.get('preauthorization_obtained') or rec.get('preauth_provided') or rec.get('authorization_provided')
    preauth = bool(preauth)
    out = {
        'patient_id': rec.get('patient_id') or rec.get('id'),
        'policy_id': _norm_pid(rec.get('insurance_policy_id') or rec.get('policy_id') or rec.get('plan_id')),
        'age': age,
        'sex': sex,
        'diagnoses': dx,
        'procedures': cpt,
        'preauth': preauth
    }
    _dbg("summarize_patient_record.output", out)
    return out

@tool
def summarize_policy_guideline(policy_id: str) -> Dict[str, Any]:
    """Retrieve policy rules by policy_id. Returns {'policy_id', 'rules':[...]} or {'error': ...}."""
    pid = _norm_pid(policy_id)
    pol = POLICY_INDEX.get(pid)
    if not pol:
        out = {'error': f'Unknown policy_id {policy_id}'}
        _dbg("summarize_policy_guideline.output", out)
        return out
    out = {'policy_id': pid, 'rules': pol.get('rules', [])}
    _dbg("summarize_policy_guideline.output", {'policy_id': pid, 'rules_count': len(out['rules'])})
    return out

def _rule_covers(record: Dict[str, Any], rule: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    cpts = set(record.get('procedures', []))
    dxs  = set(record.get('diagnoses', []))
    age  = record.get('age')
    sex  = (record.get('sex') or 'U').upper()
    pre  = bool(record.get('preauth', False))
    if rule['procedure'] not in cpts:
        return False, f"Rule covers CPT {rule['procedure']} only."
    rdx = set(rule.get('diagnoses', []))
    if rdx and not (dxs & rdx):
        return False, f"Diagnosis mismatch for CPT {rule['procedure']} (needs one of {sorted(rdx)})."
    if age is not None:
        if rule.get('age_min') is not None and age < rule['age_min']:
            return False, f"Age {age} < minimum {rule['age_min']} for CPT {rule['procedure']}."
        if rule.get('age_max') is not None and age > rule['age_max']:
            return False, f"Age {age} > maximum {rule['age_max']} for CPT {rule['procedure']}."
    if rule.get('sex') in {'M','F'} and sex != rule['sex']:
        return False, f"Sex must be {rule['sex']} for CPT {rule['procedure']}."
    if rule.get('preauth_required') and not pre:
        return False, f"Preauthorization required for CPT {rule['procedure']} but not provided."
    return True, None

@tool
def check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    """Compare patient record summary with policy rules and return decision + reason."""
    if policy_summary.get('error'):
        return {'decision':'ROUTE FOR REVIEW','reason':'Unknown policy_id.'}
    rules = policy_summary.get('rules', [])
    if not rules:
        return {'decision':'ROUTE FOR REVIEW','reason':'Policy does not define covered procedures.'}
    for cpt in record_summary.get('procedures', []):
        cand = [r for r in rules if r.get('procedure') == cpt]
        if not cand:
            return {'decision':'ROUTE FOR REVIEW','reason': f'The claim for CPT code {cpt} is not covered by the policy.'}
        ok_any = False; first_reason = None
        for r in cand:
            ok, why = _rule_covers(record_summary, r)
            if ok:
                ok_any = True; break
            if first_reason is None: first_reason = why
        if not ok_any:
            return {'decision':'ROUTE FOR REVIEW','reason': first_reason or f'The claim for CPT code {cpt} does not meet policy requirements.'}
    cpts = record_summary.get('procedures', [])
    detail = f'The claim for CPT code {cpts[0]} is approved.' if cpts else 'Meets policy criteria.'
    return {'decision':'APPROVE','reason': detail}


In [None]:

TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]

AGENT_SYS_TEXT = """You are an insurance claims agent.
You MUST always call tools in this order for each record:
  1) summarize_patient_record -> on the raw JSON string of the record
  2) summarize_policy_guideline -> with the policy_id from step 1
  3) check_claim_coverage -> with outputs of steps 1 & 2
When you have enough information, output the FINAL answer in exactly this format:
- Decision: APPROVE | ROUTE FOR REVIEW
- Reason: <short reason>
Rules:
- Do not invent policy data. If policy_id is missing/unknown, route for review.
- Keep the reason short and specific (e.g., CPT not covered / diagnosis mismatch / age out of range / preauth missing).
"""

# uses your existing `chat_client` (must be defined earlier in your environment)
try:
    agent = create_react_agent(chat_client, TOOLS)  # NOTE: no state_modifier (older API compatibility)
except NameError:
    raise RuntimeError("Expected an existing `chat_client` instance. Define it before running this cell.")

def run_agent_on_records(agent, records, n_debug=0):
    results = []
    for i, rec in enumerate(records):
        rec_str = json.dumps(rec, ensure_ascii=False)
        messages = [
            SystemMessage(content=AGENT_SYS_TEXT),
            HumanMessage(content=f"Process this record:\n{rec_str}")
        ]
        try:
            out = agent.invoke({"messages": messages})
            final = (
                out["messages"][-1].content
                if isinstance(out, dict) and "messages" in out
                else getattr(out, "content", str(out))
            )
        except Exception as e:
            final = f"- Decision: ROUTE FOR REVIEW\n- Reason: Error: {e}"
            out = None

        results.append({
            "patient_id": rec.get("patient_id") or rec.get("id"),
            "generated_response": final
        })

        if i < n_debug and DEBUG:
            print("\n[DEBUG] ===== Record", rec.get("patient_id"), "=====")
            if isinstance(out, dict) and "messages" in out:
                for m in out["messages"]:
                    print(f"  {m.__class__.__name__}: {m.content}")
            print("  => Final:", final.replace("\n", " | "))

    return results


In [None]:

val_records  = _safe_load(VALIDATION_PATH, [])
test_records = _safe_load(TEST_PATH, [])

if val_records and DEBUG:
    print("Validation debug run:")
    _ = run_agent_on_records(agent, val_records[:VALIDATION_DEBUG_N], n_debug=VALIDATION_DEBUG_N)

if test_records:
    test_results = run_agent_on_records(agent, test_records)
    os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
    with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
        w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
        w.writeheader(); w.writerows(test_results)
    print(f"Wrote {SUBMISSION_PATH} with {len(test_results)} rows.")
else:
    print("No test records found.")
