# Capstone 2 — Deterministic Runner (per‑CPT policy rules)
Implements the `covered_procedures` schema you shared and outputs sample-style text.

In [None]:
# Capstone 2 — Deterministic Runner (policy schema: `covered_procedures` with per-CPT rules)
# -----------------------------------------------------------------------------------------
# This version matches the policy JSON you showed:
#   {
#     "policy_id": "POL1001",
#     "plan_name": "...",
#     "covered_procedures": [
#       {
#         "procedure_code": "83036",
#         "covered_diagnoses": ["I20.0", "M54.5"],
#         "age_range": [72, 82],
#         "gender": "Any",
#         "requires_preauthorization": true,
#         "notes": "..."
#       },
#       ...
#     ]
#   }
#
# Decision logic (strict, deterministic):
#   • For each claimed CPT in the record, we need at least one matching rule entry in the policy:
#       - same CPT
#       - record has at least one diagnosis in rule.covered_diagnoses (if present)
#       - patient's age is within rule.age_range (if provided)
#       - patient's sex matches rule.gender (if provided / not 'Any')
#       - preauth is provided if rule.requires_preauthorization is true
#   • If ANY claimed CPT fails to find a matching rule, we ROUTE FOR REVIEW and explain the first failure.
#   • If policy_id unknown -> ROUTE FOR REVIEW.
#   • Output style in CSV matches the sample:
#       - Decision: ...
#       - Reason: ...
#
# Input/Output files under ./Data. Submission written to ./Data/submission.csv

from __future__ import annotations
from typing import Dict, Any, Optional, List, Tuple
import json, csv, os, re
from datetime import datetime

# ----------------------------- Paths -----------------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

# ------------------------- Safe JSON Loading ----------------------
def _safe_load_json(path, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception:
        return default

policies_raw  = _safe_load_json(POLICIES_PATH, {})
ref_codes     = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME   = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {}) or {}
CPT_TO_NAME   = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {}) or {}

# ------------------------------ Helpers ---------------------------
def _norm_pid(pid: Optional[str]) -> Optional[str]:
    if not pid: return None
    return re.sub(r'[-_\s]', '', str(pid)).upper()

def _build_policy_index(raw):
    """Return dict[policy_id] = policy_obj with normalized structure:
    policy_obj = {
      'policy_id': 'POL1001',
      'plan_name': '...',
      'rules': [ {'procedure':'83036','diagnoses':[...],'age_min':72,'age_max':82,'sex':'M|F|ANY','preauth_required':True}, ... ]
    }
    """
    def _extract_rules(src) -> List[Dict[str, Any]]:
        rules = []
        items = src.get('covered_procedures') or []
        if isinstance(items, list):
            for it in items:
                if not isinstance(it, dict): 
                    continue
                proc = str(it.get('procedure_code') or it.get('procedure') or '').strip()
                if not proc: 
                    continue
                diags = it.get('covered_diagnoses') or it.get('diagnoses') or []
                if isinstance(diags, str):
                    diags = [d for d in re.split(r'[;,\s]+', diags) if d]
                diags = [str(d).strip().upper() for d in diags]
                # normalize ICD: insert dot if missing
                diags = [ (d[:3] + '.' + d[3:]) if re.fullmatch(r'[A-Z]\d{3,6}', d) and '.' not in d and len(d)>=4 else d for d in diags ]
                ar = it.get('age_range') or []
                age_min = None; age_max = None
                if isinstance(ar, list) and len(ar) >= 2:
                    try: age_min = int(ar[0])
                    except Exception: pass
                    try: age_max = int(ar[1])
                    except Exception: pass
                sex_raw = str(it.get('gender') or it.get('sex') or 'Any').strip().upper()
                sex = 'ANY'
                if sex_raw.startswith('M'): sex = 'M'
                elif sex_raw.startswith('F'): sex = 'F'
                preauth_required = bool(it.get('requires_preauthorization') or it.get('preauth_required') or False)
                rules.append({
                    'procedure': proc,
                    'diagnoses': diags,
                    'age_min': age_min,
                    'age_max': age_max,
                    'sex': sex,
                    'preauth_required': preauth_required
                })
        return rules

    index = {}
    if isinstance(raw, dict):
        # could be dict of policy_id -> object, or wrapper with 'policies'
        if 'policies' in raw and isinstance(raw['policies'], list):
            src_list = raw['policies']
        else:
            # dict entries
            src_list = []
            for k, v in raw.items():
                if isinstance(v, dict):
                    v = dict(v)
                    v.setdefault('policy_id', v.get('policyId') or v.get('id') or k)
                    src_list.append(v)
        for obj in src_list:
            pid = _norm_pid(obj.get('policy_id') or obj.get('policyId') or obj.get('id'))
            if not pid: 
                continue
            policy = {
                'policy_id': pid,
                'plan_name': obj.get('plan_name') or obj.get('title') or '',
                'rules': _extract_rules(obj)
            }
            index[pid] = policy
    elif isinstance(raw, list):
        for obj in raw:
            if not isinstance(obj, dict): 
                continue
            pid = _norm_pid(obj.get('policy_id') or obj.get('policyId') or obj.get('id'))
            if not pid: 
                continue
            policy = {
                'policy_id': pid,
                'plan_name': obj.get('plan_name') or obj.get('title') or '',
                'rules': _extract_rules(obj)
            }
            index[pid] = policy
    return index

POLICY_INDEX = _build_policy_index(policies_raw)

DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def parse_date_maybe(s: Any) -> Optional[datetime]:
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try: return datetime.strptime(s, pat)
        except Exception: pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try: return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception: return None
    return None

def age_from_dob(dob: Optional[datetime], ref: Optional[datetime] = None) -> Optional[int]:
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def icd_dot_normalize(code: str) -> str:
    c = str(code).strip().upper()
    if re.fullmatch(r'[A-Z]\d{3,6}', c) and '.' not in c and len(c) >= 4:
        return c[:3] + '.' + c[3:]
    return c

def coerce_int(val) -> Optional[int]:
    try: return int(val)
    except Exception:
        if val is None: return None
        m = re.search(r'\d+', str(val)); return int(m.group()) if m else None

def walk_items(obj):
    if isinstance(obj, dict):
        for k, v in obj.items():
            yield k, v
            for subk, subv in walk_items(v):
                yield subk, subv
    elif isinstance(obj, list):
        for it in obj:
            for subk, subv in walk_items(it):
                yield subk, subv

def deep_get_first(d: Any, keys: List[str]):
    keys_l = {k.lower() for k in keys}
    for k, v in walk_items(d):
        if isinstance(k, str) and k.lower() in keys_l:
            return v
    return None

def collect_mixed(val):
    out = []
    if val is None: return out
    if isinstance(val, list):
        for it in val:
            if isinstance(it, dict):
                for k in ['code','id','icd','cpt']:
                    if it.get(k) is not None: out.append(str(it[k]))
                for k in ['name','desc','description','text']:
                    if it.get(k): out.append(str(it[k]))
            else:
                out.append(str(it))
    elif isinstance(val, dict):
        for k in ['code','id','icd','cpt']:
            if val.get(k) is not None: out.append(str(val[k]))
        for k in ['name','desc','description','text']:
            if val.get(k): out.append(str(val[k]))
    else:
        out.append(str(val))
    return out

# --------------------- Tools (deterministic versions) ---------------------
def tool_summarize_patient_record(record: Dict[str, Any]) -> Dict[str, Any]:
    patient_id = (record.get('patient_id') or record.get('id') or record.get('member_id') or record.get('patientId') or record.get('pid'))
    raw_pid = (record.get('policy_id') or record.get('policy') or record.get('plan_id')
               or record.get('insurancePolicyId') or deep_get_first(record, ['insurance_policy_id']))
    policy_id = _norm_pid(raw_pid)
    age = coerce_int(record.get('age') or deep_get_first(record, ['age','age_years','patient_age']))
    if age is None:
        dob = (record.get('dob') or record.get('date_of_birth') or record.get('birthdate')
               or deep_get_first(record, ['dob','date_of_birth','birthdate']))
        svc = (record.get('service_date') or record.get('claim_date') or record.get('date_of_service')
               or deep_get_first(record, ['service_date','claim_date','date_of_service']))
        age = age_from_dob(parse_date_maybe(dob), parse_date_maybe(svc))
    sex = sex_normalize(record.get('sex') or record.get('gender') or deep_get_first(record, ['sex','gender','patient_sex']))
    preauth_val = (record.get('preauth_provided') or deep_get_first(record, ['preauth','preauthorization','authorization','prior_auth','prior_authorization','preauth_provided']))
    preauth_provided = False
    if preauth_val is not None:
        preauth_provided = str(preauth_val).strip().lower() in {'1','y','yes','true','approved','provided'}
    raw_dx  = (record.get('diagnosis_codes') or record.get('diagnoses') or record.get('dx') or record.get('icd') or record.get('icd_codes')
               or deep_get_first(record, ['diagnosis_codes','diagnoses','diagnosis','dx','icd','icd_codes']))
    raw_cpt = (record.get('procedure_codes') or record.get('procedures') or record.get('cpt') or record.get('cpt_codes') or record.get('services')
               or deep_get_first(record, ['procedure_codes','procedures','procedure','cpt','cpt_codes','services']))
    dx_codes  = {icd_dot_normalize(s) for s in collect_mixed(raw_dx)}
    cpt_codes = set(collect_mixed(raw_cpt))

    if ICD_TO_NAME: dx_codes = {c for c in dx_codes if c in ICD_TO_NAME}
    if CPT_TO_NAME:
        known = {c for c in cpt_codes if c in CPT_TO_NAME}
        cpt_codes = known if known else {c for c in cpt_codes if re.fullmatch(r'\d{5}', c)}

    return {
        'patient_id': patient_id,
        'age': age,
        'sex': sex,
        'diagnoses': sorted(dx_codes),
        'procedures': sorted(cpt_codes),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    }

def tool_summarize_policy_guideline(policy_id: Optional[str]) -> Dict[str, Any]:
    if not policy_id:
        return {'error': 'Unknown policy_id'}
    pol = POLICY_INDEX.get(policy_id)
    if not pol:
        return {'error': 'Unknown policy_id'}
    return pol  # already normalized with 'rules' list

def _rule_covers(record: Dict[str, Any], rule: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    """Return (ok, reason_if_not) for a single rule vs record."""
    cpts = set(record.get('procedures', []))
    dxs  = set(record.get('diagnoses', []))
    age  = record.get('age')
    sex  = record.get('sex', 'U').upper()
    pre  = bool(record.get('preauth_provided', False))

    # procedure check (rule is per-procedure)
    if rule['procedure'] not in cpts:
        return False, f"Policy rule only covers CPT {rule['procedure']}."
    # diagnosis intersection if provided
    rdx = set(rule.get('diagnoses', []))
    if rdx and not (dxs & rdx):
        return False, f"CPT {rule['procedure']} not covered for provided diagnosis codes."
    # age range
    if age is not None:
        if rule.get('age_min') is not None and age < rule['age_min']:
            return False, f"Patient age {age} below covered minimum for CPT {rule['procedure']}."
        if rule.get('age_max') is not None and age > rule['age_max']:
            return False, f"Patient age {age} exceeds covered maximum for CPT {rule['procedure']}."
    # sex
    sex_rule = rule.get('sex', 'ANY')
    if sex_rule in {'M','F'} and sex != sex_rule:
        return False, f"Policy for CPT {rule['procedure']} applies only to sex {sex_rule}."
    # preauth
    if rule.get('preauth_required') and not pre:
        return False, f"Preauthorization required for CPT {rule['procedure']} but not provided."
    return True, None

def tool_check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    # Unknown policy
    if policy_summary.get('error'):
        return {'decision':'ROUTE FOR REVIEW','reason':'Unknown policy_id.'}
    rules = policy_summary.get('rules', [])
    if not rules:
        return {'decision':'ROUTE FOR REVIEW','reason':'Policy does not define covered procedures.'}

    # For each claimed CPT there must be at least one rule that *fully* covers it
    for cpt in record_summary.get('procedures', []):
        # evaluate all rules for this cpt
        rel = [r for r in rules if r.get('procedure') == cpt]
        if not rel:
            return {'decision':'ROUTE FOR REVIEW','reason': f"The claim for CPT code {cpt} is not covered by the policy."}
        # find a rule that passes all checks
        ok_any = False; first_reason = None
        for r in rel:
            ok, why = _rule_covers(record_summary, r)
            if ok:
                ok_any = True; break
            if first_reason is None: first_reason = why
        if not ok_any:
            return {'decision':'ROUTE FOR REVIEW','reason': first_reason or f"The claim for CPT code {cpt} does not meet policy requirements."}

    # If we reach here, all claimed CPTs are covered by at least one rule
    # Provide sample-style APPROVE reason
    cpts = record_summary.get('procedures', [])
    detail = f"The claim for CPT code {cpts[0]} is approved." if cpts else "Meets policy criteria."
    return {'decision':'APPROVE','reason': detail}

# ----------------------------- Runner -----------------------------
def _load_records(path):
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data: return data['records']
    if isinstance(data, list): return data
    return []

def format_sample_style(decision: str, reason: str) -> str:
    return f"- Decision: {decision}\n- Reason: {reason}"

def run_record_det(record):
    rs = tool_summarize_patient_record(record)
    ps = tool_summarize_policy_guideline(rs.get('policy_id'))
    dec = tool_check_claim_coverage(rs, ps)
    final_text = format_sample_style(dec['decision'], dec['reason'])
    return {'patient_id': record.get('patient_id') or record.get('id'), 'generated_response': final_text}

# Preview validation
val = _load_records(VALIDATION_PATH)
if val:
    print('Validation sample (deterministic + policy rules):')
    for rec in val[:3]:
        out = run_record_det(rec)
        print(out['patient_id'], '->', out['generated_response'].replace('\n', ' | '))
else:
    print('No validation records found.')

# Full test run
test = _load_records(TEST_PATH)
rows = [run_record_det(r) for r in test]
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
