# Capstone 2 — Deterministic (per‑CPT rules) + TEST Debug
Prints a preview for TEST set, decision counts, and writes `debug_test.csv` with parsed fields + decisions.

In [None]:
# Capstone 2 — Deterministic (per‑CPT rules) + DEBUG for TEST set
# ----------------------------------------------------------------
# What changed vs code_35:
#   • Also prints a DEBUG preview for the FIRST 3 TEST records (not just validation).
#   • Prints decision counts (APPROVE vs ROUTE) for the TEST set.
#   • Writes an optional ./Data/debug_test.csv with parsed fields and the decision.
#
# Toggle DEBUG flags below if you want less output.

from __future__ import annotations
from typing import Dict, Any, Optional, List, Tuple
import json, csv, os, re, textwrap
from datetime import datetime

# ----------------------------- Config -----------------------------
DEBUG = True
VALIDATION_DEBUG_N = 3
TEST_DEBUG_N = 3
WRITE_DEBUG_CSV = True
TRUNC = 1200

def _dbg(title: str, payload: Any = None):
    if not DEBUG: return
    print(f"[DEBUG] {title}")
    if payload is not None:
        try:
            s = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False, indent=2)
        except Exception:
            s = str(payload)
        if len(s) > TRUNC: s = s[:TRUNC] + " ... [truncated]"
        print(textwrap.indent(s, "  "))

# ----------------------------- Paths -----------------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'
DEBUG_TEST_PATH = f'{DATA_DIR}/debug_test.csv'

# ------------------------- Safe JSON Loading ----------------------
def _safe_load_json(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"⚠️ Could not read {path}: {e}"); return default

policies_raw  = _safe_load_json(POLICIES_PATH, {})
ref_codes     = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME   = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {}) or {}
CPT_TO_NAME   = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {}) or {}

# ------------------------------ Helpers ---------------------------
def _norm_pid(pid: Optional[str]) -> Optional[str]:
    if not pid: return None
    return re.sub(r'[-_\s]', '', str(pid)).upper()

def _build_policy_index(raw):
    def _extract_rules(src) -> List[Dict[str, Any]]:
        rules = []
        items = src.get('covered_procedures') or []
        if isinstance(items, list):
            for it in items:
                if not isinstance(it, dict): 
                    continue
                proc = str(it.get('procedure_code') or it.get('procedure') or '').strip()
                if not proc: 
                    continue
                diags = it.get('covered_diagnoses') or it.get('diagnoses') or []
                if isinstance(diags, str):
                    diags = [d for d in re.split(r'[;,\s]+', diags) if d]
                diags = [str(d).strip().upper() for d in diags]
                diags = [ (d[:3] + '.' + d[3:]) if re.fullmatch(r'[A-Z]\d{3,6}', d) and '.' not in d and len(d)>=4 else d for d in diags ]
                ar = it.get('age_range') or []
                age_min = None; age_max = None
                if isinstance(ar, list) and len(ar) >= 2:
                    try: age_min = int(ar[0])
                    except Exception: pass
                    try: age_max = int(ar[1])
                    except Exception: pass
                sex_raw = str(it.get('gender') or it.get('sex') or 'Any').strip().upper()
                sex = 'ANY'
                if sex_raw.startswith('M'): sex = 'M'
                elif sex_raw.startswith('F'): sex = 'F'
                preauth_required = bool(it.get('requires_preauthorization') or it.get('preauth_required') or False)
                rules.append({
                    'procedure': proc,
                    'diagnoses': diags,
                    'age_min': age_min,
                    'age_max': age_max,
                    'sex': sex,
                    'preauth_required': preauth_required
                })
        return rules

    index = {}
    if isinstance(raw, dict):
        if 'policies' in raw and isinstance(raw['policies'], list):
            src_list = raw['policies']
        else:
            src_list = []
            for k, v in raw.items():
                if isinstance(v, dict):
                    v = dict(v)
                    v.setdefault('policy_id', v.get('policyId') or v.get('id') or k)
                    src_list.append(v)
        for obj in src_list:
            pid = _norm_pid(obj.get('policy_id') or obj.get('policyId') or obj.get('id'))
            if not pid: 
                continue
            policy = {
                'policy_id': pid,
                'plan_name': obj.get('plan_name') or obj.get('title') or '',
                'rules': _extract_rules(obj)
            }
            index[pid] = policy
    elif isinstance(raw, list):
        for obj in raw:
            if not isinstance(obj, dict): 
                continue
            pid = _norm_pid(obj.get('policy_id') or obj.get('policyId') or obj.get('id'))
            if not pid: 
                continue
            policy = {
                'policy_id': pid,
                'plan_name': obj.get('plan_name') or obj.get('title') or '',
                'rules': _extract_rules(obj)
            }
            index[pid] = policy
    return index

POLICY_INDEX = _build_policy_index(policies_raw)

DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try: return datetime.strptime(s, pat)
        except Exception: pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try: return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception: return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def coerce_int(val) -> Optional[int]:
    try: return int(val)
    except Exception:
        if val is None: return None
        m = re.search(r'\d+', str(val)); return int(m.group()) if m else None

def walk_items(obj):
    if isinstance(obj, dict):
        for k, v in obj.items():
            yield k, v
            for subk, subv in walk_items(v):
                yield subk, subv
    elif isinstance(obj, list):
        for it in obj:
            for subk, subv in walk_items(it):
                yield subk, subv

def deep_get_first(d: Any, keys: List[str]):
    keys_l = {k.lower() for k in keys}
    for k, v in walk_items(d):
        if isinstance(k, str) and k.lower() in keys_l:
            return v
    return None

def collect_mixed(val):
    out = []
    if val is None: return out
    if isinstance(val, list):
        for it in val:
            if isinstance(it, dict):
                for k in ['code','id','icd','cpt']:
                    if it.get(k) is not None: out.append(str(it[k]))
                for k in ['name','desc','description','text']:
                    if it.get(k): out.append(str(it[k]))
            else:
                out.append(str(it))
    elif isinstance(val, dict):
        for k in ['code','id','icd','cpt']:
            if val.get(k) is not None: out.append(str(val[k]))
        for k in ['name','desc','description','text']:
            if val.get(k): out.append(str(val[k]))
    else:
        out.append(str(val))
    return out

# --------------------- Tools (deterministic versions) ---------------------
def tool_summarize_patient_record(record: Dict[str, Any]) -> Dict[str, Any]:
    patient_id = (record.get('patient_id') or record.get('id') or record.get('member_id') or record.get('patientId') or record.get('pid'))
    raw_pid = (record.get('policy_id') or record.get('policy') or record.get('plan_id')
               or record.get('insurancePolicyId') or deep_get_first(record, ['insurance_policy_id']))
    policy_id = _norm_pid(raw_pid)
    age = coerce_int(record.get('age') or deep_get_first(record, ['age','age_years','patient_age']))
    if age is None:
        dob = (record.get('dob') or record.get('date_of_birth') or record.get('birthdate')
               or deep_get_first(record, ['dob','date_of_birth','birthdate']))
        svc = (record.get('service_date') or record.get('claim_date') or record.get('date_of_service')
               or deep_get_first(record, ['service_date','claim_date','date_of_service']))
        age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))
    sex = sex_normalize(record.get('sex') or record.get('gender') or deep_get_first(record, ['sex','gender','patient_sex']))
    preauth_val = (record.get('preauth_provided') or deep_get_first(record, ['preauth','preauthorization','authorization','prior_auth','prior_authorization','preauth_provided']))
    preauth_provided = False
    if preauth_val is not None:
        preauth_provided = str(preauth_val).strip().lower() in {'1','y','yes','true','approved','provided'}
    raw_dx  = (record.get('diagnosis_codes') or record.get('diagnoses') or record.get('dx') or record.get('icd') or record.get('icd_codes')
               or deep_get_first(record, ['diagnosis_codes','diagnoses','diagnosis','dx','icd','icd_codes']))
    raw_cpt = (record.get('procedure_codes') or record.get('procedures') or record.get('cpt') or record.get('cpt_codes') or record.get('services')
               or deep_get_first(record, ['procedure_codes','procedures','procedure','cpt','cpt_codes','services']))
    dx_codes  = {icd_dot_normalize(s) for s in collect_mixed(raw_dx)}
    cpt_codes = set(collect_mixed(raw_cpt))

    if ICD_TO_NAME: dx_codes = {c for c in dx_codes if c in ICD_TO_NAME}
    if CPT_TO_NAME:
        known = {c for c in cpt_codes if c in CPT_TO_NAME}
        cpt_codes = known if known else {c for c in cpt_codes if re.fullmatch(r'\d{5}', c)}

    out = {
        'patient_id': patient_id,
        'age': age,
        'sex': sex,
        'diagnoses': sorted(dx_codes),
        'procedures': sorted(cpt_codes),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    }
    _dbg("Record summary", out)
    return out

def tool_summarize_policy_guideline(policy_id: Optional[str]) -> Dict[str, Any]:
    if not policy_id:
        return {'error': 'Unknown policy_id'}
    pol = POLICY_INDEX.get(policy_id)
    if not pol:
        return {'error': 'Unknown policy_id'}
    _dbg("Policy rules count", { 'policy_id': policy_id, 'rules': len(pol.get('rules', [])) })
    return pol

def _rule_covers(record: Dict[str, Any], rule: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    cpts = set(record.get('procedures', []))
    dxs  = set(record.get('diagnoses', []))
    age  = record.get('age')
    sex  = record.get('sex', 'U').upper()
    pre  = bool(record.get('preauth_provided', False))

    if rule['procedure'] not in cpts:
        return False, f"Rule covers CPT {rule['procedure']} only."
    rdx = set(rule.get('diagnoses', []))
    if rdx and not (dxs & rdx):
        return False, f"Diagnosis mismatch for CPT {rule['procedure']}. Required any of {sorted(rdx)}; have {sorted(dxs)}."
    if age is not None:
        if rule.get('age_min') is not None and age < rule['age_min']:
            return False, f"Age {age} < minimum {rule['age_min']} for CPT {rule['procedure']}."
        if rule.get('age_max') is not None and age > rule['age_max']:
            return False, f"Age {age} > maximum {rule['age_max']} for CPT {rule['procedure']}."
    sex_rule = rule.get('sex', 'ANY')
    if sex_rule in {'M','F'} and sex != sex_rule:
        return False, f"Sex mismatch for CPT {rule['procedure']}: requires {sex_rule}."
    if rule.get('preauth_required') and not pre:
        return False, f"Preauthorization required for CPT {rule['procedure']} but not provided."
    return True, None

def tool_check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    if policy_summary.get('error'):
        return {'decision':'ROUTE FOR REVIEW','reason':'Unknown policy_id.'}
    rules = policy_summary.get('rules', [])
    if not rules:
        return {'decision':'ROUTE FOR REVIEW','reason':'Policy does not define covered procedures.'}

    for cpt in record_summary.get('procedures', []):
        cand = [r for r in rules if r.get('procedure') == cpt]
        _dbg("Evaluate CPT", {'cpt': cpt, 'candidate_rules': cand})
        if not cand:
            return {'decision':'ROUTE FOR REVIEW','reason': f"The claim for CPT code {cpt} is not covered by the policy."}
        ok_any = False; first_reason = None
        for r in cand:
            ok, why = _rule_covers(record_summary, r)
            _dbg("Rule check", {'rule': r, 'ok': ok, 'why': why})
            if ok:
                ok_any = True; break
            if first_reason is None: first_reason = why
        if not ok_any:
            return {'decision':'ROUTE FOR REVIEW','reason': first_reason or f"The claim for CPT code {cpt} does not meet policy requirements."}

    cpts = record_summary.get('procedures', [])
    detail = f"The claim for CPT code {cpts[0]} is approved." if cpts else "Meets policy criteria."
    return {'decision':'APPROVE','reason': detail}

# ----------------------------- Runner -----------------------------
def _load_records(path):
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data: return data['records']
    if isinstance(data, list): return data
    return []

def format_sample_style(decision: str, reason: str) -> str:
    return f"- Decision: {decision}\n- Reason: {reason}"

def run_record_det(record):
    rs = tool_summarize_patient_record(record)
    ps = tool_summarize_policy_guideline(rs.get('policy_id'))
    dec = tool_check_claim_coverage(rs, ps)
    final_text = format_sample_style(dec['decision'], dec['reason'])
    _dbg("FINAL", {'patient_id': rs.get('patient_id'), 'final': final_text})
    return {'patient_id': record.get('patient_id') or record.get('id'),
            'generated_response': final_text,
            'decision_only': dec['decision'],
            'reason_only': dec['reason'],
            'procedures': ';'.join(rs.get('procedures', [])),
            'diagnoses': ';'.join(rs.get('diagnoses', [])),
            'policy_id': rs.get('policy_id') or ''}

# Validation preview
val = _load_records(VALIDATION_PATH)
if val:
    print('Validation sample (deterministic + DEBUG):')
    for rec in val[:VALIDATION_DEBUG_N]:
        out = run_record_det(rec)
        print(out['patient_id'], '->', out['generated_response'].replace('\n', ' | '))
else:
    print('No validation records found.')

# Full TEST run + preview + counts + optional debug CSV
test = _load_records(TEST_PATH)
rows = []
approve_cnt = 0
review_cnt = 0
for i, r in enumerate(test):
    out = run_record_det(r)
    if i < TEST_DEBUG_N:
        print(out['patient_id'], '->', out['generated_response'].replace('\n', ' | '))
    rows.append({'patient_id': out['patient_id'], 'generated_response': out['generated_response']})
    if out['decision_only'] == 'APPROVE': approve_cnt += 1
    else: review_cnt += 1

os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
print(f'Decision counts on TEST -> APPROVE: {approve_cnt}, ROUTE FOR REVIEW: {review_cnt}')

if WRITE_DEBUG_CSV:
    with open(DEBUG_TEST_PATH, 'w', newline='', encoding='utf-8') as f:
        w = csv.DictWriter(f, fieldnames=['patient_id','decision','reason','procedures','diagnoses','policy_id'])
        w.writeheader()
        for i, r in enumerate(test):
            o = run_record_det(r)
            w.writerow({'patient_id': o['patient_id'],
                        'decision': o['decision_only'],
                        'reason': o['reason_only'],
                        'procedures': o['procedures'],
                        'diagnoses': o['diagnoses'],
                        'policy_id': o['policy_id']})
    print(f'Wrote {DEBUG_TEST_PATH} (parsed fields + decision for each TEST record).')
