# Capstone 2 — Deterministic Claims Pipeline (Targeted Fixes)
Picks up `insurancePolicyId`, `diagnosis_codes`, `procedure_codes`; safer policy-id regex; debug printing enabled.


In [None]:
# Deterministic Claims Pipeline — Targeted Fixes (code_16)
# -------------------------------------------------------
# Fixes based on your sample:
# - Pull `policy_id` from `insurancePolicyId`
# - Pull diagnoses from `diagnosis_codes`
# - Pull procedures from `procedure_codes`
# - Harden policy-id regex to avoid false matches (won't misread 'patient_id')
# - Keep debug prints for the first 3 validation rows

from __future__ import annotations
from typing import Dict, Any, List, Optional, Iterable
import json, re, os, csv
from datetime import datetime

# ----------------------------- Paths -----------------------------
DATA_DIR = './Data'
VALIDATION_PATH = f'{DATA_DIR}/validation_records.json'
TEST_PATH       = f'{DATA_DIR}/test_records.json'
POLICIES_PATH   = f'{DATA_DIR}/insurance_policies.json'
REF_CODES_PATH  = f'{DATA_DIR}/reference_codes.json'
SUBMISSION_PATH = f'{DATA_DIR}/submission.csv'

# ------------------------- Utilities -----------------------------
DATE_PATTERNS = ['%Y-%m-%d','%m/%d/%Y','%d/%m/%Y','%Y/%m/%d']

def _safe_load_json(path: str, default):
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f'Missing file: {path}. Using default.')
        return default
    except json.JSONDecodeError as e:
        print(f'JSON parse error in {path}: {e}. Using default.')
        return default

def parse_date_maybe(s):
    if not s: return None
    s = str(s).strip()
    for pat in DATE_PATTERNS:
        try:
            return datetime.strptime(s, pat)
        except Exception:
            pass
    m = re.fullmatch(r'(\d{4})(\d{2})(\d{2})', s)
    if m:
        try:
            return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)))
        except Exception:
            return None
    return None

def age_from_dob(dob, ref=None):
    if not dob: return None
    ref = ref or datetime.utcnow()
    years = ref.year - dob.year - ((ref.month, ref.day) < (dob.month, dob.day))
    return years if 0 <= years < 150 else None

def coerce_int(val):
    try:
        return int(val)
    except Exception:
        if val is None: return None
        m = re.search(r'\d+', str(val)); return int(m.group()) if m else None

def sex_normalize(s):
    if s is None: return 'U'
    s_up = str(s).strip().upper()
    if s_up in {'M','MALE'}: return 'M'
    if s_up in {'F','FEMALE'}: return 'F'
    return 'U'

def walk_items(obj):
    if isinstance(obj, dict):
        for k, v in obj.items():
            yield k, v
            for subk, subv in walk_items(v):
                yield subk, subv
    elif isinstance(obj, list):
        for it in obj:
            for subk, subv in walk_items(it):
                yield subk, subv

def deep_find_first(data, keys):
    keys_l = {k.lower() for k in keys}
    for k, v in walk_items(data):
        if isinstance(k, str) and k.lower() in keys_l:
            return v
    return None

def collect_codes_from_mixed(val):
    out = []
    if val is None: return out
    if isinstance(val, list):
        for it in val:
            if isinstance(it, dict):
                for ck in ['code','id','cpt','icd']:
                    if it.get(ck) is not None: out.append(str(it[ck]))
                for nk in ['name','desc','description']:
                    if it.get(nk): out.append(str(it[nk]))
            else:
                out.append(str(it))
    elif isinstance(val, dict):
        c = val.get('code') or val.get('id')
        if c: out.append(str(c))
        n = val.get('name') or val.get('desc') or val.get('description')
        if n: out.append(str(n))
    else:
        out.append(str(val))
    return out

# ------------------------- Load References -----------------------
policies  = _safe_load_json(POLICIES_PATH, {})
ref_codes = _safe_load_json(REF_CODES_PATH, {})
ICD_TO_NAME = ref_codes.get('diagnosis_codes', {}) or ref_codes.get('icd10', {})
CPT_TO_NAME = ref_codes.get('procedure_codes', {}) or ref_codes.get('cpt', {})

ICD_NAME_TO_CODE = {str(v).strip().lower(): k for k, v in (ICD_TO_NAME or {}).items()}
CPT_NAME_TO_CODE = {str(v).strip().lower(): k for k, v in (CPT_TO_NAME or {}).items()}

def name_to_code_unique(name, name_to_code):
    n = name.strip().lower()
    if n in name_to_code: return name_to_code[n]
    cands = [code for nm, code in name_to_code.items() if nm.startswith(n)]
    if len(cands) == 1: return cands[0]
    if len(n) >= 4:
        cands = [code for nm, code in name_to_code.items() if n in nm]
        if len(cands) == 1: return cands[0]
    return None

# ------------------------- Core Functions ------------------------
def summarize_patient_record(record):
    if isinstance(record, dict):
        # patient & policy ids
        patient_id = (record.get('patient_id') or record.get('id')
                      or deep_find_first(record, ['patient_id','member_id','patientId','pid']))
        policy_id = (record.get('policy_id') or record.get('policy') or record.get('plan_id')
                     or record.get('insurancePolicyId')  # <— explicit from your sample
                     or deep_find_first(record, ['policy_id','policy','plan_id','policyId','insurancePolicyId','insurance_policy_id','plan']))

        # demographics
        age = coerce_int(record.get('age') or deep_find_first(record, ['age','age_years','patient_age']))
        sex = sex_normalize(record.get('sex') or record.get('gender') or deep_find_first(record, ['sex','gender','patient_sex']))
        dob_raw = (record.get('dob') or record.get('date_of_birth') or record.get('birthdate')
                   or deep_find_first(record, ['dob','date_of_birth','birthdate']))
        svc_raw = (record.get('service_date') or deep_find_first(record, ['service_date','date_of_service']))
        dob_dt = parse_date_maybe(dob_raw); svc_dt = parse_date_maybe(svc_raw)
        if age is None:
            age = age_from_dob(dob_dt, svc_dt)

        # preauth
        preauth_val = (record.get('preauth_provided')
                       or deep_find_first(record, ['preauth','preauthorization','authorization','prior_auth','prior_authorization','preauth_provided']))
        preauth_provided = str(preauth_val).strip().lower() in {'1','y','yes','true','approved','provided'} if preauth_val is not None else False

        # codes (explicit keys from your sample + aliases)
        raw_dx = (record.get('diagnosis_codes') or record.get('diagnoses') or record.get('dx') or record.get('icd') or record.get('icd_codes')
                  or deep_find_first(record, ['diagnosis_codes','diagnoses','diagnosis','dx','icd','icd_codes']))
        raw_cpt = (record.get('procedure_codes') or record.get('procedures') or record.get('cpt') or record.get('cpt_codes') or record.get('services')
                   or deep_find_first(record, ['procedure_codes','procedures','procedure','cpt','cpt_codes','services']))
        dx_items  = collect_codes_from_mixed(raw_dx)
        cpt_items = collect_codes_from_mixed(raw_cpt)

        record_str = json.dumps(record, ensure_ascii=False)
    else:
        record_str = str(record)
        patient_id = policy_id = None; age=None; sex='U'; preauth_provided=False
        dx_items, cpt_items = [], []

    # Regex extraction from text (policy regex hardened: require 'policy' whole word)
    m = re.search(r'\bpatient[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', record_str, re.I)
    if m and not patient_id: patient_id = m.group(1)
    m = re.search(r'\bage[:\s-]*([0-9]{1,3})\b', record_str, re.I)
    if m and age is None: age = int(m.group(1))
    m = re.search(r'\b(sex|gender)[:\s-]*([A-Za-z])\b', record_str, re.I)
    if m and (not isinstance(record, dict) or sex=='U'): sex = sex_normalize(m.group(2))
    m = re.search(r'(?<!patient)\bpolicy[_\s]*id[:\s-]*([A-Za-z0-9_-]+)', record_str, re.I)  # negative lookbehind to avoid 'patient_id'
    if m and not policy_id: policy_id = m.group(1)
    if re.search(r'pre-?auth(?:orization)?[:\s-]*(yes|y|true|provided|approved)', record_str, re.I):
        preauth_provided = True if not isinstance(record, dict) else preauth_provided or True

    # Add codes from text
    dx_items += re.findall(r'\b([A-Z][0-9][A-Z0-9.\-]{1,6})\b', record_str)
    cpt_items += re.findall(r'\b(\d{4,5})\b', record_str)

    # Normalize / map to valid codes
    dx_codes = set()
    for item in dx_items:
        s = str(item).strip()
        if not s: continue
        if isinstance(ICD_TO_NAME, dict) and s.upper() in ICD_TO_NAME:
            dx_codes.add(s.upper()); continue
        code = name_to_code_unique(s, ICD_NAME_TO_CODE)
        if code: dx_codes.add(code)

    cpt_codes = set()
    for item in cpt_items:
        s = str(item).strip()
        if not s: continue
        if isinstance(CPT_TO_NAME, dict) and s in CPT_TO_NAME:
            cpt_codes.add(s); continue
        code = name_to_code_unique(s, CPT_NAME_TO_CODE)
        if code: cpt_codes.add(code)

    return {
        'patient_id': patient_id,
        'age': age,
        'sex': sex,
        'diagnoses': sorted(dx_codes),
        'procedures': sorted(cpt_codes),
        'policy_id': policy_id,
        'preauth_provided': preauth_provided
    }

def summarize_policy_guideline(policy_id: str) -> Dict[str, Any]:
    pol = policies.get(str(policy_id)) if isinstance(policies, dict) else None
    if not pol:
        return {'error': f"Unknown policy_id '{policy_id}'"}
    crit = pol.get('criteria', {}) or {}
    return {
        'policy_id': policy_id,
        'allowed_diagnoses': crit.get('diagnoses', []) or [],
        'allowed_procedures': crit.get('procedures', []) or [],
        'age_min': crit.get('age_min'),
        'age_max': crit.get('age_max'),
        'sex': crit.get('sex'),
        'preauth_required': bool(crit.get('preauth_required', False)),
        'policy_title': pol.get('title', ''),
        'notes': pol.get('notes', '')
    }

def check_claim_coverage(record_summary: Dict[str, Any], policy_summary: Dict[str, Any]) -> Dict[str, str]:
    missing = [k for k in ['age','sex','diagnoses','procedures','policy_id'] if not record_summary.get(k)]
    if missing:
        return {'decision':'ROUTE FOR REVIEW', 'reason': f"Missing/insufficient fields: {', '.join(missing)}."}

    def CI(v):
        try: return int(v)
        except Exception: return None

    age = CI(record_summary.get('age'))
    sex = str(record_summary.get('sex','U')).upper()
    dxs  = {d.upper() for d in record_summary.get('diagnoses',[])}
    cpts = set(record_summary.get('procedures',[]))

    allowed_dx  = {d.upper() for d in policy_summary.get('allowed_diagnoses',[])}
    allowed_cpt = set(policy_summary.get('allowed_procedures',[]))
    age_min = policy_summary.get('age_min'); age_max = policy_summary.get('age_max')
    sex_rule = policy_summary.get('sex')
    preauth_required = bool(policy_summary.get('preauth_required', False))
    preauth_provided = bool(record_summary.get('preauth_provided', False))

    reasons = []
    if allowed_cpt and not (cpts & allowed_cpt):
        reasons.append('Claimed procedure not covered.')
    if allowed_dx and not (dxs & allowed_dx):
        reasons.append('Diagnosis not covered.')
    if age is not None:
        if age_min is not None and age < age_min: reasons.append(f'Patient age {age} is below {age_min}.')
        if age_max is not None and age > age_max: reasons.append(f'Patient age {age} exceeds {age_max}.')
    else:
        reasons.append('Age unavailable.')
    if sex_rule in {'M','F'} and sex != sex_rule:
        reasons.append(f'Policy restricted to sex {sex_rule}.')
    if preauth_required and not preauth_provided:
        reasons.append('Preauthorization required but not provided.')

    if reasons:
        return {'decision':'ROUTE FOR REVIEW', 'reason': '; '.join(reasons)[:500]}
    return {'decision':'APPROVE', 'reason':'Meets policy criteria.'}

# --------------------------- IO Helpers --------------------------
def load_records(path: str):
    data = _safe_load_json(path, [])
    if isinstance(data, dict) and 'records' in data:
        return data['records']
    if isinstance(data, list):
        return data
    print(f'Unexpected JSON root at {path}; expecting list or dict.records. Got {type(data).__name__}.')
    return []

def run_on_record(rec: Dict[str, Any]) -> Dict[str,str]:
    raw = rec.get('record_str', rec)
    rs = summarize_patient_record(raw)
    rs['patient_id'] = rs.get('patient_id') or rec.get('patient_id')
    rs['policy_id'] = rs.get('policy_id') or rec.get('policy_id') or rec.get('insurancePolicyId')
    pol_sum = summarize_policy_guideline(str(rs.get('policy_id') or ''))
    decision = check_claim_coverage(rs, pol_sum)
    line = f"Decision: {decision['decision']}. Reason: {decision['reason']}"
    return {'patient_id': rs.get('patient_id'), 'generated_response': line}

# ---------------------- Validation debug print -------------------
val = load_records(VALIDATION_PATH)
if val:
    print('Validation sample:')
    for rec in val[:3]:
        raw = rec.get('record_str', rec)
        parsed = summarize_patient_record(raw)
        parsed['policy_id'] = parsed.get('policy_id') or rec.get('insurancePolicyId')
        print('RAW RECORD (truncated):\n', json.dumps(rec, indent=2)[:700], '...')
        print('PARSED SUMMARY:', parsed)
        pol_sum = summarize_policy_guideline(str(parsed.get('policy_id') or ''))
        print('POLICY SUMMARY:', pol_sum)
        decision = check_claim_coverage(parsed, pol_sum)
        print('DECISION:', decision)
        print('-'*90)
else:
    print('No validation records found.')

# --------------------- Submission for test set -------------------
test = load_records(TEST_PATH)
rows = [run_on_record(r) for r in test]
os.makedirs(os.path.dirname(SUBMISSION_PATH), exist_ok=True)
with open(SUBMISSION_PATH, 'w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=['patient_id','generated_response'])
    w.writeheader(); w.writerows(rows)
print(f'Wrote {SUBMISSION_PATH} with {len(rows)} rows.')
