# Capstone 2 — Deterministic 3-Tool Runner
Uses the same three tools logic but computes the decision deterministically (no LLM deciding).

In [None]:
# Capstone 2 — Deterministic Runner (One Agent + Three Tools API, but rule-based final check)
# -------------------------------------------------------------------------------------------
# Why this version? Your LLM was returning APPROVE too often. This cell runs a *deterministic*
# 3-step pipeline (same tools) and produces the final decision without asking the LLM to judge.
# You still keep the agent cell above if you want it, but this runner guarantees strict behavior.
#
# Output format enforced:
#   Decision: <APPROVE|ROUTE FOR REVIEW> Reason: <short reason>
#
# Files under ./Data are used, and submission is written to ./Data/submission.csv

from __future__ import annotations
from typing import Dict, Any, Optional, List
import json, csv, os, re
from datetime import datetime

# ----------------------------- Paths -----------------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

# ------------------------- Safe JSON Loading ----------------------
def _safe_load_json(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception:
        return default

policies_raw  = _safe_load_json(POLICIES_PATH, {})
ref_codes     = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME   = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {}) or {}
CPT_TO_NAME   = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {}) or {}

# ------------------------------ Helpers ---------------------------
def _norm_pid(pid: Optional[str]) -> Optional[str]:
    if not pid: return None
    return re.sub(r'[-_\s]', '', str(pid)).upper()

def _build_policy_index(raw):
    index = {}
    if isinstance(raw, dict):
        for k, v in raw.items():
            nk = _norm_pid(k)
            if nk: index[nk] = v
            if isinstance(v, dict):
                inner = v.get('policy_id') or v.get('policyId') or v.get('id') or v.get('code') or v.get('policy_code')
                ni = _norm_pid(inner)
                if ni: index.setdefault(ni, v)
    elif isinstance(raw, list):
        for obj in raw:
            if not isinstance(obj, dict): continue
            for c in [obj.get('policy_id'), obj.get('policyId'), obj.get('id'), obj.get('code'), obj.get('policy_code')]:
                nc = _norm_pid(c)
                if nc: index[nc] = obj; break
    return index

POLICY_INDEX = _build_policy_index(policies_raw)

DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try: return datetime.strptime(s, pat)
        except Exception: pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try: return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception: return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def coerce_int(val) -> Optional[int]:
    try: return int(val)
    except Exception:
        if val is None: return None
        m = re.search(r'\d+', str(val)); return int(m.group()) if m else None

def walk_items(obj):
    if isinstance(obj, dict):
        for k, v in obj.items():
            yield k, v
            for subk, subv in walk_items(v):
                yield subk, subv
    elif isinstance(obj, list):
        for it in obj:
            for subk, subv in walk_items(it):
                yield subk, subv

def deep_get_first(d: Any, keys: List[str]):
    keys_l = {k.lower() for k in keys}
    for k, v in walk_items(d):
        if isinstance(k, str) and k.lower() in keys_l:
            return v
    return None

def collect_mixed(val):
    out = []
    if val is None: return out
    if isinstance(val, list):
        for it in val:
            if isinstance(it, dict):
                for k in ['code','id','icd','cpt']:
                    if it.get(k) is not None: out.append(str(it[k]))
                for k in ['name','desc','description','text']:
                    if it.get(k): out.append(str(it[k]))
            else:
                out.append(str(it))
    elif isinstance(val, dict):
        for k in ['code','id','icd','cpt']:
            if val.get(k) is not None: out.append(str(val[k]))
        for k in ['name','desc','description','text']:
            if val.get(k): out.append(str(val[k]))
    else:
        out.append(str(val))
    return out

# --------------------- Tools (same logic as earlier) ---------------------
def tool_summarize_patient_record(record: Dict[str, Any]) -> Dict[str, Any]:
    patient_id = (record.get('patient_id') or record.get('id') or record.get('member_id') or record.get('patientId') or record.get('pid'))
    raw_pid = (record.get('policy_id') or record.get('policy') or record.get('plan_id')
               or record.get('insurancePolicyId') or deep_get_first(record, ['insurance_policy_id']))
    policy_id = _norm_pid(raw_pid)
    age = coerce_int(record.get('age') or deep_get_first(record, ['age','age_years','patient_age']))
    if age is None:
        dob = (record.get('dob') or record.get('date_of_birth') or record.get('birthdate')
               or deep_get_first(record, ['dob','date_of_birth','birthdate']))
        svc = (record.get('service_date') or record.get('claim_date') or record.get('date_of_service')
               or deep_get_first(record, ['service_date','claim_date','date_of_service']))
        age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))
    sex = sex_normalize(record.get('sex') or record.get('gender') or deep_get_first(record, ['sex','gender','patient_sex']))
    preauth_val = (record.get('preauth_provided') or deep_get_first(record, ['preauth','preauthorization','authorization','prior_auth','prior_authorization','preauth_provided']))
    preauth_provided = False
    if preauth_val is not None:
        preauth_provided = str(preauth_val).strip().lower() in {'1','y','yes','true','approved','provided'}
    raw_dx  = (record.get('diagnosis_codes') or record.get('diagnoses') or record.get('dx') or record.get('icd') or record.get('icd_codes')
               or deep_get_first(record, ['diagnosis_codes','diagnoses','diagnosis','dx','icd','icd_codes']))
    raw_cpt = (record.get('procedure_codes') or record.get('procedures') or record.get('cpt') or record.get('cpt_codes') or record.get('services')
               or deep_get_first(record, ['procedure_codes','procedures','procedure','cpt','cpt_codes','services']))
    dx_codes  = {icd_dot_normalize(s) for s in collect_mixed(raw_dx)}
    cpt_codes = set(collect_mixed(raw_cpt))

    if ICD_TO_NAME: dx_codes = {c for c in dx_codes if c in ICD_TO_NAME}
    if CPT_TO_NAME:
        known = {c for c in cpt_codes if c in CPT_TO_NAME}
        cpt_codes = known if known else {c for c in cpt_codes if re.fullmatch(r'\d{5}', c)}

    return {
        'patient_id': patient_id,
        'age': age,
        'sex': sex,
        'diagnoses': sorted(dx_codes),
        'procedures': sorted(cpt_codes),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    }

def tool_summarize_policy_guideline(policy_id: Optional[str]) -> Dict[str, Any]:
    if not policy_id: 
        return {'error': 'Unknown policy_id'}
    pol = POLICY_INDEX.get(policy_id)
    if not pol:
        return {'error': 'Unknown policy_id'}
    crit = pol.get('criteria', {}) or {}
    def _norm_list(v):
        if v is None: return []
        if isinstance(v, str):
            parts = re.split(r'[;,\s]+', v.strip())
            return [p for p in parts if p]
        if isinstance(v, list): return v
        return []
    allowed_dx  = [icd_dot_normalize(str(x)) for x in _norm_list(crit.get('diagnoses'))]
    allowed_cpt = [str(x) for x in _norm_list(crit.get('procedures'))]
    return {
        'policy_id': policy_id,
        'allowed_diagnoses': allowed_dx,
        'allowed_procedures': allowed_cpt,
        'age_min': crit.get('age_min'),
        'age_max': crit.get('age_max'),
        'sex': crit.get('sex'),
        'preauth_required': bool(crit.get('preauth_required', False))
    }

def tool_check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    # Unknown policy
    if policy_summary.get('error'):
        return {'decision':'ROUTE FOR REVIEW','reason':'Unknown policy_id.'}
    missing = []
    if not record_summary.get('policy_id'): missing.append('policy_id')
    if (policy_summary.get('allowed_diagnoses')) and not record_summary.get('diagnoses'): missing.append('diagnoses')
    if (policy_summary.get('allowed_procedures')) and not record_summary.get('procedures'): missing.append('procedures')
    if (policy_summary.get('age_min') is not None or policy_summary.get('age_max') is not None) and record_summary.get('age') is None:
        missing.append('age')
    if policy_summary.get('sex') in {'M','F'} and not record_summary.get('sex'): missing.append('sex')
    if missing:
        return {'decision':'ROUTE FOR REVIEW','reason': f"Missing/insufficient fields: {', '.join(missing)}."}

    age = coerce_int(record_summary.get('age'))
    sex = str(record_summary.get('sex','U')).upper()
    dxs  = [icd_dot_normalize(d) for d in record_summary.get('diagnoses',[])]
    cpts = record_summary.get('procedures',[])

    allowed_dx  = [icd_dot_normalize(d) for d in policy_summary.get('allowed_diagnoses',[])]
    allowed_cpt = policy_summary.get('allowed_procedures',[])
    age_min = policy_summary.get('age_min'); age_max = policy_summary.get('age_max')
    sex_rule = policy_summary.get('sex')
    preauth_required = bool(policy_summary.get('preauth_required', False))
    preauth_provided = bool(record_summary.get('preauth_provided', False))

    reasons = []
    if allowed_cpt:
        not_allowed_cpt = [c for c in cpts if c not in allowed_cpt]
        if not_allowed_cpt:
            reasons.append(f"Not covered procedure(s): {', '.join(not_allowed_cpt)}.")
    if allowed_dx:
        not_allowed_dx = [d for d in dxs if d not in set(allowed_dx)]
        if not_allowed_dx:
            reasons.append(f"Not covered diagnosis code(s): {', '.join(not_allowed_dx)}.")
    if age is not None:
        if age_min is not None and age < age_min: reasons.append(f'Patient age {age} is below {age_min}.')
        if age_max is not None and age > age_max: reasons.append(f'Patient age {age} exceeds {age_max}.')
    if sex_rule in {'M','F'} and sex != sex_rule: reasons.append(f'Policy restricted to sex {sex_rule}.')
    if preauth_required and not preauth_provided: reasons.append('Preauthorization required but not provided.')

    if reasons:
        return {'decision':'ROUTE FOR REVIEW','reason':' '.join(reasons)[:500]}
    return {'decision':'APPROVE','reason':'Meets policy criteria.'}

# ----------------------------- Runner -----------------------------
def _load_records(path):
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data: return data['records']
    if isinstance(data, list): return data
    return []

def run_record_det(record):
    rs = tool_summarize_patient_record(record)
    ps = tool_summarize_policy_guideline(rs.get('policy_id'))
    dec = tool_check_claim_coverage(rs, ps)
    final_text = f"Decision: {dec['decision']} Reason: {dec['reason']}"
    final_text = re.sub(r'\s+', ' ', final_text).strip()
    return {'patient_id': record.get('patient_id') or record.get('id'), 'generated_response': final_text}

# Preview validation
val = _load_records(VALIDATION_PATH)
if val:
    print('Validation sample (deterministic):')
    for rec in val[:3]:
        out = run_record_det(rec)
        print(out['patient_id'], '->', out['generated_response'])
else:
    print('No validation records found.')

# Full test run
test = _load_records(TEST_PATH)
rows = [run_record_det(r) for r in test]
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
