# Capstone 2 – Healthcare Claim Coverage Agent (Prompt-based ReAct Agent)
This version ensures ASCII-only docstrings for all tools. Provide an authenticated `chat_client` before running. Reads from `./Data`, writes to `./Data/submission.csv`.

In [None]:
from typing import Dict, Anyfrom langchain.tools import toolfrom langchain.agents import create_react_agentfrom langchain_core.prompts import ChatPromptTemplateimport json, csv, os, re# ----------------------------- Paths -----------------------------DATA_DIR = './Data'VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'TEST_PATH       = f'{DATA_DIR}/test_records.json'POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'SYSTEM_PROMPT = (    'You are a careful, compliance-oriented claims AI. '    'Use tools to extract facts deterministically, then decide APPROVE or ROUTE FOR REVIEW '    'with a short factual reason based strictly on policy rules.')# ------------------------- Load references -----------------------def _safe_load_json(path, default):    try:        with open(path, 'r') as f:            return json.load(f)    except FileNotFoundError:        print(f'Missing file: {path}. Skipping.')        return defaultpolicies  = _safe_load_json(POLICIES_PATH, {})ref_codes = _safe_load_json(REF_CODES_PATH, {})ICD_TO_NAME = ref_codes.get('diagnosis_codes', {})CPT_TO_NAME = ref_codes.get('procedure_codes', {})# ----------------------------- Helpers ---------------------------def normalize_age(age_val):    try:        return int(age_val)    except Exception:        m = re.search(r'\d+', str(age_val))        return int(m.group()) if m else Nonedef sex_normalize(s):    s = str(s).strip().upper()    if s in {'M','MALE'}: return 'M'    if s in {'F','FEMALE'}: return 'F'    return 'U'def dict_get(d, *keys, default=None):    cur = d    for k in keys:        if not isinstance(cur, dict):            return default        cur = cur.get(k)    return default if cur is None else cur# ------------------------------ Tools ----------------------------@tooldef summarize_patient_record(record_str: str) -> str:    """Extract key fields from a claim record text.    Returns JSON string with: patient_id, age, sex, diagnoses, procedures, policy_id, preauth_provided.    """    patient_id = None    age = None    sex = 'U'    policy_id = None    diagnoses = []    procedures = []    preauth_provided = False    pid = re.search(r'patient[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', record_str, re.I)    if pid: patient_id = pid.group(1)    age_m = re.search(r'\bage[:\s-]*([0-9]{1,3})\b', record_str, re.I)    if age_m: age = int(age_m.group(1))    sex_m = re.search(r'\b(sex|gender)[:\s-]*([A-Za-z]+)', record_str, re.I)    if sex_m: sex = sex_normalize(sex_m.group(2))    pol_m = re.search(r'\bpolicy[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', record_str, re.I)    if pol_m: policy_id = pol_m.group(1)    diagnoses = list({c.upper() for c in re.findall(r'\b([A-Z]\d{1,2}[A-Z0-9.\-]+)\b', record_str) if c.upper() in ICD_TO_NAME})    procedures = list({c for c in re.findall(r'\b(\d{4,5})\b', record_str) if c in CPT_TO_NAME})    if re.search(r'pre-?auth(?:orization)?[:\s-]*(yes|y|true|provided|approved)', record_str, re.I):        preauth_provided = True    return json.dumps({        'patient_id': patient_id,        'age': age,        'sex': sex,        'diagnoses': diagnoses,        'procedures': procedures,        'policy_id': policy_id,        'preauth_provided': preauth_provided    })@tooldef summarize_policy_guideline(policy_id: str) -> str:    """Return a compact JSON rule set for the given policy_id using insurance_policies.json."""    pol = policies.get(policy_id)    if not pol:        return json.dumps({'error': f"Unknown policy_id '{policy_id}'"})    allowed_dx   = dict_get(pol, 'criteria', 'diagnoses', default=[])    allowed_cpt  = dict_get(pol, 'criteria', 'procedures', default=[])    age_min      = dict_get(pol, 'criteria', 'age_min', default=None)    age_max      = dict_get(pol, 'criteria', 'age_max', default=None)    sex_rule     = dict_get(pol, 'criteria', 'sex', default=None)    preauth_req  = bool(dict_get(pol, 'criteria', 'preauth_required', default=False))    return json.dumps({        'policy_id': policy_id,        'allowed_diagnoses': allowed_dx,        'allowed_procedures': allowed_cpt,        'age_min': age_min,        'age_max': age_max,        'sex': sex_rule,        'preauth_required': preauth_req,        'policy_title': pol.get('title', ''),        'notes': dict_get(pol, 'notes', default='')    })@tooldef check_claim_coverage(record_summary: str, policy_summary: str) -> str:    """Compare record summary vs. policy summary deterministically.    Returns JSON with decision and reason.    """    try:        rs = json.loads(record_summary)        ps = json.loads(policy_summary)    except Exception as e:        return json.dumps({'decision':'ROUTE FOR REVIEW', 'reason': f'Malformed inputs: {e}'})    missing = [k for k in ['age','sex','diagnoses','procedures','policy_id'] if rs.get(k) in [None, [], '']]    if missing:        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': f"Missing/insufficient fields: {', '.join(missing)}."})    def NA(x):        try:            return int(x)        except Exception:            m = re.search(r'\d+', str(x)); return int(m.group()) if m else None    age  = NA(rs.get('age'))    sex  = str(rs.get('sex','U')).upper()    dxs  = {d.upper() for d in rs.get('diagnoses',[])}    cpts = set(rs.get('procedures',[]))    allowed_dx  = {d.upper() for d in ps.get('allowed_diagnoses',[])}    allowed_cpt = set(ps.get('allowed_procedures',[]))    age_min = ps.get('age_min'); age_max = ps.get('age_max')    sex_rule = ps.get('sex')    preauth_required = bool(ps.get('preauth_required', False))    preauth_provided = bool(rs.get('preauth_provided', False))    reasons = []    if allowed_cpt and not (cpts & allowed_cpt):        reasons.append('Claimed procedure not covered.')    if allowed_dx and not (dxs & allowed_dx):        reasons.append('Diagnosis not covered.')    if age is not None:        if age_min is not None and age < age_min: reasons.append(f'Patient age {age} is below {age_min}.')        if age_max is not None and age > age_max: reasons.append(f'Patient age {age} exceeds {age_max}.')    else:        reasons.append('Age unavailable.')    if sex_rule in {'M','F'} and sex != sex_rule:        reasons.append(f"Policy restricted to sex {sex_rule}." )    if preauth_required and not preauth_provided:        reasons.append('Preauthorization required but not provided.')    if reasons:        return json.dumps({'decision':'ROUTE FOR REVIEW','reason': '; '.join(reasons)[:500]})    return json.dumps({'decision':'APPROVE','reason':'Meets policy criteria.'})# ---------------------- Agent (prompt-based) ----------------------TOOLS = [summarize_patient_record, summarize_policy_guideline, check_claim_coverage]PROMPT = ChatPromptTemplate.from_messages([    ('system', SYSTEM_PROMPT),    ('human', '{input}')])agent = create_react_agent(chat_client, TOOLS, prompt=PROMPT)# ----------------------------- Runners ----------------------------def run_on_record(record: Dict[str, Any]) -> Dict[str, str]:    record_str = record.get('record_str', '')    policy_id  = record.get('policy_id', '')    user_input = (        'Use the tools to:\n'        '1) summarize_patient_record from the record,\n'        f'2) summarize_policy_guideline for policy_id={policy_id},\n'        '3) check_claim_coverage,\n'        'then return one line:\n'        '"Decision: <APPROVE|ROUTE FOR REVIEW>. Reason: <reason>"\n\n'        'Record:\n' + record_str + '\n\nPolicy ID: ' + str(policy_id)    )    result = agent.invoke({'input': user_input})    final_text = result.get('output', str(result)) if isinstance(result, dict) else str(result)    return {'patient_id': record.get('patient_id'), 'generated_response': final_text.strip()}def load_records(path):    try:        with open(path, 'r') as f:            data = json.load(f)        if isinstance(data, dict) and 'records' in data:            return data['records']        return data if isinstance(data, list) else []    except FileNotFoundError:        print(f'Missing file: {path}. Skipping.')        return []# -------------------- Validation quick print ---------------------val = load_records(VALIDATION_PATH)if val:    print('Validation sample:')    for rec in val[:3]:        out = run_on_record(rec)        print(out['patient_id'], '->', out['generated_response'])else:    print('No validation records found.')# --------------------- Submission for test set -------------------test = load_records(TEST_PATH)rows = [run_on_record(r) for r in test]os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])    w.writeheader(); w.writerows(rows)print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')