# Capstone 2 — Deterministic Claims Pipeline (No Agent Required)
This notebook implements the three required steps in pure Python and avoids LangChain agent prompt issues. It reads from `./Data` and writes `./Data/submission.csv`.


In [None]:
# Deterministic Healthcare Claim Coverage Pipeline
# ------------------------------------------------
# This notebook intentionally avoids LangChain agent wiring (and its prompt/placeholder
# requirements) and runs the 3 required steps deterministically in Python:
#   1) summarize_patient_record
#   2) summarize_policy_guideline
#   3) check_claim_coverage
# It reads from ./Data and writes ./Data/submission.csv exactly as the capstone expects.

from __future__ import annotations
from typing import Dict, Any, List, Optional, Tuple
import json, re, os, csv, sys
from dataclasses import dataclass

# ----------------------------- Paths -----------------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

# ------------------------- Utilities -----------------------------
def _safe_load_json(path: str, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f'⚠️ Missing file: {path}. Using default.')
        return default
    except json.JSONDecodeError as e:
        print(f'⚠️ JSON parse error in {path}: {e}. Using default.')
        return default

def coerce_int(val) -> Optional[int]:
    try:
        return int(val)
    except Exception:
        m = re.search(r'\d+', str(val)) if val is not None else None
        return int(m.group()) if m else None

def sex_normalize(s: Any) -> str:
    s = str(s).strip().upper()
    if s in {'M','MALE'}: return 'M'
    if s in {'F','FEMALE'}: return 'F'
    return 'U'

def dict_get(d, *keys, default=None):
    cur = d
    for k in keys:
        if not isinstance(cur, dict):
            return default
        cur = cur.get(k)
    return default if cur is None else cur

# ------------------------- Load References -----------------------
policies  = _safe_load_json(POLICIES_PATH, {})
ref_codes = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {})
CPT_TO_NAME = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {})

# ------------------------- Core Functions ------------------------
def summarize_patient_record(record: Dict[str, Any] | str) -> Dict[str, Any]:
    """Extract key fields from a claim record (dict or text).
    Returns: dict with patient_id, age, sex, diagnoses, procedures, policy_id, preauth_provided
    """
    # Handle either dict-like record or raw text
    if isinstance(record, dict):
        record_str = json.dumps(record, ensure_ascii=False)
        patient_id = record.get('patient_id') or record.get('id')
        age        = coerce_int(record.get('age'))
        sex        = sex_normalize(record.get('sex'))
        policy_id  = record.get('policy_id') or record.get('policy')
        preauth_provided = bool(record.get('preauth_provided') or record.get('preauthorization') in ['yes', True, 'true', 'approved'])
        diagnoses = set()
        for key in ['diagnoses', 'dx', 'icd_codes']:
            vals = record.get(key, [])
            if isinstance(vals, str): vals = re.findall(r'\b[A-Z]\d[\w.\-]*\b', vals.upper())
            for c in vals or []:
                c_up = str(c).upper()
                if c_up in ICD_TO_NAME: diagnoses.add(c_up)
        procedures = set()
        for key in ['procedures', 'cpt_codes']:
            vals = record.get(key, [])
            if isinstance(vals, str): vals = re.findall(r'\b\d{4,5}\b', vals)
            for c in vals or []:
                c_s = str(c)
                if c_s in CPT_TO_NAME: procedures.add(c_s)
    else:
        record_str = str(record)
        patient_id = None
        age = None
        sex = 'U'
        policy_id = None
        preauth_provided = False
        # Regex extraction from text
        m = re.search(r'patient[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', record_str, re.I)
        if m: patient_id = m.group(1)
        m = re.search(r'\bage[:\s-]*([0-9]{1,3})\b', record_str, re.I)
        if m: age = int(m.group(1))
        m = re.search(r'\b(sex|gender)[:\s-]*([A-Za-z]+)', record_str, re.I)
        if m: sex = sex_normalize(m.group(2))
        m = re.search(r'\bpolicy[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', record_str, re.I)
        if m: policy_id = m.group(1)
        preauth_provided = bool(re.search(r'pre-?auth(?:orization)?[:\s-]*(yes|y|true|provided|approved)', record_str, re.I))
        diagnoses = {c.upper() for c in re.findall(r'\b([A-Z]\d{1,2}[A-Z0-9.\-]+)\b', record_str) if c.upper() in ICD_TO_NAME}
        procedures = {c for c in re.findall(r'\b(\d{4,5})\b', record_str) if c in CPT_TO_NAME}

    return {
        'patient_id': patient_id,
        'age': age,
        'sex': sex if isinstance(record, dict) else sex,
        'diagnoses': sorted(diagnoses),
        'procedures': sorted(procedures),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    }

def summarize_policy_guideline(policy_id: str) -> Dict[str, Any]:
    pol = policies.get(policy_id) if isinstance(policies, dict) else None
    if not pol:
        return {'error': f"Unknown policy_id '{policy_id}'"}
    return {
        'policy_id': policy_id,
        'allowed_diagnoses': dict_get(pol, 'criteria', 'diagnoses', default=[]) or [],
        'allowed_procedures': dict_get(pol, 'criteria', 'procedures', default=[]) or [],
        'age_min': dict_get(pol, 'criteria', 'age_min', default=None),
        'age_max': dict_get(pol, 'criteria', 'age_max', default=None),
        'sex': dict_get(pol, 'criteria', 'sex', default=None),
        'preauth_required': bool(dict_get(pol, 'criteria', 'preauth_required', default=False)),
        'policy_title': pol.get('title', ''),
        'notes': pol.get('notes', '')
    }

def check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    # Validate required fields
    missing = [k for k in ['age','sex','diagnoses','procedures','policy_id'] if not record_summary.get(k)]
    if missing:
        return {'decision':'ROUTE FOR REVIEW', 'reason': f"Missing/insufficient fields: {', '.join(missing)}."}

    age = coerce_int(record_summary.get('age'))
    sex = str(record_summary.get('sex','U')).upper()
    dxs  = {d.upper() for d in record_summary.get('diagnoses',[])}
    cpts = set(record_summary.get('procedures',[]))

    allowed_dx  = {d.upper() for d in policy_summary.get('allowed_diagnoses',[])}
    allowed_cpt = set(policy_summary.get('allowed_procedures',[]))
    age_min = policy_summary.get('age_min')
    age_max = policy_summary.get('age_max')
    sex_rule = policy_summary.get('sex')
    preauth_required = bool(policy_summary.get('preauth_required', False))
    preauth_provided = bool(record_summary.get('preauth_provided', False))

    reasons = []
    if allowed_cpt and not (cpts & allowed_cpt):
        reasons.append('Claimed procedure not covered.')
    if allowed_dx and not (dxs & allowed_dx):
        reasons.append('Diagnosis not covered.')
    if age is not None:
        if age_min is not None and age < age_min: reasons.append(f'Patient age {age} is below {age_min}.')
        if age_max is not None and age > age_max: reasons.append(f'Patient age {age} exceeds {age_max}.')
    else:
        reasons.append('Age unavailable.')
    if sex_rule in {'M','F'} and sex != sex_rule:
        reasons.append(f'Policy restricted to sex {sex_rule}.')
    if preauth_required and not preauth_provided:
        reasons.append('Preauthorization required but not provided.')

    if reasons:
        return {'decision':'ROUTE FOR REVIEW', 'reason': '; '.join(reasons)[:500]}
    return {'decision':'APPROVE', 'reason':'Meets policy criteria.'}

# --------------------------- IO Helpers --------------------------
def load_records(path: str) -> List[Dict[str, Any]]:
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data:
        return data['records']
    if isinstance(data, list):
        return data
    print(f'⚠️ Unexpected JSON root at {path}; expecting list or dict.records. Got {type(data).__name__}.')
    return []

def run_on_record(rec: Dict[str, Any]) -> Dict[str,str]:
    # Accept either structured dict with 'record_str' or full dict record
    raw = rec.get('record_str', rec)
    rs = summarize_patient_record(raw)
    pid = rs.get('patient_id') or rec.get('patient_id')
    if not rs.get('policy_id'):
        rs['policy_id'] = rec.get('policy_id')
    pol_sum = summarize_policy_guideline(str(rs.get('policy_id')) if rs.get('policy_id') is not None else '')
    decision = check_claim_coverage(rs, pol_sum)
    line = f"Decision: {decision['decision']}. Reason: {decision['reason']}"
    return {'patient_id': pid, 'generated_response': line}

# ---------------------- Validation quick print -------------------
val = load_records(VALIDATION_PATH)
if val:
    print('Validation sample:')
    for rec in val[:3]:
        out = run_on_record(rec)
        print(out['patient_id'], '->', out['generated_response'])
else:
    print('No validation records found.')

# --------------------- Submission for test set -------------------
test = load_records(TEST_PATH)
rows = [run_on_record(r) for r in test]
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
