# Capstone 2 — LangGraph ReAct Agent (sample-aligned)
Uses your exact agent pattern (StateGraph + ToolNode) with 3 domain tools.

In [None]:
# Capstone 2 — LangGraph ReAct Agent (based on your sample), One Agent + Three Tools
# ----------------------------------------------------------------------------------
# This notebook mirrors your provided sample structure (StateGraph + ToolNode + tools_condition)
# but adapts the domain to insurance claims with exactly three tools.
#
# Assumptions:
# - You already created `chat_client` (authenticated LLM) in the kernel.
# - All data files live under ./Data:
#     validation_records.json, test_records.json, insurance_policies.json, reference_codes.json
# - Output written to ./Data/submission.csv
#
# Tools:
#   1) summarize_patient_record(record_str_or_json) -> JSON string
#   2) summarize_policy_guideline(policy_id) -> JSON string
#   3) check_claim_coverage(record_summary, policy_summary) -> JSON string

from __future__ import annotations
from typing import Annotated, TypedDict, Dict, Any, Optional, List, Iterable
import json, csv, os, re
from datetime import datetime

# ---- LangChain / LangGraph imports (same style as your sample) ----
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, tools_condition

# ----------------------------- Paths -----------------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

# ------------------------- Safe JSON Loading ----------------------
def _safe_load_json(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"⚠️ Missing file: {path}. Using default."); return default
    except json.JSONDecodeError as e:
        print(f"⚠️ JSON parse error in {path}: {e}. Using default."); return default

policies  = _safe_load_json(POLICIES_PATH, {})
ref_codes = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {}) or {}
CPT_TO_NAME = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {}) or {}

# ------------------------------ Helpers ---------------------------
DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try:
            return datetime.strptime(s, pat)
        except Exception:
            pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try:
            return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception:
            return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def coerce_int(val) -> Optional[int]:
    try: return int(val)
    except Exception:
        if val is None: return None
        m = re.search(r'\d+', str(val)); return int(m.group()) if m else None

def walk_items(obj):
    if isinstance(obj, dict):
        for k, v in obj.items():
            yield k, v
            for subk, subv in walk_items(v):
                yield subk, subv
    elif isinstance(obj, list):
        for it in obj:
            for subk, subv in walk_items(it):
                yield subk, subv

def deep_get_first(d: Any, keys: List[str]):
    keys_l = {k.lower() for k in keys}
    for k, v in walk_items(d):
        if isinstance(k, str) and k.lower() in keys_l:
            return v
    return None

def collect_mixed(val):
    out = []
    if val is None: return out
    if isinstance(val, list):
        for it in val:
            if isinstance(it, dict):
                for k in ['code','id','icd','cpt']:
                    if it.get(k) is not None: out.append(str(it[k]))
                for k in ['name','desc','description','text']:
                    if it.get(k): out.append(str(it[k]))
            else:
                out.append(str(it))
    elif isinstance(val, dict):
        for k in ['code','id','icd','cpt']:
            if val.get(k) is not None: out.append(str(val[k]))
        for k in ['name','desc','description','text']:
            if val.get(k): out.append(str(val[k]))
    else:
        out.append(str(val))
    return out

# ------------------------------- TOOLS ----------------------------
@tool
def summarize_patient_record(record_str_or_json: str) -> str:
    """Extract patient_id, policy_id, age, sex, diagnoses (ICD-10), procedures (CPT, 5-digit), preauth flag from a claim JSON/text. Returns a JSON string."""
    try:
        rec = json.loads(record_str_or_json)
    except Exception:
        rec = None
    text = record_str_or_json if rec is None else json.dumps(rec, ensure_ascii=False)

    patient_id = None
    policy_id = None
    age = None
    sex = 'U'
    preauth_provided = False
    dx_codes, cpt_codes = set(), set()

    if isinstance(rec, dict):
        patient_id = (rec.get('patient_id') or rec.get('id') or rec.get('member_id') or rec.get('patientId') or rec.get('pid'))
        policy_id = (rec.get('policy_id') or rec.get('policy') or rec.get('plan_id')
                     or rec.get('insurancePolicyId') or deep_get_first(rec, ['insurance_policy_id']))
        age = coerce_int(rec.get('age') or deep_get_first(rec, ['age','age_years','patient_age']))
        if age is None:
            dob = (rec.get('dob') or rec.get('date_of_birth') or rec.get('birthdate')
                   or deep_get_first(rec, ['dob','date_of_birth','birthdate']))
            svc = (rec.get('service_date') or rec.get('claim_date') or rec.get('date_of_service')
                   or deep_get_first(rec, ['service_date','claim_date','date_of_service']))
            age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))
        sex = sex_normalize(rec.get('sex') or rec.get('gender') or deep_get_first(rec, ['sex','gender','patient_sex']))
        preauth_val = (rec.get('preauth_provided') or deep_get_first(rec, ['preauth','preauthorization','authorization','prior_auth','prior_authorization','preauth_provided']))
        if preauth_val is not None:
            preauth_provided = str(preauth_val).strip().lower() in {'1','y','yes','true','approved','provided'}
        raw_dx  = (rec.get('diagnosis_codes') or rec.get('diagnoses') or rec.get('dx') or rec.get('icd') or rec.get('icd_codes')
                   or deep_get_first(rec, ['diagnosis_codes','diagnoses','diagnosis','dx','icd','icd_codes']))
        raw_cpt = (rec.get('procedure_codes') or rec.get('procedures') or rec.get('cpt') or rec.get('cpt_codes') or rec.get('services')
                   or deep_get_first(rec, ['procedure_codes','procedures','procedure','cpt','cpt_codes','services']))
        for s in collect_mixed(raw_dx):  dx_codes.add(icd_dot_normalize(s))
        for s in collect_mixed(raw_cpt): cpt_codes.add(str(s))

    m = re.search(r'\bpatient[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', text, re.I)
    if m and not patient_id: patient_id = m.group(1)
    m = re.search(r'\bpolicy[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', text, re.I)
    if m and not policy_id:
        pre = text[:m.start()]
        if not re.search(r'patient\s*$', pre, re.I):
            policy_id = m.group(1)
    m = re.search(r'\bage[:\s-]*([0-9]{1,3})\b', text, re.I)
    if m and age is None: age = int(m.group(1))
    m = re.search(r'\b(sex|gender)[:\s-]*([A-Za-z]+)\b', text, re.I)
    if m and sex == 'U': sex = sex_normalize(m.group(2))
    if re.search(r'pre-?auth(?:orization)?[:\s-]*(yes|y|true|provided|approved)', text, re.I):
        preauth_provided = True

    dx_codes |= {icd_dot_normalize(s) for s in re.findall(r'\b([A-Z][0-9][A-Z0-9.\-]{1,6})\b', text)}
    cpt_codes |= set(re.findall(r'\b(\d{5})\b', text))

    if ICD_TO_NAME:
        dx_codes = {c for c in dx_codes if (c in ICD_TO_NAME)}
    if CPT_TO_NAME:
        known = {c for c in cpt_codes if c in CPT_TO_NAME}
        cpt_codes = known if known else {c for c in cpt_codes if re.fullmatch(r'\d{5}', c)}

    return json.dumps({
        'patient_id': patient_id,
        'age': age,
        'sex': sex,
        'diagnoses': sorted(dx_codes),
        'procedures': sorted(cpt_codes),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    })

@tool
def summarize_policy_guideline(policy_id: str) -> str:
    """Look up an insurance policy guideline by policy_id and return its criteria as a JSON string."""
    pol = policies.get(str(policy_id)) if isinstance(policies, dict) else None
    if not pol:
        return json.dumps({'error': f"Unknown policy_id '{policy_id}'"})
    crit = pol.get('criteria', {}) or {}
    return json.dumps({
        'policy_id': policy_id,
        'allowed_diagnoses': crit.get('diagnoses', []) or [],
        'allowed_procedures': crit.get('procedures', []) or [],
        'age_min': crit.get('age_min'),
        'age_max': crit.get('age_max'),
        'sex': crit.get('sex'),
        'preauth_required': bool(crit.get('preauth_required', False)),
        'policy_title': pol.get('title', ''),
        'notes': pol.get('notes', '')
    })

@tool
def check_claim_coverage(record_summary: str, policy_summary: str) -> str:
    """Compare a patient record summary against a policy summary and return a decision JSON."""
    try:
        rs = json.loads(record_summary) if isinstance(record_summary, str) else record_summary
        ps = json.loads(policy_summary) if isinstance(policy_summary, str) else policy_summary
    except Exception as e:
        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': f'Malformed inputs: {e}'})

    missing = []
    if not rs.get('policy_id'): missing.append('policy_id')
    if (ps.get('allowed_diagnoses')) and not rs.get('diagnoses'): missing.append('diagnoses')
    if (ps.get('allowed_procedures')) and not rs.get('procedures'): missing.append('procedures')
    if (ps.get('age_min') is not None or ps.get('age_max') is not None) and rs.get('age') is None:
        missing.append('age')
    if ps.get('sex') in {'M','F'} and not rs.get('sex'): missing.append('sex')
    if missing:
        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': f"Missing/insufficient fields: {', '.join(missing)}."})

    age = coerce_int(rs.get('age'))
    sex = str(rs.get('sex','U')).upper()
    dxs  = {d.upper() for d in rs.get('diagnoses',[])}
    cpts = set(rs.get('procedures',[]))

    allowed_dx  = {d.upper() for d in ps.get('allowed_diagnoses',[])}
    allowed_cpt = set(ps.get('allowed_procedures',[]))
    age_min = ps.get('age_min'); age_max = ps.get('age_max')
    sex_rule = ps.get('sex')
    preauth_required = bool(ps.get('preauth_required', False))
    preauth_provided = bool(rs.get('preauth_provided', False))

    reasons = []
    if allowed_cpt and not (cpts & allowed_cpt): reasons.append('Claimed procedure not covered.')
    if allowed_dx and not (dxs & allowed_dx):   reasons.append('Diagnosis not covered.')
    if age is not None:
        if age_min is not None and age < age_min: reasons.append(f'Patient age {age} is below {age_min}.')
        if age_max is not None and age > age_max: reasons.append(f'Patient age {age} exceeds {age_max}.')
    if sex_rule in {'M','F'} and sex != sex_rule: reasons.append(f'Policy restricted to sex {sex_rule}.')
    if preauth_required and not preauth_provided: reasons.append('Preauthorization required but not provided.')

    if reasons:
        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': '; '.join(reasons)[:500]})
    return json.dumps({'decision':'APPROVE','reason':'Meets policy criteria.'})

# ---------------------- Bind tools to your chat client -------------
# IMPORTANT: You must have defined `chat_client` before running this cell.
# Example:
#   from langchain_openai import ChatOpenAI
#   chat_client = ChatOpenAI(model="gpt-4o-mini", temperature=0)
llm_with_tools = chat_client.bind_tools(tools=[summarize_patient_record, summarize_policy_guideline, check_claim_coverage])

# --------------------------- Agent Prompt --------------------------
AGENT_PROMPT_TXT = """You are an expert claims agent for health insurance. Use ReAct: reason, decide whether to
call a tool, observe the result, and continue until you can produce the final answer.

Rules you MUST follow for each record:
  1) Call summarize_patient_record on the raw record JSON.
  2) Then call summarize_policy_guideline using the policy_id from step 1 (or the record if needed).
  3) Then call check_claim_coverage with the outputs from steps 1 & 2.
  4) Finally, return exactly one line:
     "Decision: <APPROVE|ROUTE FOR REVIEW>. Reason: <short reason>"

Notes:
- Never fabricate data; if required info is missing, route for review and state the missing pieces.
- Keep tool inputs concise JSON. Keep the final reason short and specific.
"""
AGENT_SYS_PROMPT = SystemMessage(content=AGENT_PROMPT_TXT)

# ----------------------- LangGraph State & Node --------------------
class State(TypedDict):
    messages: Annotated[list, add_messages]

def tool_calling_llm(state: State) -> State:
    history = state['messages']
    # Inject the system instructions before the current message sequence
    state_with_instructions = [AGENT_SYS_PROMPT] + history
    response = [llm_with_tools.invoke(state_with_instructions)]
    return {'messages': response}

# ------------------------- Build the Graph -------------------------
builder = StateGraph(State)
builder.add_node('tool_calling_llm', tool_calling_llm)
builder.add_node('tools', ToolNode(tools=[summarize_patient_record, summarize_policy_guideline, check_claim_coverage]))

builder.add_edge(START, 'tool_calling_llm')
builder.add_conditional_edges('tool_calling_llm', tools_condition, ['tools', END])
builder.add_edge('tools', 'tool_calling_llm')

# Compile the ReAct-style agent graph
claims_agent = builder.compile()

# ----------------------------- IO helpers --------------------------
def _load_records(path):
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data: return data['records']
    if isinstance(data, list): return data
    print(f'⚠️ Unexpected JSON root at {path}; expecting list or dict.records.'); return []

def run_on_record(record: Dict[str, Any]) -> Dict[str, str]:
    record_json = json.dumps(record.get('record_str', record), ensure_ascii=False)
    # Initialize conversation: Human asks to process a single record
    messages = [HumanMessage(content=(
        "Process this claim by following the 3-tool sequence and return exactly one line at the end.\n"
        "Record JSON:\n" + record_json
    ))]
    result = claims_agent.invoke({'messages': messages})
    # The last message is the model's reply
    final_msg = result['messages'][-1]
    final_text = final_msg.content if hasattr(final_msg, 'content') else str(final_msg)
    return {'patient_id': record.get('patient_id') or record.get('id'), 'generated_response': final_text.strip()}

# ------------------------- Run & Write Output ----------------------
val = _load_records(VALIDATION_PATH)
if val:
    print('Validation sample:')
    for rec in val[:3]:
        out = run_on_record(rec); print(out['patient_id'], '->', out['generated_response'])
else:
    print('No validation records found.')

test = _load_records(TEST_PATH)
rows = [run_on_record(r) for r in test]
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
