# Capstone 2 — Claims Agent (LangGraph ReAct) with Robust Policy Index + DEBUG
One agent + three tools. Now resolves policy IDs from dict or list files with flexible normalization.

In [None]:
# Capstone 2 — Claims Agent (LangGraph ReAct) with Robust Policy Index + DEBUG
# ---------------------------------------------------------------------------
# One agent + three tools, same as before, but now:
# - Builds POLICY_INDEX from dict *or* list sources
# - Normalizes policy ids by uppercasing and removing hyphens/spaces/underscores
# - Keeps DEBUG logging of tool calls and LLM replies
#
# Define your model before running:
#   from langchain_openai import ChatOpenAI
#   chat_client = ChatOpenAI(model="gpt-4o-mini", temperature=0)

from __future__ import annotations
from typing import Annotated, TypedDict, Dict, Any, Optional, List
import json, csv, os, re, textwrap
from datetime import datetime

# -------- LangChain / LangGraph imports --------
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, tools_condition

# ----------------------------- Config -----------------------------
DEBUG = True
DEBUG_TRUNC = 800
VALIDATION_DEBUG_N = 3

def _dbg(title: str, payload: Any = None):
    if not DEBUG: return
    print(f"[DEBUG] {title}")
    if payload is None: return
    try:
        s = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False, indent=2)
    except Exception:
        s = str(payload)
    if len(s) > DEBUG_TRUNC: s = s[:DEBUG_TRUNC] + " ... [truncated]"
    print(textwrap.indent(s, "  "))

# ----------------------------- Paths -----------------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

# ------------------------- Safe JSON Loading ----------------------
def _safe_load_json(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"⚠️ Missing file: {path}. Using default."); return default
    except json.JSONDecodeError as e:
        print(f"⚠️ JSON parse error in {path}: {e}. Using default."); return default

policies_raw  = _safe_load_json(POLICIES_PATH, {})
ref_codes     = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME   = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {}) or {}
CPT_TO_NAME   = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {}) or {}

# ------------------------------ Helpers ---------------------------
def _norm_pid(pid: Optional[str]) -> Optional[str]:
    if not pid: return None
    s = re.sub(r'[-_\s]', '', str(pid)).upper()  # remove -, _, spaces; uppercase
    return s

DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try: return datetime.strptime(s, pat)
        except Exception: pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try: return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception: return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def coerce_int(val) -> Optional[int]:
    try: return int(val)
    except Exception:
        if val is None: return None
        m = re.search(r'\d+', str(val)); return int(m.group()) if m else None

def walk_items(obj):
    if isinstance(obj, dict):
        for k, v in obj.items():
            yield k, v
            for subk, subv in walk_items(v):
                yield subk, subv
    elif isinstance(obj, list):
        for it in obj:
            for subk, subv in walk_items(it):
                yield subk, subv

def deep_get_first(d: Any, keys: List[str]):
    keys_l = {k.lower() for k in keys}
    for k, v in walk_items(d):
        if isinstance(k, str) and k.lower() in keys_l:
            return v
    return None

def collect_mixed(val):
    out = []
    if val is None: return out
    if isinstance(val, list):
        for it in val:
            if isinstance(it, dict):
                for k in ['code','id','icd','cpt']:
                    if it.get(k) is not None: out.append(str(it[k]))
                for k in ['name','desc','description','text']:
                    if it.get(k): out.append(str(it[k]))
            else:
                out.append(str(it))
    elif isinstance(val, dict):
        for k in ['code','id','icd','cpt']:
            if val.get(k) is not None: out.append(str(val[k]))
        for k in ['name','desc','description','text']:
            if val.get(k): out.append(str(val[k]))
    else:
        out.append(str(val))
    return out

# ------------------------- Build robust POLICY_INDEX ---------------
def _build_policy_index(raw):
    index = {}
    # dict root: keys may be ids
    if isinstance(raw, dict):
        for k, v in raw.items():
            nk = _norm_pid(k)
            if nk: index[nk] = v
            # also consider inner id fields
            if isinstance(v, dict):
                inner = v.get('policy_id') or v.get('policyId') or v.get('id') or v.get('code')
                ni = _norm_pid(inner)
                if ni: index.setdefault(ni, v)
    # list root: each item is an object
    elif isinstance(raw, list):
        for obj in raw:
            if not isinstance(obj, dict): continue
            candidates = [
                obj.get('policy_id'), obj.get('policyId'), obj.get('id'),
                obj.get('code'), obj.get('policy_code')
            ]
            for c in candidates:
                nc = _norm_pid(c)
                if nc:
                    index[nc] = obj
                    break
    return index

POLICY_INDEX = _build_policy_index(policies_raw)
_dbg("Policy index keys (sample)", list(POLICY_INDEX.keys())[:10])

# ------------------------------- TOOLS ----------------------------
@tool
def summarize_patient_record(record_str_or_json: str) -> str:
    """Extract patient_id, policy_id, age, sex, diagnoses (ICD-10), procedures (CPT), preauth flag from JSON/text. Returns a JSON string."""
    _dbg("CALL summarize_patient_record.input", record_str_or_json)
    try:
        rec = json.loads(record_str_or_json)
    except Exception:
        rec = None
    text = record_str_or_json if rec is None else json.dumps(rec, ensure_ascii=False)

    patient_id = None
    policy_id = None
    age = None
    sex = 'U'
    preauth_provided = False
    dx_codes, cpt_codes = set(), set()

    if isinstance(rec, dict):
        patient_id = (rec.get('patient_id') or rec.get('id') or rec.get('member_id') or rec.get('patientId') or rec.get('pid'))
        raw_pid = (rec.get('policy_id') or rec.get('policy') or rec.get('plan_id')
                   or rec.get('insurancePolicyId') or deep_get_first(rec, ['insurance_policy_id']))
        policy_id = _norm_pid(raw_pid)
        age = coerce_int(rec.get('age') or deep_get_first(rec, ['age','age_years','patient_age']))
        if age is None:
            dob = (rec.get('dob') or rec.get('date_of_birth') or rec.get('birthdate')
                   or deep_get_first(rec, ['dob','date_of_birth','birthdate']))
            svc = (rec.get('service_date') or rec.get('claim_date') or rec.get('date_of_service')
                   or deep_get_first(rec, ['service_date','claim_date','date_of_service']))
            age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))
        sex = sex_normalize(rec.get('sex') or rec.get('gender') or deep_get_first(rec, ['sex','gender','patient_sex']))
        preauth_val = (rec.get('preauth_provided') or deep_get_first(rec, ['preauth','preauthorization','authorization','prior_auth','prior_authorization','preauth_provided']))
        if preauth_val is not None:
            preauth_provided = str(preauth_val).strip().lower() in {'1','y','yes','true','approved','provided'}
        raw_dx  = (rec.get('diagnosis_codes') or rec.get('diagnoses') or rec.get('dx') or rec.get('icd') or rec.get('icd_codes')
                   or deep_get_first(rec, ['diagnosis_codes','diagnoses','diagnosis','dx','icd','icd_codes']))
        raw_cpt = (rec.get('procedure_codes') or rec.get('procedures') or rec.get('cpt') or rec.get('cpt_codes') or rec.get('services')
                   or deep_get_first(rec, ['procedure_codes','procedures','procedure','cpt','cpt_codes','services']))
        for s in collect_mixed(raw_dx):  dx_codes.add(icd_dot_normalize(s))
        for s in collect_mixed(raw_cpt): cpt_codes.add(str(s))

    # Regex fallback for policy id (normalize too)
    m = re.search(r'\bpolicy[_\s-]*id[:\s-]*([A-Za-z0-9_-]+)', text, re.I)
    if m and not policy_id:
        pre = text[:m.start()]
        if not re.search(r'patient\s*$', pre, re.I):
            policy_id = _norm_pid(m.group(1))

    dx_codes |= {icd_dot_normalize(s) for s in re.findall(r'\b([A-Z][0-9][A-Z0-9.\-]{1,6})\b', text)}
    cpt_codes |= set(re.findall(r'\b(\d{5})\b', text))

    if ICD_TO_NAME: dx_codes = {c for c in dx_codes if c in ICD_TO_NAME}
    if CPT_TO_NAME:
        known = {c for c in cpt_codes if c in CPT_TO_NAME}
        cpt_codes = known if known else {c for c in cpt_codes if re.fullmatch(r'\d{5}', c)}

    out = {
        'patient_id': patient_id,
        'age': age,
        'sex': sex,
        'diagnoses': sorted(dx_codes),
        'procedures': sorted(cpt_codes),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    }
    _dbg("CALL summarize_patient_record.output", out)
    return json.dumps(out)

@tool
def summarize_policy_guideline(policy_id: str) -> str:
    """Look up an insurance policy guideline by policy_id (case-insensitive; dashes/spaces ignored). Returns criteria JSON."""
    _dbg("CALL summarize_policy_guideline.input", policy_id)
    pid = _norm_pid(policy_id)
    pol = POLICY_INDEX.get(pid)
    if not pol:
        out = {'error': f"Unknown policy_id '{policy_id}'"}
        _dbg("CALL summarize_policy_guideline.output", out)
        return json.dumps(out)
    crit = pol.get('criteria', {}) or {}
    out = {
        'policy_id': pid,
        'allowed_diagnoses': crit.get('diagnoses', []) or [],
        'allowed_procedures': crit.get('procedures', []) or [],
        'age_min': crit.get('age_min'),
        'age_max': crit.get('age_max'),
        'sex': crit.get('sex'),
        'preauth_required': bool(crit.get('preauth_required', False)),
        'policy_title': pol.get('title', ''),
        'notes': pol.get('notes', '')
    }
    _dbg("CALL summarize_policy_guideline.output", out)
    return json.dumps(out)

@tool
def check_claim_coverage(record_summary: str, policy_summary: str) -> str:
    """Compare a patient record summary against a policy summary and return a decision JSON."""
    _dbg("CALL check_claim_coverage.input.record_summary", record_summary)
    _dbg("CALL check_claim_coverage.input.policy_summary", policy_summary)
    try:
        rs = json.loads(record_summary) if isinstance(record_summary, str) else record_summary
        ps = json.loads(policy_summary) if isinstance(policy_summary, str) else policy_summary
    except Exception as e:
        out = {'decision':'ROUTE FOR REVIEW','reason': f'Malformed inputs: {e}'}
        _dbg("CALL check_claim_coverage.output", out)
        return json.dumps(out)

    missing = []
    if not rs.get('policy_id'): missing.append('policy_id')
    if (ps.get('allowed_diagnoses')) and not rs.get('diagnoses'): missing.append('diagnoses')
    if (ps.get('allowed_procedures')) and not rs.get('procedures'): missing.append('procedures')
    if (ps.get('age_min') is not None or ps.get('age_max') is not None) and rs.get('age') is None:
        missing.append('age')
    if ps.get('sex') in {'M','F'} and not rs.get('sex'): missing.append('sex')
    if missing:
        out = {'decision':'ROUTE FOR REVIEW','reason': f"Missing/insufficient fields: {', '.join(missing)}."}
        _dbg("CALL check_claim_coverage.output", out)
        return json.dumps(out)

    age = coerce_int(rs.get('age'))
    sex = str(rs.get('sex','U')).upper()
    dxs  = {d.upper() for d in rs.get('diagnoses',[])}
    cpts = set(rs.get('procedures',[]))

    allowed_dx  = {d.upper() for d in ps.get('allowed_diagnoses',[])}
    allowed_cpt = set(ps.get('allowed_procedures',[]))
    age_min = ps.get('age_min'); age_max = ps.get('age_max')
    sex_rule = ps.get('sex')
    preauth_required = bool(ps.get('preauth_required', False))
    preauth_provided = bool(rs.get('preauth_provided', False))

    reasons = []
    if allowed_cpt and not (cpts & allowed_cpt): reasons.append('Claimed procedure not covered.')
    if allowed_dx and not (dxs & allowed_dx):   reasons.append('Diagnosis not covered.')
    if age is not None:
        if age_min is not None and age < age_min: reasons.append(f'Patient age {age} is below {age_min}.')
        if age_max is not None and age > age_max: reasons.append(f'Patient age {age} exceeds {age_max}.')
    if sex_rule in {'M','F'} and sex != sex_rule: reasons.append(f'Policy restricted to sex {sex_rule}.')
    if preauth_required and not preauth_provided: reasons.append('Preauthorization required but not provided.')

    if reasons:
        out = {'decision':'ROUTE FOR REVIEW','reason': '; '.join(reasons)[:500]}
        _dbg("CALL check_claim_coverage.output", out)
        return json.dumps(out)
    out = {'decision':'APPROVE','reason':'Meets policy criteria.'}
    _dbg("CALL check_claim_coverage.output", out)
    return json.dumps(out)

# --------------------------- Bind tools to LLM ---------------------
# IMPORTANT: Define `chat_client` before running.
llm_with_tools = chat_client.bind_tools(tools=[
    summarize_patient_record,
    summarize_policy_guideline,
    check_claim_coverage
])

# --------------------------- Agent Prompt --------------------------
AGENT_PROMPT_TXT = """You are an expert claims agent for health insurance. Use ReAct: reason, call tools, observe, continue.

Strict sequence for each record:
  1) Call summarize_patient_record on the raw record JSON (stringified).
  2) Call summarize_policy_guideline using the policy_id from step 1 (or the record if needed).
  3) Call check_claim_coverage with the outputs from steps 1 & 2.
When you have enough information, output exactly one final line:
  "Decision: <APPROVE|ROUTE FOR REVIEW>. Reason: <short reason>"

Rules:
- Never fabricate data. If information is missing or criteria cannot be verified, route for review and specify why.
- Keep tool inputs minimal JSON. Keep the final reason short and specific.
"""
AGENT_SYS_PROMPT = SystemMessage(content=AGENT_PROMPT_TXT)

# ----------------------- LangGraph State & Node --------------------
class State(TypedDict):
    messages: Annotated[list, add_messages]

def tool_calling_llm(state: State) -> State:
    history = state['messages']
    state_with_instructions = [AGENT_SYS_PROMPT] + history
    response = [llm_with_tools.invoke(state_with_instructions)]
    _dbg("LLM reply", getattr(response[0], 'content', str(response[0])))
    return {'messages': response}

# Build the graph
graph = StateGraph(State)
graph.add_node('tool_calling_llm', tool_calling_llm)
graph.add_node('tools', ToolNode(tools=[
    summarize_patient_record,
    summarize_policy_guideline,
    check_claim_coverage
]))

graph.add_edge(START, 'tool_calling_llm')
graph.add_conditional_edges('tool_calling_llm', tools_condition, ['tools', END])
graph.add_edge('tools', 'tool_calling_llm')

claims_agent = graph.compile()

# ----------------------------- IO helpers --------------------------
def _load_records(path):
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data: return data['records']
    if isinstance(data, list): return data
    print(f'⚠️ Unexpected JSON root at {path}; expecting list or dict.records.'); return []

def run_on_record(record: Dict[str, Any]) -> Dict[str, str]:
    record_json = json.dumps(record.get('record_str', record), ensure_ascii=False)
    messages = [HumanMessage(content=(
        "Process this claim by following the 3-tool sequence and return exactly one line at the end.\n"
        "Record JSON:\n" + record_json
    ))]
    result = claims_agent.invoke({'messages': messages})
    final_msg = result['messages'][-1]
    final_text = final_msg.content if hasattr(final_msg, 'content') else str(final_msg)
    return {'patient_id': record.get('patient_id') or record.get('id'), 'generated_response': final_text.strip()}

# ------------------------- Run & Write Output ----------------------
val = _load_records(VALIDATION_PATH)
if val:
    print('Validation sample (DEBUG ON):')
    for rec in val[:VALIDATION_DEBUG_N]:
        _dbg("--- VALIDATION RECORD ---", rec)
        out = run_on_record(rec); print(out['patient_id'], '->', out['generated_response'])
else:
    print('No validation records found.')

test = _load_records(TEST_PATH)
rows = [run_on_record(r) for r in test]
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
