# Capstone 2 — LLM Agent (LangGraph ReAct) + Tool Debug
Single agent, three tools, strict format, prints tool inputs/outputs for validation debug.

In [None]:
# Capstone 2 — LLM Agent (LangGraph ReAct) with Tool Debug
# --------------------------------------------------------
# One agent + three tools, enforced sequence, strict output format.
# Debug: prints each tool call input/output for first few validation records.

from __future__ import annotations
from typing import Dict, Any, List, Optional
import os, json, csv

from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent

# ------------------- Config -------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

DEBUG = True
VALIDATION_DEBUG_N = 3

# ------------------- Load Data ----------------
def _safe_load(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        if DEBUG: print(f"⚠️ Could not read {path}: {e}")
        return default

policies = _safe_load(POLICIES_PATH, {})
ref_codes = _safe_load(REF_CODES_PATH, {})
ICD_TO_NAME = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {})
CPT_TO_NAME = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {})

# ------------------- Tools -------------------
@tool
def summarize_patient_record(record_str: str) -> Dict[str, Any]:
    """Summarize patient record JSON string into key fields (patient_id, policy_id, age, sex, diagnoses, procedures, preauth)."""
    try:
        record = json.loads(record_str)
    except Exception:
        return {'error': 'invalid JSON'}
    out = {
        'patient_id': record.get('patient_id') or record.get('id'),
        'policy_id': str(record.get('insurance_policy_id') or record.get('policy_id')),
        'age': record.get('age'),
        'sex': record.get('gender') or record.get('sex'),
        'diagnoses': record.get('diagnosis_codes') or [],
        'procedures': record.get('procedure_codes') or [],
        'preauth': record.get('preauthorization_obtained') or record.get('preauth_provided')
    }
    return out

@tool
def summarize_policy_guideline(policy_id: str) -> Dict[str, Any]:
    """Retrieve policy guideline by policy_id (returns coverage rules, procedures, diagnoses, requirements)."""
    if not policy_id:
        return {'error': 'no policy id'}
    pol = policies.get(policy_id) or {}
    return pol

@tool
def check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    """Compare patient record summary with policy rules and return decision + reason."""
    if not record_summary or not policy_summary:
        return {'decision': 'ROUTE FOR REVIEW', 'reason': 'Missing data.'}
    procedures = record_summary.get('procedures', [])
    preauth = record_summary.get('preauth', False)
    covered = policy_summary.get('covered_procedures', [])
    if not covered:
        return {'decision': 'ROUTE FOR REVIEW', 'reason': 'No covered procedures in policy.'}
    for c in procedures:
        found = False
        for rule in covered:
            if str(rule.get('procedure_code')) == str(c):
                found = True
                if rule.get('requires_preauthorization') and not preauth:
                    return {'decision': 'ROUTE FOR REVIEW', 'reason': f'Preauthorization required for CPT {c} but not provided.'}
        if not found:
            return {'decision': 'ROUTE FOR REVIEW', 'reason': f'CPT {c} not covered by policy.'}
    return {'decision': 'APPROVE', 'reason': f'The claim for CPT code {procedures[0]} is approved.' if procedures else 'Meets policy criteria.'}

TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]

# ------------------- Agent -------------------
AGENT_SYS_PROMPT = SystemMessage(content="""You are an insurance claims agent. 
You MUST always use the tools in this order:
1. summarize_patient_record
2. summarize_policy_guideline
3. check_claim_coverage

Then output in this exact format:
- Decision: APPROVE | ROUTE FOR REVIEW
- Reason: <short reason>
""")

def create_agent(chat_client):
    return create_react_agent(chat_client, TOOLS, state_modifier=AGENT_SYS_PROMPT)

# ------------------- Runner -------------------
def run_agent_on_records(agent, records, n_debug=0):
    results = []
    for i, rec in enumerate(records):
        rec_str = json.dumps(rec)
        msgs = [HumanMessage(content=f"Process this record: {rec_str}")]
        try:
            out = agent.invoke({'messages': msgs})
            final = out['messages'][-1].content if 'messages' in out else str(out)
        except Exception as e:
            final = f"- Decision: ROUTE FOR REVIEW\n- Reason: Error {e}"
        results.append({'patient_id': rec.get('patient_id') or rec.get('id'), 'generated_response': final})
        if i < n_debug and DEBUG:
            print("\n[DEBUG] Record", rec.get('patient_id'))
            for m in out['messages']:
                role = m.__class__.__name__
                print(f"  {role}: {m.content}")
        if i < n_debug and DEBUG:
            print(f"  => Final: {final}")
    return results

# ------------------- Execution -------------------
val_records = _safe_load(VALIDATION_PATH, [])
test_records = _safe_load(TEST_PATH, [])

chat_client = ChatOpenAI(model="gpt-4o-mini", temperature=0)
agent = create_agent(chat_client)

if val_records and DEBUG:
    print("Validation debug run:")
    _ = run_agent_on_records(agent, val_records[:VALIDATION_DEBUG_N], n_debug=VALIDATION_DEBUG_N)

if test_records:
    test_results = run_agent_on_records(agent, test_records)
    with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
        w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
        w.writeheader(); w.writerows(test_results)
    print(f"Wrote {SUBMISSION_PATH} with {len(test_results)} rows.")
