# Capstone 2 — Single Agent + 3 Tools (Submission Clean)
Clean, submission-ready notebook. Reads `./Data`, writes `./Data/submission.csv`.


In [None]:
# Capstone 2 — Single Agent + 3 Tools (Submission Clean)
# ------------------------------------------------------
# One agent, exactly three tools. Clean output suitable for submission.
# Reads from ./Data and writes ./Data/submission.csv

from __future__ import annotations
from typing import Dict, Any, Optional
import json, csv, os, re
from datetime import datetime

# ---- LangChain imports ----
from langchain.tools import tool
from langchain.agents import create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

# ----------------------------- Paths -----------------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

SYSTEM_PROMPT = (
    "You are a careful, compliance-oriented claims agent.\n"
    "You have THREE tools only and must use them deterministically in this order:\n"
    "  (1) summarize_patient_record, (2) summarize_policy_guideline, (3) check_claim_coverage.\n"
    "Return exactly one line at the end:\n"
    "\"Decision: <APPROVE|ROUTE FOR REVIEW>. Reason: <short reason>\"\n"
    "Follow policy rules strictly; do not invent data."
)

# ------------------------- Load references -----------------------
def _safe_load_json(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Missing file: {path}. Using default.")
        return default
    except json.JSONDecodeError as e:
        print(f"JSON parse error in {path}: {e}. Using default.")
        return default

policies  = _safe_load_json(POLICIES_PATH, {})
ref_codes = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {})
CPT_TO_NAME = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {})

# ----------------------------- Helpers ---------------------------
DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try:
            return datetime.strptime(s, pat)
        except Exception:
            pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try:
            return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception:
            return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

# ------------------------------ Tools ----------------------------
@tool
def summarize_patient_record(record_str_or_json: str) -> str:
    """Extract fields from a claim record (raw text or JSON string).
    Output JSON with: patient_id, age, sex, diagnoses, procedures, policy_id, preauth_provided."""
    try:
        rec = json.loads(record_str_or_json)
    except Exception:
        rec = None
    text = record_str_or_json if not rec else json.dumps(rec, ensure_ascii=False)

    patient_id = None
    policy_id  = None
    age        = None
    sex        = 'U'
    preauth_provided = False
    dx_codes, cpt_codes = set(), set()

    if isinstance(rec, dict):
        patient_id = (rec.get('patient_id') or rec.get('id') or rec.get('member_id') or rec.get('patientId') or rec.get('pid'))
        policy_id = (rec.get('policy_id') or rec.get('policy') or rec.get('plan_id') or rec.get('insurancePolicyId'))
        age = rec.get('age')
        if age is None:
            dob = rec.get('dob') or rec.get('date_of_birth') or rec.get('birthdate')
            svc = rec.get('service_date') or rec.get('claim_date') or rec.get('date_of_service')
            age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))
        sex = sex_normalize(rec.get('sex') or rec.get('gender') or rec.get('patient_sex'))

        raw_dx  = (rec.get('diagnosis_codes') or rec.get('diagnoses') or rec.get('dx') or rec.get('icd') or rec.get('icd_codes'))
        raw_cpt = (rec.get('procedure_codes') or rec.get('procedures') or rec.get('cpt') or rec.get('cpt_codes') or rec.get('services'))

        def collect(val):
            out = []
            if val is None: return out
            if isinstance(val, list):
                for it in val:
                    if isinstance(it, dict):
                        for k in ['code','id','icd','cpt']:
                            if it.get(k) is not None: out.append(str(it[k]))
                        for k in ['name','desc','description','text']:
                            if it.get(k): out.append(str(it[k]))
                    else:
                        out.append(str(it))
            elif isinstance(val, dict):
                for k in ['code','id','icd','cpt']:
                    if val.get(k) is not None: out.append(str(val[k]))
                for k in ['name','desc','description','text']:
                    if val.get(k): out.append(str(val[k]))
            else:
                out.append(str(val))
            return out

        for s in collect(raw_dx): dx_codes.add(icd_dot_normalize(s))
        for s in collect(raw_cpt): cpt_codes.add(str(s))

    # Regex fallbacks
    m = re.search(r'\bpatient[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', text, re.I)
    if m and not patient_id: patient_id = m.group(1)
    m = re.search(r'(?<!patient)\bpolicy[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', text, re.I)
    if m and not policy_id: policy_id = m.group(1)
    m = re.search(r'\bage[:\s-]*([0-9]{1,3})\b', text, re.I)
    if m and age is None: age = int(m.group(1))
    m = re.search(r'\b(sex|gender)[:\s-]*([A-Za-z])\b', text, re.I)
    if m and sex == 'U': sex = sex_normalize(m.group(2))
    if re.search(r'pre-?auth(?:orization)?[:\s-]*(yes|y|true|provided|approved)', text, re.I):
        preauth_provided = True

    # Codes from text
    dx_codes |= {icd_dot_normalize(s) for s in re.findall(r'\b([A-Z][0-9][A-Z0-9.\-]{1,6})\b', text)}
    cpt_codes |= set(re.findall(r'\b(\d{5})\b', text))

    # Filter to known ICD if reference present; CPT keep 5-digit
    if isinstance(ICD_TO_NAME, dict) and ICD_TO_NAME:
        dx_codes = {c for c in dx_codes if (c in ICD_TO_NAME)}
    if isinstance(CPT_TO_NAME, dict) and CPT_TO_NAME:
        cpt_known = {c for c in cpt_codes if c in CPT_TO_NAME}
        cpt_codes = cpt_known if cpt_known else {c for c in cpt_codes if re.fullmatch(r'\d{5}', c)}

    return json.dumps({
        'patient_id': patient_id,
        'age': age,
        'sex': sex,
        'diagnoses': sorted(dx_codes),
        'procedures': sorted(cpt_codes),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    })

@tool
def summarize_policy_guideline(policy_id: str) -> str:
    """Look up policy rules by policy_id from insurance_policies.json.
    Returns JSON with allowed_diagnoses, allowed_procedures, age_min, age_max, sex, preauth_required."""
    pol = policies.get(str(policy_id)) if isinstance(policies, dict) else None
    if not pol:
        return json.dumps({'error': f"Unknown policy_id '{policy_id}'"})
    crit = pol.get('criteria', {}) or {}
    return json.dumps({
        'policy_id': policy_id,
        'allowed_diagnoses': crit.get('diagnoses', []) or [],
        'allowed_procedures': crit.get('procedures', []) or [],
        'age_min': crit.get('age_min'),
        'age_max': crit.get('age_max'),
        'sex': crit.get('sex'),
        'preauth_required': bool(crit.get('preauth_required', False)),
        'policy_title': pol.get('title', ''),
        'notes': pol.get('notes', '')
    })

@tool
def check_claim_coverage(record_summary: str, policy_summary: str) -> str:
    """Deterministically compare record vs. policy. Return JSON with decision & reason."""
    try:
        rs = json.loads(record_summary) if isinstance(record_summary, str) else record_summary
        ps = json.loads(policy_summary) if isinstance(policy_summary, str) else policy_summary
    except Exception as e:
        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': f'Malformed inputs: {e}'})

    missing = []
    if not rs.get('policy_id'): missing.append('policy_id')
    if (ps.get('allowed_diagnoses') and not rs.get('diagnoses')): missing.append('diagnoses')
    if (ps.get('allowed_procedures') and not rs.get('procedures')): missing.append('procedures')
    if (ps.get('age_min') is not None or ps.get('age_max') is not None) and rs.get('age') is None:
        missing.append('age')
    if ps.get('sex') in {'M','F'} and not rs.get('sex'):
        missing.append('sex')
    if missing:
        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': f"Missing/insufficient fields: {', '.join(missing)}."})

    age = rs.get('age')
    try: age = int(age) if age is not None else None
    except Exception: age = None
    sex = str(rs.get('sex','U')).upper()
    dxs  = {d.upper() for d in rs.get('diagnoses',[])}
    cpts = set(rs.get('procedures',[]))

    allowed_dx  = {d.upper() for d in ps.get('allowed_diagnoses',[])}
    allowed_cpt = set(ps.get('allowed_procedures',[]))
    age_min = ps.get('age_min'); age_max = ps.get('age_max')
    sex_rule = ps.get('sex')
    preauth_required = bool(ps.get('preauth_required', False))
    preauth_provided = bool(rs.get('preauth_provided', False))

    reasons = []
    if allowed_cpt and not (cpts & allowed_cpt):
        reasons.append('Claimed procedure not covered.')
    if allowed_dx and not (dxs & allowed_dx):
        reasons.append('Diagnosis not covered.')
    if age is not None:
        if age_min is not None and age < age_min: reasons.append(f'Patient age {age} is below {age_min}.')
        if age_max is not None and age > age_max: reasons.append(f'Patient age {age} exceeds {age_max}.')
    if sex_rule in {'M','F'} and sex != sex_rule:
        reasons.append(f'Policy restricted to sex {sex_rule}.')
    if preauth_required and not preauth_provided:
        reasons.append('Preauthorization required but not provided.')

    if reasons:
        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': '; '.join(reasons)[:500]})
    return json.dumps({'decision':'APPROVE','reason':'Meets policy criteria.'})

# ---------------------- Build single agent -----------------------
TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]

PROMPT = ChatPromptTemplate.from_messages([
    ('system', SYSTEM_PROMPT),
    ('human', '{input}'),
    MessagesPlaceholder(variable_name='agent_scratchpad'),  # compatibility with some LC builds
])

agent = create_openai_functions_agent(chat_client, TOOLS, prompt=PROMPT)

# ----------------------------- Runner ----------------------------
def run_on_record(record: Dict[str, Any]) -> Dict[str, str]:
    record_json = json.dumps(record.get('record_str', record), ensure_ascii=False)
    user_input = (
        "Step 1: Call summarize_patient_record on the provided record JSON.\n"
        "Step 2: Call summarize_policy_guideline using the policy_id from step 1 (or the record).\n"
        "Step 3: Call check_claim_coverage with outputs from steps 1 and 2.\n"
        "Return exactly one line: \"Decision: <APPROVE|ROUTE FOR REVIEW>. Reason: <short reason>\".\n\n"
        f"Record JSON:\n{record_json}"
    )
    result = agent.invoke({'input': user_input, 'agent_scratchpad': []})
    if hasattr(result, 'output'):
        final_text = result.output
    elif isinstance(result, dict):
        final_text = result.get('output') or str(result)
    else:
        final_text = str(result)
    return {'patient_id': record.get('patient_id') or record.get('id'), 'generated_response': final_text.strip()}

# ----------------------------- IO -------------------------------
def _load_records(path):
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data:
        return data['records']
    if isinstance(data, list):
        return data
    print(f'Unexpected JSON root at {path}; expecting list or dict.records. Got {type(data).__name__}.')
    return []

# Minimal validation sample (clean)
val = _load_records(VALIDATION_PATH)
if val:
    for rec in val[:3]:
        _ = run_on_record(rec)

# Submission for test set
test = _load_records(TEST_PATH)
rows = [run_on_record(r) for r in test]
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
