# Capstone — Healthcare Claim Coverage Agent (Rebuild with Tool Docstrings)
- One agent + three tools (ReAct) using your existing `chat_client`
- Loads `./Data/validation_records.json` (preview) and `./Data/test_records.json` (final)
- Writes `./Data/submission.csv`


In [None]:
from __future__ import annotations
from typing import Dict, Any, List, Optional, Tuple
import os, json, csv, re
from datetime import datetime

# LangChain / LangGraph
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.prebuilt import create_react_agent

# ----------------- Paths -----------------
DATA_DIR = './Data'
REFERENCE_CODES_PATH = os.path.join(DATA_DIR, 'reference_codes.json')
POLICIES_PATH        = os.path.join(DATA_DIR, 'insurance_policies.json')
VALIDATION_PATH      = os.path.join(DATA_DIR, 'validation_records.json')
TEST_PATH            = os.path.join(DATA_DIR, 'test_records.json')
SUBMISSION_PATH      = os.path.join(DATA_DIR, 'submission.csv')


In [None]:
def _safe_load(path, default):
    """Load JSON from `path`, print a warning, and return `default` on any error."""
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"⚠️ Could not load {path}: {e}")
        return default

reference_codes = _safe_load(REFERENCE_CODES_PATH, {})
policies_raw    = _safe_load(POLICIES_PATH, {})
val_records     = _safe_load(VALIDATION_PATH, [])
test_records    = _safe_load(TEST_PATH, [])

ICD_TO_NAME = reference_codes.get('diagnosis_codes', {}) or reference_codes.get('icd10', {})
CPT_TO_NAME = reference_codes.get('procedure_codes', {}) or reference_codes.get('cpt', {})

print(f"Loaded: reference_codes={bool(reference_codes)}, policies={bool(policies_raw)}, "
      f"validation_records={len(val_records)}, test_records={len(test_records)}")


In [None]:
# ---------- Helpers & Policy Index ----------
DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    """Try multiple formats + compact YYYYMMDD and return a datetime or None."""
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try: return datetime.strptime(s, pat)
        except Exception: pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try: return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception: return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    """Compute whole-year age at reference date (or today)."""
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    """Return 'M', 'F', or 'U' (unknown) from various inputs."""
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    """Normalize ICD-like codes by inserting a dot after 3 chars when missing."""
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def _norm_pid(pid: Optional[str]) -> Optional[str]:
    """Normalize policy id by removing dashes/underscores/spaces and uppercasing."""
    if not pid: return None
    return re.sub(r'[-_\s]', '', str(pid)).upper()

def _extract_rules(obj: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Extract and normalize rule dicts from a single policy object."""
    rules = []
    items = obj.get('covered_procedures') or []
    if isinstance(items, list):
        for it in items:
            if not isinstance(it, dict): continue
            proc = str(it.get('procedure_code') or it.get('procedure') or '').strip()
            if not proc: continue
            diags = it.get('covered_diagnoses') or it.get('diagnoses') or []
            if isinstance(diags, str):
                diags = [d for d in re.split(r'[;,.\s]+', diags) if d]
            diags = [icd_dot_normalize(d) for d in diags]
            ar = it.get('age_range') or []
            age_min = None; age_max = None
            if isinstance(ar, list) and len(ar) >= 2:
                try: age_min = int(ar[0])
                except Exception: pass
                try: age_max = int(ar[1])
                except Exception: pass
            sex_raw = str(it.get('gender') or it.get('sex') or 'Any').strip().upper()
            sex_rule = 'ANY'
            if sex_raw.startswith('M'): sex_rule = 'M'
            elif sex_raw.startswith('F'): sex_rule = 'F'
            preauth_required = bool(it.get('requires_preauthorization') or it.get('preauth_required') or False)
            rules.append({
                'procedure': proc,
                'diagnoses': diags,
                'age_min': age_min,
                'age_max': age_max,
                'sex': sex_rule,
                'preauth_required': preauth_required
            })
    return rules

def build_policy_index(raw) -> Dict[str, Dict[str, Any]]:
    """Normalize raw policies (list or dict) into an index keyed by policy_id."""
    index = {}
    if isinstance(raw, dict) and 'policies' in raw and isinstance(raw['policies'], list):
        src_list = raw['policies']
    elif isinstance(raw, dict):
        src_list = []
        for k, v in raw.items():
            if isinstance(v, dict):
                v = dict(v)
                v.setdefault('policy_id', v.get('policyId') or v.get('id') or k)
                src_list.append(v)
    elif isinstance(raw, list):
        src_list = raw
    else:
        src_list = []
    for obj in src_list:
        if not isinstance(obj, dict): continue
        pid = _norm_pid(obj.get('policy_id') or obj.get('policyId') or obj.get('id') or obj.get('code') or obj.get('policy_code'))
        if not pid: continue
        index[pid] = {
            'policy_id': pid,
            'plan_name': obj.get('plan_name') or obj.get('title') or '',
            'rules': _extract_rules(obj)
        }
    return index

POLICY_INDEX = build_policy_index(policies_raw)
print(f"Policy index contains {len(POLICY_INDEX)} policies.")


In [None]:
# ---------- Tools ----------

@tool
def summarize_patient_record(record_str: str) -> Dict[str, Any]:
    """
    Normalize a raw patient claim record.
    
    Args:
        record_str: JSON string for a single patient/claim record.
    Returns:
        Dict containing: patient_id, policy_id, age, sex, diagnoses, procedures, preauth.
        - Applies tolerant date parsing and age calculation (if age missing).
        - Normalizes sex and ICD/CPT codes; filters to known reference codes when available.
    """
    try:
        rec = json.loads(record_str)
    except Exception:
        return {'error': 'invalid JSON'}

    age = rec.get('age')
    if age is None:
        dob = rec.get('date_of_birth') or rec.get('dob')
        svc = rec.get('date_of_service') or rec.get('service_date') or rec.get('claim_date')
        if dob:
            age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))

    sex = sex_normalize(rec.get('gender') or rec.get('sex'))

    dx  = rec.get('diagnosis_codes') or rec.get('diagnoses') or []
    dx  = [icd_dot_normalize(d) for d in dx]
    if ICD_TO_NAME:
        dx = [d for d in dx if d in ICD_TO_NAME]

    cpt = rec.get('procedure_codes') or rec.get('procedures') or []
    cpt = [str(x) for x in cpt]
    if CPT_TO_NAME:
        keep = [c for c in cpt if c in CPT_TO_NAME]
        cpt = keep if keep else [c for c in cpt if re.fullmatch(r'\d{5}', c)]

    preauth = rec.get('preauthorization_obtained') or rec.get('preauth_provided') or rec.get('authorization_provided')
    preauth = bool(preauth)

    return {
        'patient_id': rec.get('patient_id') or rec.get('id'),
        'policy_id': _norm_pid(rec.get('insurance_policy_id') or rec.get('policy_id') or rec.get('plan_id')),
        'age': age,
        'sex': sex,
        'diagnoses': dx,
        'procedures': cpt,
        'preauth': preauth
    }

@tool
def summarize_policy_guideline(policy_id: str) -> Dict[str, Any]:
    """
    Load normalized rules for a given policy.
    
    Args:
        policy_id: Policy identifier from the claim record.
    Returns:
        Dict with 'policy_id' and 'rules' list (procedure, diagnoses, age bounds, sex, preauth_required).
        Returns {'error': '...'} when the policy is unknown.
    """
    pid = _norm_pid(policy_id)
    pol = POLICY_INDEX.get(pid)
    if not pol:
        return {'error': f'Unknown policy_id {policy_id}'}
    return {'policy_id': pid, 'rules': pol.get('rules', [])}

def _rule_covers(record: Dict[str, Any], rule: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    """
    Check whether a single policy rule covers this record.
    
    Returns (True, None) when covered; otherwise (False, reason).
    """
    cpts = set(record.get('procedures', []))
    dxs  = set(record.get('diagnoses', []))
    age  = record.get('age')
    sex  = (record.get('sex') or 'U').upper()
    pre  = bool(record.get('preauth', False))

    if rule['procedure'] not in cpts:
        return False, None

    rdx = set(rule.get('diagnoses', []))
    if rdx and not (dxs & rdx):
        return False, f"Diagnosis mismatch for CPT {rule['procedure']}."

    if age is not None:
        if rule.get('age_min') is not None and age < rule['age_min']:
            return False, f"Patient age {age} below covered minimum for CPT {rule['procedure']}."
        if rule.get('age_max') is not None and age > rule['age_max']:
            return False, f"Patient age {age} exceeds covered maximum for CPT {rule['procedure']}."

    if rule.get('sex') in {'M','F'} and sex != rule['sex']:
        return False, f"Sex restriction {rule['sex']} for CPT {rule['procedure']}."

    if rule.get('preauth_required') and not pre:
        return False, f"Preauthorization required for CPT {rule['procedure']} but not provided."

    return True, None

@tool
def check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    """
    Decide APPROVE vs ROUTE FOR REVIEW for a claim using policy rules.
    
    Args:
        record_summary: Output of summarize_patient_record.
        policy_summary: Output of summarize_policy_guideline.
    Returns:
        Dict with keys:
          - decision: 'APPROVE' | 'ROUTE FOR REVIEW'
          - reason: concise reason
    """
    if policy_summary.get('error'):
        return {'decision':'ROUTE FOR REVIEW','reason':'Unknown policy_id.'}

    rules = policy_summary.get('rules', [])
    if not rules:
        return {'decision':'ROUTE FOR REVIEW','reason':'Policy does not define covered procedures.'}

    for cpt in record_summary.get('procedures', []):
        cand = [r for r in rules if r.get('procedure') == cpt]
        if not cand:
            return {'decision':'ROUTE FOR REVIEW','reason': f'The claim for CPT code {cpt} is not covered by the policy.'}
        ok_any = False; first_reason = None
        for r in cand:
            ok, why = _rule_covers(record_summary, r)
            if ok:
                ok_any = True; break
            if first_reason is None: first_reason = why
        if not ok_any:
            return {'decision':'ROUTE FOR REVIEW','reason': first_reason or f'The claim for CPT code {cpt} does not meet policy requirements.'}

    cpts = record_summary.get('procedures', [])
    detail = f'The claim for CPT code {cpts[0]} is approved.' if cpts else 'Meets policy criteria.'
    return {'decision':'APPROVE','reason': detail}


In [None]:
# ---------- Agent ----------
TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]

AGENT_SYS_TEXT = (
    "You are an insurance claims agent.\n"
    "You MUST always call tools in this order for each record:\n"
    "  1) summarize_patient_record -> on the raw JSON string of the record\n"
    "  2) summarize_policy_guideline -> with the policy_id from step 1\n"
    "  3) check_claim_coverage -> with outputs of steps 1 & 2\n"
    "When you have enough information, output the FINAL answer in exactly this format:\n"
    "- Decision: APPROVE | ROUTE FOR REVIEW\n"
    "- Reason: \n"
    "Rules:\n"
    "- Do not invent policy data. If policy_id is missing/unknown, route for review.\n"
    "- Keep the reason short and specific.\n"
)

try:
    agent = create_react_agent(chat_client, TOOLS)
except NameError:
    raise RuntimeError("Expected an existing `chat_client` instance. Define it before running this cell.")

def run_agent_on_records(agent, records, preview_n: int = 3):
    """Run the agent on a list of raw dict records. Returns list of {patient_id, generated_response}."""
    results = []
    for i, rec in enumerate(records):
        rec_str = json.dumps(rec, ensure_ascii=False)
        messages = [
            SystemMessage(content=AGENT_SYS_TEXT),
            HumanMessage(content=f"Process this record:\n{rec_str}")
        ]
        out = agent.invoke({"messages": messages})
        final = (
            out["messages"][-1].content
            if isinstance(out, dict) and "messages" in out
            else getattr(out, "content", str(out))
        )
        if i < preview_n:
            print(f"[Validation preview] {rec.get('patient_id') or rec.get('id')}: {final}")
        results.append({
            "patient_id": rec.get("patient_id") or rec.get("id"),
            "generated_response": final
        })
    return results


In [None]:
# ---------- Validation run (uses validation_records.json) ----------
if val_records:
    _ = run_agent_on_records(agent, val_records, preview_n=5)
else:
    print("No validation records found.")

# ---------- Test run -> submission.csv ----------
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
if test_records:
    test_results = run_agent_on_records(agent, test_records, preview_n=0)
    with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
        w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
        w.writeheader(); w.writerows(test_results)
    print(f"Wrote {SUBMISSION_PATH} with {len(test_results)} rows.")
else:
    print("No test records found.")
