# Capstone — ReAct Agent v2 (Faster, With Tests)

**One agent + three tools** using your existing `chat_client`. Reads from `./Data` and writes to `./Data/submission.csv`.

### Speed optimizations
- Compact JSON for fewer tokens
- Precompiled regex and cached lookups
- Pre-indexed rules for faster validation
- Lean system prompt

Includes quick *self-tests* for the policy index and each tool.

In [None]:

from __future__ import annotations
from typing import Dict, Any, List, Optional, Tuple
import os, json, csv, re
from datetime import datetime
from functools import lru_cache

# LangChain / LangGraph
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.prebuilt import create_react_agent

DATA_DIR = './Data'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'


In [None]:

def _safe_load(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception:
        return default

policies_raw = _safe_load(POLICIES_PATH, {})
ref_codes    = _safe_load(REF_CODES_PATH, {})
ICD_TO_NAME  = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {})
CPT_TO_NAME  = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {})
ICD_SET = frozenset(ICD_TO_NAME.keys()) if isinstance(ICD_TO_NAME, dict) else frozenset(ICD_TO_NAME)
CPT_SET = frozenset(CPT_TO_NAME.keys()) if isinstance(CPT_TO_NAME, dict) else frozenset(CPT_TO_NAME)


In [None]:

DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']
RE_YMD = re.compile(r'^(\d{4})(\d{2})(\d{2})$')
RE_ICD_UNDOTTED = re.compile(r'^[A-Z]\d{3,6}$')
RE_5DIGITS = re.compile(r'^\d{5}$')
RE_SPLIT = re.compile(r'[;,\s]+')

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try: 
            return datetime.strptime(s, pat)
        except Exception: 
            pass
    m = RE_YMD.fullmatch(s)
    if m:
        try: return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception: return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up.startswith('M'): return 'M'
    if s_up.startswith('F'): return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if RE_ICD_UNDOTTED.fullmatch(c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def _norm_pid(pid: Optional[str]) -> Optional[str]:
    if not pid: return None
    return re.sub(r'[-_\s]', '', str(pid)).upper()


In [None]:

def _extract_rules(obj: Dict[str, Any]) -> List[Dict[str, Any]]:
    rules = []
    items = obj.get('covered_procedures') or []
    if isinstance(items, list):
        for it in items:
            if not isinstance(it, dict): continue
            proc = str(it.get('procedure_code') or it.get('procedure') or '').strip()
            if not proc: continue
            diags = it.get('covered_diagnoses') or it.get('diagnoses') or []
            if isinstance(diags, str):
                diags = [d for d in RE_SPLIT.split(diags) if d]
            diags = [icd_dot_normalize(d) for d in diags]
            ar = it.get('age_range') or []
            age_min = age_max = None
            if isinstance(ar, list) and len(ar) >= 2:
                try: age_min = int(ar[0])
                except Exception: pass
                try: age_max = int(ar[1])
                except Exception: pass
            sex_raw = str(it.get('gender') or it.get('sex') or 'Any').strip().upper()
            sex_rule = 'ANY' if sex_raw not in ('M','MALE','F','FEMALE') else ('M' if sex_raw.startswith('M') else 'F')
            preauth_required = bool(it.get('requires_preauthorization') or it.get('preauth_required') or False)
            rules.append({
                'procedure': proc,
                'diagnoses': tuple(diags),
                'age_min': age_min,
                'age_max': age_max,
                'sex': sex_rule,
                'preauth_required': preauth_required
            })
    return rules

def build_policy_index(raw) -> Dict[str, Dict[str, Any]]:
    index = {}
    if isinstance(raw, dict) and 'policies' in raw and isinstance(raw['policies'], list):
        src_list = raw['policies']
    elif isinstance(raw, dict):
        src_list = []
        for k, v in raw.items():
            if isinstance(v, dict):
                v = dict(v); v.setdefault('policy_id', v.get('policyId') or v.get('id') or k)
                src_list.append(v)
    elif isinstance(raw, list):
        src_list = raw
    else:
        src_list = []
    for obj in src_list:
        if not isinstance(obj, dict): continue
        pid = _norm_pid(obj.get('policy_id') or obj.get('policyId') or obj.get('id') or obj.get('code') or obj.get('policy_code'))
        if not pid: continue
        index[pid] = {
            'policy_id': pid,
            'plan_name': obj.get('plan_name') or obj.get('title') or '',
            'rules': _extract_rules(obj)
        }
    return index

POLICY_INDEX = build_policy_index(policies_raw)
print(f"[index] policies loaded: {len(POLICY_INDEX)}")


In [None]:

@tool
def summarize_patient_record(record_str: str) -> Dict[str, Any]:
    """Summarize a raw patient claim record JSON string into a normalized dict."""
    try:
        rec = json.loads(record_str)
    except Exception:
        return {'error': 'invalid JSON'}
    age = rec.get('age')
    if age is None:
        dob = rec.get('date_of_birth') or rec.get('dob')
        svc = rec.get('date_of_service') or rec.get('service_date') or rec.get('claim_date')
        if dob:
            age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))
    sex = sex_normalize(rec.get('gender') or rec.get('sex'))
    dx  = rec.get('diagnosis_codes') or rec.get('diagnoses') or []
    dx  = [icd_dot_normalize(d) for d in dx]
    if ICD_SET:
        dx = [d for d in dx if d in ICD_SET]
    cpt = rec.get('procedure_codes') or rec.get('procedures') or []
    cpt = [str(x) for x in cpt]
    if CPT_SET:
        keep = [c for c in cpt if c in CPT_SET]
        cpt = keep if keep else [c for c in cpt if RE_5DIGITS.fullmatch(c)]
    preauth = rec.get('preauthorization_obtained') or rec.get('preauth_provided') or rec.get('authorization_provided')
    preauth = bool(preauth)
    return {
        'patient_id': rec.get('patient_id') or rec.get('id'),
        'policy_id': _norm_pid(rec.get('insurance_policy_id') or rec.get('policy_id') or rec.get('plan_id')),
        'age': age,
        'sex': sex,
        'diagnoses': tuple(dx),
        'procedures': tuple(cpt),
        'preauth': preauth
    }

@lru_cache(maxsize=1024)
def _cached_policy_rules(pid_norm: str):
    pol = POLICY_INDEX.get(pid_norm)
    if not pol: return None
    return {'policy_id': pid_norm, 'rules': pol.get('rules', [])}

@tool
def summarize_policy_guideline(policy_id: str) -> Dict[str, Any]:
    """Return normalized policy rules for the given policy_id (cached)."""
    pid = _norm_pid(policy_id)
    res = _cached_policy_rules(pid)
    return res if res is not None else {'error': f'Unknown policy_id {policy_id}'}

def _rule_covers(record: Dict[str, Any], rule: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    cpts = set(record.get('procedures', []))
    dxs  = set(record.get('diagnoses', []))
    age  = record.get('age')
    sex  = (record.get('sex') or 'U').upper()
    pre  = bool(record.get('preauth', False))
    if rule['procedure'] not in cpts:
        return False, None
    rdx = set(rule.get('diagnoses', []))
    if rdx and not (dxs & rdx):
        return False, f"Diagnosis mismatch for CPT {rule['procedure']}."
    if age is not None:
        if rule.get('age_min') is not None and age < rule['age_min']:
            return False, f"Patient age {age} below covered minimum for CPT {rule['procedure']}."
        if rule.get('age_max') is not None and age > rule['age_max']:
            return False, f"Patient age {age} exceeds covered maximum for CPT {rule['procedure']}."
    if rule.get('sex') in {'M','F'} and sex != rule['sex']:
        return False, f"Sex restriction {rule['sex']} for CPT {rule['procedure']}."
    if rule.get('preauth_required') and not pre:
        return False, f"Preauthorization required for CPT {rule['procedure']} but not provided."
    return True, None

@tool
def check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    """Determine APPROVE vs ROUTE FOR REVIEW for a claim based on policy rules."""
    if policy_summary.get('error'):
        return {'decision':'ROUTE FOR REVIEW','reason':'Unknown policy_id.'}
    rules = policy_summary.get('rules', [])
    if not rules:
        return {'decision':'ROUTE FOR REVIEW','reason':'Policy does not define covered procedures.'}
    for cpt in record_summary.get('procedures', []):
        cand = [r for r in rules if r.get('procedure') == cpt]
        if not cand:
            return {'decision':'ROUTE FOR REVIEW','reason': f'The claim for CPT code {cpt} is not covered by the policy.'}
        ok_any = False; first_reason = None
        for r in cand:
            ok, why = _rule_covers(record_summary, r)
            if ok:
                ok_any = True; break
            if first_reason is None: first_reason = why
        if not ok_any:
            return {'decision':'ROUTE FOR REVIEW','reason': first_reason or f'The claim for CPT code {cpt} does not meet policy requirements.'}
    cpts = record_summary.get('procedures', [])
    detail = f'The claim for CPT code {cpts[0]} is approved.' if cpts else 'Meets policy criteria.'
    return {'decision':'APPROVE','reason': detail}


In [None]:

TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]

AGENT_SYS_TEXT = (
    "Claims agent. Always use tools in this order per record:\n"
    "1) summarize_patient_record (JSON string)\n"
    "2) summarize_policy_guideline (policy_id)\n"
    "3) check_claim_coverage (record, policy)\n"
    "Return ONLY:\n"
    "- Decision: APPROVE | ROUTE FOR REVIEW\n"
    "- Reason: <short reason>"
)

try:
    agent = create_react_agent(chat_client, TOOLS)
except NameError:
    raise RuntimeError("Expected an existing `chat_client` instance. Define it before running this cell.")


## Quick self-tests for policy index and tools

In [None]:

val = _safe_load(f'{DATA_DIR}/validation_records.json', [])
samples = val[:2] if isinstance(val, list) and val else [{
    "patient_id": "PX1",
    "date_of_birth": "1980-01-01",
    "gender": "Female",
    "insurance_policy_id": next(iter(POLICY_INDEX.keys())) if POLICY_INDEX else "POL0000",
    "diagnosis_codes": ["I20.0"],
    "procedure_codes": ["83036"],
    "preauthorization_obtained": True
}]

print("[test:index] policies:", len(POLICY_INDEX))
rec_str = json.dumps(samples[0], separators=(',',':'))
ps = summarize_patient_record.invoke({"record_str": rec_str})
print("[test:patient] ->", ps)
if ps.get("policy_id"):
    pol = summarize_policy_guideline.invoke({"policy_id": ps["policy_id"]})
    print("[test:policy] ->", pol)
    cov = check_claim_coverage.invoke({"record_summary": ps, "policy_summary": pol})
    print("[test:coverage] ->", cov)


In [None]:

def run_agent_on_records(agent, records):
    results = []
    for rec in records:
        rec_str = json.dumps(rec, ensure_ascii=False, separators=(',',':'))
        messages = [
            SystemMessage(content=AGENT_SYS_TEXT),
            HumanMessage(content=f"Process this record:\n{rec_str}")
        ]
        out = agent.invoke({"messages": messages})
        final = (
            out["messages"][-1].content
            if isinstance(out, dict) and "messages" in out
            else getattr(out, "content", str(out))
        )
        results.append({
            "patient_id": rec.get("patient_id") or rec.get("id"),
            "generated_response": final
        })
    return results

test_records = _safe_load(TEST_PATH, [])
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
if test_records:
    test_results = run_agent_on_records(agent, test_records)
    with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
        w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
        w.writeheader(); w.writerows(test_results)
    print(f"Wrote {SUBMISSION_PATH} with {len(test_results)} rows.")
else:
    with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
        w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
        w.writeheader()
    print("No test records found.")
