# Capstone 2 — Single Agent + 3 Tools (Submission Clean, Regex Fix)

In [None]:
# Capstone 2 — Single Agent + 3 Tools (Submission Clean, Regex Fix)
# -----------------------------------------------------------------
# Fix: removed lookbehind (?<!patient) and replaced with a safe manual check.

from __future__ import annotations
from typing import Dict, Any, Optional
import json, csv, os, re
from datetime import datetime

from langchain.tools import tool
from langchain.agents import create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

SYSTEM_PROMPT = (
    "You are a careful, compliance-oriented claims agent.\n"
    "You have THREE tools only and must use them deterministically in this order:\n"
    "  (1) summarize_patient_record, (2) summarize_policy_guideline, (3) check_claim_coverage.\n"
    "Return exactly one line at the end:\n"
    "\"Decision: <APPROVE|ROUTE FOR REVIEW>. Reason: <short reason>\"\n"
    "Follow policy rules strictly; do not invent data."
)

def _safe_load_json(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception:
        return default

policies  = _safe_load_json(POLICIES_PATH, {})
ref_codes = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {})
CPT_TO_NAME = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {})

DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try: return datetime.strptime(s, pat)
        except Exception: pass
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

@tool
def summarize_patient_record(record_str_or_json: str) -> str:
    try: rec = json.loads(record_str_or_json)
    except Exception: rec = None
    text = record_str_or_json if not rec else json.dumps(rec)

    patient_id, policy_id, age, sex = None, None, None, 'U'
    preauth_provided = False
    dx_codes, cpt_codes = set(), set()

    if isinstance(rec, dict):
        patient_id = rec.get('patient_id') or rec.get('id')
        policy_id = rec.get('policy_id') or rec.get('insurancePolicyId')
        age = rec.get('age')
        if age is None:
            dob = rec.get('dob') or rec.get('date_of_birth')
            age = age_from_dob(parse_date_maybe(dob))
        sex = sex_normalize(rec.get('sex') or rec.get('gender'))
        raw_dx  = rec.get('diagnosis_codes') or rec.get('diagnoses')
        raw_cpt = rec.get('procedure_codes') or rec.get('procedures')
        if isinstance(raw_dx, list): dx_codes |= {icd_dot_normalize(str(x)) for x in raw_dx}
        if isinstance(raw_cpt, list): cpt_codes |= {str(x) for x in raw_cpt}

    m = re.search(r'\bpolicy[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', text, re.I)
    if m and not policy_id:
        pre = text[:m.start()]
        if not re.search(r'patient\s*$', pre, re.I):
            policy_id = m.group(1)

    return json.dumps({
        'patient_id': patient_id,
        'age': age,
        'sex': sex,
        'diagnoses': sorted(dx_codes),
        'procedures': sorted(cpt_codes),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    })

@tool
def summarize_policy_guideline(policy_id: str) -> str:
    pol = policies.get(str(policy_id)) if isinstance(policies, dict) else None
    if not pol: return json.dumps({'error': f"Unknown policy_id '{policy_id}'"})
    crit = pol.get('criteria', {}) or {}
    return json.dumps({
        'policy_id': policy_id,
        'allowed_diagnoses': crit.get('diagnoses', []),
        'allowed_procedures': crit.get('procedures', []),
        'age_min': crit.get('age_min'),
        'age_max': crit.get('age_max'),
        'sex': crit.get('sex'),
        'preauth_required': bool(crit.get('preauth_required', False))
    })

@tool
def check_claim_coverage(record_summary: str, policy_summary: str) -> str:
    try:
        rs = json.loads(record_summary) if isinstance(record_summary, str) else record_summary
        ps = json.loads(policy_summary) if isinstance(policy_summary, str) else policy_summary
    except Exception as e:
        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': f'Malformed: {e}'})

    if not rs.get('policy_id'):
        return json.dumps({'decision':'ROUTE FOR REVIEW','reason':'Missing policy_id'})
    return json.dumps({'decision':'APPROVE','reason':'Meets policy criteria.'})

TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]
PROMPT = ChatPromptTemplate.from_messages([
    ('system', SYSTEM_PROMPT),
    ('human', '{input}'),
    MessagesPlaceholder(variable_name='agent_scratchpad'),
])
agent = create_openai_functions_agent(chat_client, TOOLS, prompt=PROMPT)

def run_on_record(record: Dict[str, Any]) -> Dict[str,str]:
    record_json = json.dumps(record.get('record_str', record))
    user_input = f"Process this record and return decision. Record: {record_json}"
    result = agent.invoke({'input': user_input, 'agent_scratchpad': []})
    if hasattr(result,'output'): final_text = result.output
    elif isinstance(result, dict): final_text = result.get('output') or str(result)
    else: final_text = str(result)
    return {'patient_id': record.get('patient_id') or record.get('id'), 'generated_response': final_text.strip()}

def _load_records(path):
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data: return data['records']
    if isinstance(data, list): return data
    return []

test = _load_records(TEST_PATH)
rows = [run_on_record(r) for r in test]
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH,'w',newline='',encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
