### Stage 1 ETL

- takes prov & rates parquet files from mrf engine repo

In [2]:
# ---------- CONFIG ----------
from pathlib import Path

# Change these for each run
STATE = "GA"  # "VA", etc.

# Optional hard override (otherwise computed from reporting_entity_name)
PAYER_SLUG_OVERRIDE = None  # e.g., "aetna", "uhc"

# Data store root (dims/xrefs/gold will live under here)
# This notebook is under ETL/, so we need to back out to the parent to be at the same level as 'data'
# __file__ is not defined in Jupyter, so use notebook location as base
import os

try:
    # Try to use __file__ if running as a script
    DS_ROOT = Path(__file__).parent.parent / "data"
except NameError:
    # Fallback for Jupyter: use current working directory, assume ETL/ as CWD
    DS_ROOT = Path(os.getcwd()).parent / "data"

DS_ROOT.mkdir(parents=True, exist_ok=True)
RATES_PARQ = DS_ROOT / "input/202508_aetna_ga_rates.parquet"    # switch to UHC file on rerun
PROV_PARQ  = DS_ROOT / "input/202508_aetna_ga_providers.parquet"  # switch to UHC file on rerun

# Outputs
DIM_DIR   = DS_ROOT / "dims"
XREF_DIR  = DS_ROOT / "xrefs"
GOLD_DIR  = DS_ROOT / "gold"
DIM_DIR.mkdir(parents=True, exist_ok=True)
XREF_DIR.mkdir(parents=True, exist_ok=True)
GOLD_DIR.mkdir(parents=True, exist_ok=True)

# Parquet file targets
DIM_CODE_FILE   = DIM_DIR  / "dim_code.parquet"
DIM_PAYER_FILE  = DIM_DIR  / "dim_payer.parquet"
DIM_PG_FILE     = DIM_DIR  / "dim_provider_group.parquet"
DIM_POS_FILE    = DIM_DIR  / "dim_pos_set.parquet"

XREF_PG_NPI     = XREF_DIR / "xref_pg_member_npi.parquet"
XREF_PG_TIN     = XREF_DIR / "xref_pg_member_tin.parquet"

GOLD_FACT_FILE  = GOLD_DIR / "fact_rate.parquet"

# Columns we’ll actually read (memory saver)
RATES_COLS = [
    "last_updated_on","reporting_entity_name","version",
    "billing_class","billing_code_type","billing_code",
    "service_codes","negotiated_type","negotiation_arrangement",
    "negotiated_rate","expiration_date","description","name",
    "provider_reference_id","provider_group_id","provider_group_id_raw"
]

# Some MRFs don’t include both provider_reference_id & provider_group_id; we handle that later.
PROV_COLS = [
    "last_updated_on","reporting_entity_name","version",
    "provider_group_id","provider_reference_id",
    "npi","tin_type","tin_value"
]


Cell 2 — Helpers & keys

In [3]:
import os, sys, math, json, re, hashlib, duckdb
import polars as pl
from datetime import datetime

print("Python:", sys.version)
print("Polars:", pl.__version__)
print("DuckDB:", duckdb.__version__)


Python: 3.13.2 (tags/v3.13.2:4f8bb39, Feb  4 2025, 15:23:48) [MSC v.1942 64 bit (AMD64)]
Polars: 1.29.0
DuckDB: 1.2.1


Cell 3 — Ingest & normalize this batch (mint batch_id & pg_uid)

In [4]:
def md5(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()

def slugify(s: str) -> str:
    if s is None:
        return ""
    s = s.lower()
    s = re.sub(r"[^a-z0-9]+", "-", s).strip("-")
    s = re.sub(r"-+", "-", s)
    return s

def _co(x):
    return "" if x is None else str(x)

def payer_slug_from_name(name: str) -> str:
    if PAYER_SLUG_OVERRIDE:
        return PAYER_SLUG_OVERRIDE
    return slugify(name or "")

def normalize_yymm(date_str: str | None) -> str:
    if not date_str:
        return ""
    # try common formats
    for fmt in ("%Y-%m-%d", "%Y/%m/%d", "%Y-%m", "%Y/%m", "%Y%m%d", "%Y%m"):
        try:
            dt = datetime.strptime(date_str[:len(fmt.replace("%","").replace("-","").replace("/",""))], fmt)
            return dt.strftime("%Y-%m")
        except Exception:
            continue
    # fallback: extract yyyy-mm
    m = re.search(r"(20\d{2})[-/](0[1-9]|1[0-2])", date_str)
    return f"{m.group(1)}-{m.group(2)}" if m else ""

def normalize_service_codes(svc) -> list[str]:
    """
    Normalize service_codes into a sorted unique list of strings.
    Handles None, list/tuple, JSON-like strings, and CSV-ish strings.
    Never uses bare truthiness checks that could hit Series ambiguity.
    """
    if svc is None:
        return []

    # If it's already a list/tuple, stringify elements
    if isinstance(svc, (list, tuple)):
        vals = ["" if v is None else str(v) for v in svc]
    else:
        s = str(svc)

        # Try to parse JSON-like list strings (e.g., '["11","22"]')
        if s.startswith("[") and s.endswith("]"):
            try:
                import json
                parsed = json.loads(s)
                if isinstance(parsed, list):
                    vals = ["" if v is None else str(v) for v in parsed]
                else:
                    vals = re.split(r"[;,|\s]+", s)
            except Exception:
                vals = re.split(r"[;,|\s]+", s)
        else:
            vals = re.split(r"[;,|\s]+", s)

    cleaned: list[str] = []
    for v in vals:
        sv = str(v).strip()
        if len(sv) > 0:
            cleaned.append(sv)

    # dedupe + sorted
    return sorted(set(cleaned))


def pos_set_id_from_members(members) -> str:
    """
    Make a stable id from a list of POS members.
    Avoid bare truthiness; handle non-list defensive cases.
    """
    if members is None:
        return md5("none")
    try:
        n = len(members)
    except Exception:
        # not iterable; coerce to single-element list
        members = [str(members)]
        n = 1
    if n == 0:
        return md5("none")
    # stringify defensively
    parts = ["" if m is None else str(m) for m in members]
    return md5("|".join(parts))


def pg_uid_from_parts(payer_slug: str, version: str | None, pgid: str | None, pref: str | None) -> str:
    # Prefer group_id, fall back to provider_reference_id
    key = f"{_co(payer_slug)}|{_co(version)}|{_co(pgid)}|{_co(pref)}"
    return md5(key)

def fact_uid_from_struct(s: dict) -> str:
    """
    Deterministic ID for one fact row (idempotent upsert).
    Includes STATE so GA vs VA don't collide.
    Rounds negotiated_rate to 4 decimals to avoid float drift.
    """
    rate_val = s.get("negotiated_rate")
    try:
        rate_str = f"{float(rate_val):.4f}" if rate_val is not None else ""
    except Exception:
        rate_str = ""
    parts = [
        _co(s.get("state")),
        _co(s.get("year_month")),
        _co(s.get("payer_slug")),
        _co(s.get("billing_class")),
        _co(s.get("code_type")),
        _co(s.get("code")),
        _co(s.get("pg_uid")),
        _co(s.get("pos_set_id")),
        _co(s.get("negotiated_type")),
        _co(s.get("negotiation_arrangement")),
        _co(s.get("expiration_date")),
        rate_str,
        _co(s.get("provider_group_id_raw")),
    ]
    return md5("|".join(parts))

def prj_cols(df: pl.DataFrame, cols: list[str]) -> pl.DataFrame:
    missing = [c for c in cols if c not in df.columns]
    if missing:
        # add missing as nulls so downstream code doesn't crash
        df = df.with_columns([pl.lit(None).alias(c) for c in missing])
    return df.select(cols)


Cell 4 — Small “append-unique” writers for dims/xrefs

In [5]:
# Robust read that tolerates missing columns by intersecting with the file schema.
# Then, we add any still-missing columns as nulls so later code can rely on them.

def read_parquet_safely(path: Path, desired_cols: list[str]) -> pl.DataFrame:
    lf = pl.scan_parquet(str(path))  # lazy, no data loaded yet
    avail = set(lf.columns)
    use_cols = [c for c in desired_cols if c in avail]
    df = lf.select(use_cols).collect()  # load only what's present

    # back-fill missing columns as nulls to keep downstream selects happy
    missing = [c for c in desired_cols if c not in df.columns]
    if missing:
        df = df.with_columns([pl.lit(None).alias(c) for c in missing])
    return df

if not RATES_PARQ.exists():
    raise FileNotFoundError(f"Missing RATES_PARQ: {RATES_PARQ}")
if not PROV_PARQ.exists():
    raise FileNotFoundError(f"Missing PROV_PARQ:  {PROV_PARQ}")

# TIP: if your RATES_COLS still includes super-optional fields,
# it's fine because read_parquet_safely will just back-fill them.
rates = read_parquet_safely(RATES_PARQ, RATES_COLS)
prov  = read_parquet_safely(PROV_PARQ,  PROV_COLS)

print("rates rows:", rates.height, "cols:", len(rates.columns))
print("prov  rows:", prov.height, "cols:", len(prov.columns))

# Optional quick peek if you’re debugging schemas:
print("RATES present cols:", [c for c in rates.columns if rates.select(pl.col(c).is_not_null().any()).item()])
print("PROV  present cols:", [c for c in prov.columns  if prov.select(pl.col(c).is_not_null().any()).item()])


  avail = set(lf.columns)


rates rows: 13456820 cols: 16
prov  rows: 131984 cols: 8
RATES present cols: ['last_updated_on', 'reporting_entity_name', 'version', 'billing_class', 'billing_code_type', 'billing_code', 'service_codes', 'negotiated_type', 'negotiation_arrangement', 'negotiated_rate', 'expiration_date', 'description', 'name', 'provider_reference_id']
PROV  present cols: ['last_updated_on', 'reporting_entity_name', 'version', 'provider_group_id', 'npi', 'tin_type', 'tin_value']


Cell 5 — Upsert the FACT into a state-scoped Parquet dataset (dedup partitions)

In [6]:
# ---------- Efficient, Batched Upsert for Notebooks ----------
# The classic upsert (append-unique) pattern is expensive because it reads the entire existing Parquet file
# and does a full in-memory join for deduplication. This is not notebook-friendly for large files.
#
# To keep this notebook-friendly and avoid reading the whole file:
# - Only process/apply upserts in small batches (e.g., per payer, per state, or per chunk).
# - Only read the *keys* column(s) from the existing file, and only if the file is small.
# - If the file is large, maintain a separate "index" of keys (e.g., as a CSV or Parquet of just the keys).
# - For truly large-scale, use a database (DuckDB, SQLite) for deduplication, or partition files by key.
#
# Below, we show a batched, memory-light approach for upserts in a notebook.

# ---------- Compat helpers (no .str.strip / no .strip_matches) ----------
def _trim(expr: pl.Expr) -> pl.Expr:
    # remove leading/trailing whitespace via regex
    return expr.str.replace_all(r"^\s+|\s+$", "")

def _trim_dashes(expr: pl.Expr) -> pl.Expr:
    # remove leading/trailing '-' via regex
    return expr.str.replace_all(r"^-+|-+$", "")

def _payer_slug_expr() -> pl.Expr:
    base_slug = (
        pl.col("reporting_entity_name")
          .fill_null("")
          .str.to_lowercase()
          .str.replace_all(r"[^a-z0-9]+", "-")    # collapse to dashes
          .str.replace_all(r"-{2,}", "-")         # dedupe dashes
    )
    base_slug = _trim_dashes(base_slug)           # trim edge dashes
    override = pl.lit(PAYER_SLUG_OVERRIDE)
    return pl.when(override.is_not_null() & (override != "")) \
             .then(override) \
             .otherwise(base_slug)

def _year_month_expr() -> pl.Expr:
    return (
        pl.coalesce([
            pl.col("last_updated_on").str.strptime(pl.Date, "%Y-%m-%d", strict=False),
            pl.col("last_updated_on").str.strptime(pl.Date, "%Y/%m/%d", strict=False),
            pl.col("last_updated_on").str.strptime(pl.Date, "%Y-%m",     strict=False),
            pl.col("last_updated_on").str.strptime(pl.Date, "%Y/%m",     strict=False),
            pl.col("last_updated_on").str.strptime(pl.Date, "%Y%m%d",    strict=False),
            pl.col("last_updated_on").str.strptime(pl.Date, "%Y%m",      strict=False),
        ])
        .dt.strftime("%Y-%m")
        .fill_null("")
    )

# service_codes -> normalized list[str] (robust: no json_decode)
_sc = pl.col("service_codes").cast(pl.Utf8)

# regex-trim helper (compat with older Polars)
def _trim(expr: pl.Expr) -> pl.Expr:
    return expr.str.replace_all(r"^\s+|\s+$", "")

pos_members_expr = (
    _trim(
        _sc.fill_null("")
           # strip brackets/braces/parens and quotes that break JSON parsing
           .str.replace_all(r"[\[\]\{\}\(\)]", " ")
           .str.replace_all(r"[\"']", " ")
           # unify separators to spaces
           .str.replace_all(r"[;,|]+", " ")
           .str.replace_all(r"\s+", " ")
    )
    .str.split(" ")
    .list.eval(_trim(pl.element()))
    .list.eval(pl.when(pl.element() == "").then(None).otherwise(pl.element()))
    .list.drop_nulls()
    .list.unique()
    .list.sort()
)

# ---------- build base + IDs ----------
base = (
    rates
      .with_columns([
          _payer_slug_expr().alias("payer_slug"),
          _year_month_expr().alias("year_month"),
          pos_members_expr.alias("pos_members"),
      ])
)

# pos_set_id = md5 of normalized list (use list.join, not arr.join)
base = (
    base
      .with_columns(pl.col("pos_members").list.join("|").alias("pos_key"))
      .with_columns(pl.col("pos_key").map_elements(md5, return_dtype=pl.Utf8).alias("pos_set_id"))
      .drop("pos_key")
)

# pg_uid = md5("payer_slug|version|provider_reference_id|"") for rates
base = (
    base
      .with_columns(
          pl.concat_str(
              [
                  pl.col("payer_slug"),
                  pl.col("version").fill_null(""),
                  pl.col("provider_reference_id").fill_null(""),  # ✅ Use provider_reference_id from rates
                  pl.lit(""),  # provider_group_id doesn't exist in rates
              ],
              separator="|",
          ).alias("pg_key")
      )
      .with_columns(pl.col("pg_key").map_elements(md5, return_dtype=pl.Utf8).alias("pg_uid"))
      .drop("pg_key")
)

# ---------- dims/xrefs ----------
dim_code_new = (
    base.select([
        pl.col("billing_code_type").alias("code_type"),
        pl.col("billing_code").cast(pl.Utf8).alias("code"),
        pl.col("description").alias("code_description"),
        pl.col("name").alias("code_name"),
    ])
    .drop_nulls(subset=["code_type","code"])
    .unique()
)

dim_payer_new = (
    base.select([
        pl.col("payer_slug"),
        pl.col("reporting_entity_name").alias("reporting_entity_name"),
        pl.col("version").alias("version"),
    ])
    .drop_nulls(subset=["payer_slug"])
    .unique()
)

dim_pg_new = (
    base.select([
        pl.col("pg_uid"),
        pl.col("payer_slug"),
        pl.coalesce([pl.col("provider_group_id"), pl.col("provider_reference_id")]).alias("provider_group_id_raw"),
        pl.col("version"),
    ])
    .drop_nulls(subset=["pg_uid"])
    .unique()
)

dim_pos_new = (
    base.select(["pos_set_id","pos_members"])
        .drop_nulls(subset=["pos_set_id"])
        .unique()
)

# XREFs from provider file (vectorized payer_slug + pg_uid)
prov_aug = (
    prov
      .with_columns([ _payer_slug_expr().alias("payer_slug") ])
      .with_columns(
          pl.concat_str(
              [
                  pl.col("payer_slug"),
                  pl.col("version").fill_null(""),
                  pl.col("provider_group_id").fill_null(""),  # ✅ Use provider_group_id from provider
                  pl.lit(""),  # provider_reference_id not needed for hash
              ],
              separator="|",
          ).alias("pg_key")
      )
      .with_columns(pl.col("pg_key").map_elements(md5, return_dtype=pl.Utf8).alias("pg_uid"))
      .drop("pg_key")
)

xref_pg_npi_new = (
    prov_aug.select(["pg_uid","npi"])
            .drop_nulls(subset=["pg_uid","npi"])
            .unique()
)

xref_pg_tin_new = (
    prov_aug.select(["pg_uid","tin_type","tin_value"])
            .drop_nulls(subset=["pg_uid","tin_value"])
            .unique()
)

print("Base/dims/xrefs built (vectorized, compat v2).")

# --------- Batched Upsert Helper for Notebooks ---------
# Only reads the *keys* column(s) from the existing file, and only for the current batch.
# If the file is large, consider partitioning or using a DB for deduplication.

def append_unique_parquet_batched(df_new: pl.DataFrame, path: Path, keys: list[str], batch_size: int = 10000):
    """
    Append unique rows from df_new to Parquet at path, in small batches.
    Only reads the keys from the existing file, not the whole file.
    """
    import math

    path.parent.mkdir(parents=True, exist_ok=True)
    if not path.exists():
        # If file doesn't exist, just write all new data
        df_new.write_parquet(path, compression="zstd")
        return

    # Read only the keys from the existing file
    try:
        existing_keys = pl.read_parquet(path, columns=keys)
    except Exception as e:
        print(f"Warning: Could not read keys from {path}: {e}")
        existing_keys = pl.DataFrame({k: [] for k in keys})

    # Process in batches to avoid memory spikes
    n = df_new.height
    n_batches = math.ceil(n / batch_size)
    for i in range(n_batches):
        start = i * batch_size
        end = min((i+1) * batch_size, n)
        batch = df_new.slice(start, end - start)
        # Anti-join to find truly new rows
        to_add = batch.join(existing_keys, on=keys, how="anti")
        if not to_add.is_empty():
            # Append to Parquet (append mode)
            to_add.write_parquet(path, compression="zstd", append=True)

# Example usage (replace append_unique_parquet with batched version for notebook safety):
# append_unique_parquet_batched(dim_code_new, DIM_CODE_FILE, keys=["code_type","code"])
# append_unique_parquet_batched(dim_payer_new, DIM_PAYER_FILE, keys=["payer_slug"])
# append_unique_parquet_batched(dim_pg_new,   DIM_PG_FILE,    keys=["pg_uid"])
# append_unique_parquet_batched(dim_pos_new,  DIM_POS_FILE,   keys=["pos_set_id"])
# append_unique_parquet_batched(xref_pg_npi_new, XREF_PG_NPI, keys=["pg_uid","npi"])
# append_unique_parquet_batched(xref_pg_tin_new, XREF_PG_TIN, keys=["pg_uid","tin_value"])

print("Ready for batched, memory-light upserts in notebook mode.")


KeyboardInterrupt: 

Cell 6

In [7]:
def append_unique_parquet(df_new: pl.DataFrame, path: Path, keys: list[str]):
    path.parent.mkdir(parents=True, exist_ok=True)
    if path.exists():
        old_keys = pl.read_parquet(path, columns=keys).unique()
        to_add = df_new.join(old_keys, on=keys, how="anti")
    else:
        to_add = df_new
    if to_add.is_empty():
        return

    tmp_new = path.with_suffix(".new.parquet")
    tmp_out = path.with_suffix(".next.parquet")

    to_add.write_parquet(tmp_new, compression="zstd")
    con = duckdb.connect()
    if path.exists():
        con.execute(f"""
          COPY (
            SELECT * FROM read_parquet('{path}')
            UNION ALL
            SELECT * FROM read_parquet('{tmp_new}')
          ) TO '{tmp_out}' (FORMAT PARQUET, COMPRESSION ZSTD);
        """)
    else:
        con.execute(f"""
          COPY (SELECT * FROM read_parquet('{tmp_new}'))
          TO '{tmp_out}' (FORMAT PARQUET, COMPRESSION ZSTD);
        """)
    con.close()
    os.replace(tmp_out, path)
    os.remove(tmp_new)

# Write dims/xrefs
append_unique_parquet(dim_code_new, DIM_CODE_FILE, keys=["code_type","code"])
append_unique_parquet(dim_payer_new, DIM_PAYER_FILE, keys=["payer_slug"])
append_unique_parquet(dim_pg_new,   DIM_PG_FILE,    keys=["pg_uid"])
append_unique_parquet(dim_pos_new,  DIM_POS_FILE,   keys=["pos_set_id"])

append_unique_parquet(xref_pg_npi_new, XREF_PG_NPI, keys=["pg_uid","npi"])
append_unique_parquet(xref_pg_tin_new, XREF_PG_TIN, keys=["pg_uid","tin_value"])

print("Dims/Xrefs up to date.")


Dims/Xrefs up to date.


Cell 7

In [8]:
# Cell 7 — Build fact_new (keep exact fact_uid semantics via existing helper)

fact_new = (
    base
      .with_columns(pl.lit(STATE).alias("state"))
      .select(
          "state",
          pl.col("year_month"),
          pl.col("payer_slug"),
          pl.col("billing_class"),
          pl.col("billing_code_type").alias("code_type"),
          pl.col("billing_code").cast(pl.Utf8).alias("code"),
          pl.col("pg_uid"),
          pl.col("pos_set_id"),
          pl.col("negotiated_type"),
          pl.col("negotiation_arrangement"),
          pl.col("negotiated_rate").cast(pl.Float64).alias("negotiated_rate"),
          pl.col("expiration_date"),
          pl.coalesce([pl.col("provider_group_id"), pl.col("provider_reference_id")]).alias("provider_group_id_raw"),
          pl.col("reporting_entity_name"),
      )
      .with_columns(
          # preserve your exact ID logic (%.4f on rate, same field order)
          pl.struct([
              "state","year_month","payer_slug","billing_class","code_type","code",
              "pg_uid","pos_set_id","negotiated_type","negotiation_arrangement",
              "expiration_date","negotiated_rate","provider_group_id_raw"
          ]).map_elements(fact_uid_from_struct, return_dtype=pl.Utf8).alias("fact_uid")
      )
      .select(
          "fact_uid","state","year_month","payer_slug","billing_class","code_type","code",
          "pg_uid","pos_set_id","negotiated_type","negotiation_arrangement",
          "negotiated_rate","expiration_date","provider_group_id_raw","reporting_entity_name"
      )
      .unique()
)

print("fact_new rows:", fact_new.height)
fact_new.head(3)


fact_new rows: 2967105


fact_uid,state,year_month,payer_slug,billing_class,code_type,code,pg_uid,pos_set_id,negotiated_type,negotiation_arrangement,negotiated_rate,expiration_date,provider_group_id_raw,reporting_entity_name
str,str,str,str,str,str,str,str,str,str,str,f64,str,i64,str
"""d023eebdae2fec46f9179ab625c678…","""GA""","""2025-08""","""unitedhealthcare-of-georgia-in…","""professional""","""CPT""","""28545""","""016b481acf847781526267a457c9c4…","""17b00c58b3dcdb9c20cb2a70b52a4c…","""negotiated""","""ffs""",608.27,"""9999-12-31""",584,"""UnitedHealthcare of Georgia In…"
"""5f8783b3c2eadafc80af9f47a3bcec…","""GA""","""2025-08""","""unitedhealthcare-of-georgia-in…","""professional""","""CPT""","""29886""","""e469dc1ef6c0ea0d9bb989d96e4592…","""17b00c58b3dcdb9c20cb2a70b52a4c…","""negotiated""","""ffs""",1051.15,"""9999-12-31""",282,"""UnitedHealthcare of Georgia In…"
"""da6893a29f62eb80b35fc28450365a…","""GA""","""2025-08""","""unitedhealthcare-of-georgia-in…","""institutional""","""CPT""","""19120""","""a910491ce9955a0083ee09c0b9c36d…","""3b3743c43fbdc8f1b8eedc16b264f3…","""negotiated""","""ffs""",0.0,"""9999-12-31""",220,"""UnitedHealthcare of Georgia In…"


cell 8

In [9]:
# --- Cell 8: robust upsert that normalizes schema before insert ---
import os, duckdb

CANON_COLS_SQL = """
  CAST(fact_uid                AS VARCHAR) AS fact_uid,
  CAST(state                   AS VARCHAR) AS state,
  CAST(year_month              AS VARCHAR) AS year_month,
  CAST(payer_slug              AS VARCHAR) AS payer_slug,
  CAST(billing_class           AS VARCHAR) AS billing_class,
  CAST(code_type               AS VARCHAR) AS code_type,
  CAST(code                    AS VARCHAR) AS code,
  CAST(pg_uid                  AS VARCHAR) AS pg_uid,
  CAST(pos_set_id              AS VARCHAR) AS pos_set_id,
  CAST(negotiated_type         AS VARCHAR) AS negotiated_type,
  CAST(negotiation_arrangement AS VARCHAR) AS negotiation_arrangement,
  CAST(negotiated_rate         AS DOUBLE)  AS negotiated_rate,
  CAST(expiration_date         AS VARCHAR) AS expiration_date,
  CAST(provider_group_id_raw   AS VARCHAR) AS provider_group_id_raw,
  CAST(reporting_entity_name   AS VARCHAR) AS reporting_entity_name
"""

def upsert_fact_single(fact_batch: pl.DataFrame):
    # write the incoming batch to a temp parquet
    tmp_new = f"{GOLD_FACT_FILE}.stage.parquet"
    tmp_out = f"{GOLD_FACT_FILE}.new.parquet"
    fact_batch.write_parquet(tmp_new)

    p_all  = GOLD_FACT_FILE.replace("'", "''")
    p_new  = tmp_new.replace("'", "''")
    p_out  = tmp_out.replace("'", "''")

    con = duckdb.connect()  # ephemeral in-memory db/session

    if not os.path.exists(GOLD_FACT_FILE):
        # First write: just normalize schema and save
        con.execute(f"CREATE OR REPLACE TABLE _stage AS SELECT {CANON_COLS_SQL} FROM read_parquet('{p_new}');")
        con.execute(f"COPY (SELECT * FROM _stage) TO '{p_all}' (FORMAT PARQUET, COMPRESSION ZSTD);")
        con.close()
        os.remove(tmp_new)
        print(f"Created {GOLD_FACT_FILE} with {fact_batch.height} rows.")
        return

    # Normalize both existing gold and new batch to the canonical schema
    con.execute(f"CREATE OR REPLACE TABLE _all   AS SELECT {CANON_COLS_SQL} FROM read_parquet('{p_all}');")
    con.execute(f"CREATE OR REPLACE TABLE _stage AS SELECT {CANON_COLS_SQL} FROM read_parquet('{p_new}');")

    # Insert only new fact_uids
    con.execute("""
      INSERT INTO _all
      SELECT s.*
      FROM _stage s
      LEFT JOIN _all a ON a.fact_uid = s.fact_uid
      WHERE a.fact_uid IS NULL;
    """)

    # Write back to parquet with consistent schema
    con.execute(f"COPY (SELECT * FROM _all) TO '{p_out}' (FORMAT PARQUET, COMPRESSION ZSTD);")
    con.close()
    os.replace(tmp_out, GOLD_FACT_FILE)
    os.remove(tmp_new)
    print(f"Upsert complete into {GOLD_FACT_FILE}.")


Cell 9 send fact rate to parquet

In [10]:
# Cell 9 - Actually save the fact table (FIXED VERSION)
import os, duckdb

CANON_COLS_SQL = """
  CAST(fact_uid                AS VARCHAR) AS fact_uid,
  CAST(state                   AS VARCHAR) AS state,
  CAST(year_month              AS VARCHAR) AS year_month,
  CAST(payer_slug              AS VARCHAR) AS payer_slug,
  CAST(billing_class           AS VARCHAR) AS billing_class,
  CAST(code_type               AS VARCHAR) AS code_type,
  CAST(code                    AS VARCHAR) AS code,
  CAST(pg_uid                  AS VARCHAR) AS pg_uid,
  CAST(pos_set_id              AS VARCHAR) AS pos_set_id,
  CAST(negotiated_type         AS VARCHAR) AS negotiated_type,
  CAST(negotiation_arrangement AS VARCHAR) AS negotiation_arrangement,
  CAST(negotiated_rate         AS DOUBLE)  AS negotiated_rate,
  CAST(expiration_date         AS VARCHAR) AS expiration_date,
  CAST(provider_group_id_raw   AS VARCHAR) AS provider_group_id_raw,
  CAST(reporting_entity_name   AS VARCHAR) AS reporting_entity_name
"""

def upsert_fact_single_fixed(fact_batch: pl.DataFrame):
    # write the incoming batch to a temp parquet
    tmp_new = f"{GOLD_FACT_FILE}.stage.parquet"
    tmp_out = f"{GOLD_FACT_FILE}.new.parquet"
    fact_batch.write_parquet(tmp_new)

    # Convert Path objects to strings for SQL
    p_all  = str(GOLD_FACT_FILE).replace("'", "''")
    p_new  = tmp_new.replace("'", "''")
    p_out  = tmp_out.replace("'", "''")

    con = duckdb.connect()  # ephemeral in-memory db/session

    if not os.path.exists(GOLD_FACT_FILE):
        # First write: just normalize schema and save
        con.execute(f"CREATE OR REPLACE TABLE _stage AS SELECT {CANON_COLS_SQL} FROM read_parquet('{p_new}');")
        con.execute(f"COPY (SELECT * FROM _stage) TO '{p_all}' (FORMAT PARQUET, COMPRESSION ZSTD);")
        con.close()
        os.remove(tmp_new)
        print(f"Created {GOLD_FACT_FILE} with {fact_batch.height} rows.")
        return

    # Normalize both existing gold and new batch to the canonical schema
    con.execute(f"CREATE OR REPLACE TABLE _all   AS SELECT {CANON_COLS_SQL} FROM read_parquet('{p_all}');")
    con.execute(f"CREATE OR REPLACE TABLE _stage AS SELECT {CANON_COLS_SQL} FROM read_parquet('{p_new}');")

    # Insert only new fact_uids
    con.execute("""
      INSERT INTO _all
      SELECT s.*
      FROM _stage s
      LEFT JOIN _all a ON a.fact_uid = s.fact_uid
      WHERE a.fact_uid IS NULL;
    """)

    # Write back to parquet with consistent schema
    con.execute(f"COPY (SELECT * FROM _all) TO '{p_out}' (FORMAT PARQUET, COMPRESSION ZSTD);")
    con.close()
    os.replace(tmp_out, GOLD_FACT_FILE)
    os.remove(tmp_new)
    print(f"Upsert complete into {GOLD_FACT_FILE}.")

# Actually save the fact table
upsert_fact_single_fixed(fact_new)

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Created c:\Users\ChristopherCato\OneDrive - clarity-dx.com\code\bph\workcomp-rates-etl\data\gold\fact_rate.parquet with 2967105 rows.


Cell 9 - Sanity Check

In [11]:
def count_parquet_rows(path: Path) -> int:
    if not path.exists():
        return 0
    con = duckdb.connect()
    n = con.execute(f"SELECT COUNT(*) FROM read_parquet('{str(path)}')").fetchone()[0]
    con.close()
    return int(n)

print("Row counts:")
print("  dim_code         :", count_parquet_rows(DIM_CODE_FILE))
print("  dim_payer        :", count_parquet_rows(DIM_PAYER_FILE))
print("  dim_provider_grp :", count_parquet_rows(DIM_PG_FILE))
print("  dim_pos_set      :", count_parquet_rows(DIM_POS_FILE))
print("  xref_pg_npi      :", count_parquet_rows(XREF_PG_NPI))
print("  xref_pg_tin      :", count_parquet_rows(XREF_PG_TIN))
print("  fact_rate (gold) :", count_parquet_rows(GOLD_FACT_FILE))


Row counts:
  dim_code         : 3696
  dim_payer        : 1
  dim_provider_grp : 637
  dim_pos_set      : 4
  xref_pg_npi      : 16999
  xref_pg_tin      : 9486
  fact_rate (gold) : 2967105


checks

In [15]:
# Data Integrity & Relationship Verification
import polars as pl
import duckdb
from pathlib import Path
import os

# Helper to safely load a parquet file, returning None if missing
def safe_read_parquet(path):
    try:
        return pl.read_parquet(path)
    except FileNotFoundError:
        print(f"❌ File not found: {path}")
        return None

# All file paths are relative to the notebook location (ETL/), so back out one level
base_dir = os.path.join("..", "data")
fact_path = os.path.join(base_dir, "gold", "fact_rate.parquet")
dim_code_path = os.path.join(base_dir, "dims", "dim_code.parquet")
dim_payer_path = os.path.join(base_dir, "dims", "dim_payer.parquet")
dim_pg_path = os.path.join(base_dir, "dims", "dim_provider_group.parquet")
dim_pos_path = os.path.join(base_dir, "dims", "dim_pos_set.parquet")
xref_npi_path = os.path.join(base_dir, "xrefs", "xref_pg_member_npi.parquet")
xref_tin_path = os.path.join(base_dir, "xrefs", "xref_pg_member_tin.parquet")

# Load all the output tables, handling missing files gracefully
fact = safe_read_parquet(fact_path)
dim_code = safe_read_parquet(dim_code_path)
dim_payer = safe_read_parquet(dim_payer_path)
dim_pg = safe_read_parquet(dim_pg_path)
dim_pos = safe_read_parquet(dim_pos_path)
xref_npi = safe_read_parquet(xref_npi_path)
xref_tin = safe_read_parquet(xref_tin_path)

# If any required table is missing, skip the rest of the checks
required_tables = {
    "fact": fact,
    "dim_code": dim_code,
    "dim_payer": dim_payer,
    "dim_pg": dim_pg,
    "dim_pos": dim_pos,
    "xref_npi": xref_npi,
    "xref_tin": xref_tin,
}
missing = [k for k, v in required_tables.items() if v is None]
if missing:
    print(f"\n❌ Skipping integrity checks. Missing tables: {', '.join(missing)}")
else:
    print("=== BASIC COUNTS ===")
    print(f"Fact table: {fact.height:,} rows")
    print(f"Dim codes: {dim_code.height:,} rows")
    print(f"Dim payers: {dim_payer.height:,} rows")
    print(f"Dim provider groups: {dim_pg.height:,} rows")
    print(f"Dim POS sets: {dim_pos.height:,} rows")
    print(f"Xref NPI: {xref_npi.height:,} rows")
    print(f"Xref TIN: {xref_tin.height:,} rows")

    print("\n=== FOREIGN KEY INTEGRITY ===")

    # Check if all fact table foreign keys exist in dimensions
    print("Fact -> Dim Code mapping:")
    fact_codes = fact.select(["code_type", "code"]).unique()
    missing_codes = fact_codes.join(dim_code, on=["code_type", "code"], how="anti")
    print(f"  Missing codes: {missing_codes.height}")

    print("Fact -> Dim Payer mapping:")
    fact_payers = fact.select("payer_slug").unique()
    missing_payers = fact_payers.join(dim_payer, on="payer_slug", how="anti")
    print(f"  Missing payers: {missing_payers.height}")

    print("Fact -> Dim Provider Group mapping:")
    fact_pgs = fact.select("pg_uid").unique()
    missing_pgs = fact_pgs.join(dim_pg, on="pg_uid", how="anti")
    print(f"  Missing provider groups: {missing_pgs.height}")

    print("Fact -> Dim POS Set mapping:")
    fact_pos = fact.select("pos_set_id").unique()
    missing_pos = fact_pos.join(dim_pos, on="pos_set_id", how="anti")
    print(f"  Missing POS sets: {missing_pos.height}")

    print("\n=== CROSS-REFERENCE INTEGRITY ===")

    # Check if all provider groups in fact table have NPI mappings
    fact_pg_has_npi = fact.select("pg_uid").unique().join(xref_npi, on="pg_uid", how="inner")
    print(f"Provider groups with NPI mappings: {fact_pg_has_npi.height:,}")

    # Check if all provider groups in fact table have TIN mappings  
    fact_pg_has_tin = fact.select("pg_uid").unique().join(xref_tin, on="pg_uid", how="inner")
    print(f"Provider groups with TIN mappings: {fact_pg_has_tin.height:,}")

    print("\n=== SAMPLE DATA OVERLAP ===")

    # Show a sample of joined data to verify relationships
    sample_joined = (
        fact.head(5)
        .join(dim_code, on=["code_type", "code"], how="left")
        .join(dim_payer, on="payer_slug", how="left")
        .join(dim_pg, on="pg_uid", how="left")
        .join(dim_pos, on="pos_set_id", how="left")
        .select([
            "fact_uid", "state", "payer_slug", "reporting_entity_name",
            "code_type", "code", "code_description", 
            "pg_uid", "provider_group_id_raw",
            "pos_set_id", "pos_members",
            "negotiated_rate"
        ])
    )

    print("Sample joined data:")
    print(sample_joined)

    print("\n=== UNIQUENESS CHECKS ===")

    # Check for duplicate fact_uids (should be 0)
    duplicate_facts = fact.group_by("fact_uid").agg(pl.count().alias("count")).filter(pl.col("count") > 1)
    print(f"Duplicate fact_uids: {duplicate_facts.height}")

    # Check for duplicate dimension keys
    duplicate_codes = dim_code.group_by(["code_type", "code"]).agg(pl.count().alias("count")).filter(pl.col("count") > 1)
    print(f"Duplicate code keys: {duplicate_codes.height}")

    duplicate_payers = dim_payer.group_by("payer_slug").agg(pl.count().alias("count")).filter(pl.col("count") > 1)
    print(f"Duplicate payer keys: {duplicate_payers.height}")

    duplicate_pgs = dim_pg.group_by("pg_uid").agg(pl.count().alias("count")).filter(pl.col("count") > 1)
    print(f"Duplicate provider group keys: {duplicate_pgs.height}")

    print("\n=== DATA QUALITY SUMMARY ===")
    print("✅ All checks passed!" if all([
        missing_codes.height == 0,
        missing_payers.height == 0, 
        missing_pgs.height == 0,
        missing_pos.height == 0,
        duplicate_facts.height == 0,
        duplicate_codes.height == 0,
        duplicate_payers.height == 0,
        duplicate_pgs.height == 0
    ]) else "❌ Some integrity issues found!")

=== BASIC COUNTS ===
Fact table: 2,967,105 rows
Dim codes: 3,696 rows
Dim payers: 1 rows
Dim provider groups: 637 rows
Dim POS sets: 4 rows
Xref NPI: 16,999 rows
Xref TIN: 9,486 rows

=== FOREIGN KEY INTEGRITY ===
Fact -> Dim Code mapping:
  Missing codes: 0
Fact -> Dim Payer mapping:
  Missing payers: 0
Fact -> Dim Provider Group mapping:
  Missing provider groups: 0
Fact -> Dim POS Set mapping:
  Missing POS sets: 0

=== CROSS-REFERENCE INTEGRITY ===
Provider groups with NPI mappings: 13,912
Provider groups with TIN mappings: 7,661

=== SAMPLE DATA OVERLAP ===
Sample joined data:
shape: (5, 12)
┌────────────┬───────┬────────────┬────────────┬───┬───────────┬───────────┬───────────┬───────────┐
│ fact_uid   ┆ state ┆ payer_slug ┆ reporting_ ┆ … ┆ provider_ ┆ pos_set_i ┆ pos_membe ┆ negotiate │
│ ---        ┆ ---   ┆ ---        ┆ entity_nam ┆   ┆ group_id_ ┆ d         ┆ rs        ┆ d_rate    │
│ str        ┆ str   ┆ str        ┆ e          ┆   ┆ raw       ┆ ---       ┆ ---       ┆ --- 

  duplicate_facts = fact.group_by("fact_uid").agg(pl.count().alias("count")).filter(pl.col("count") > 1)


Duplicate fact_uids: 0
Duplicate code keys: 0
Duplicate payer keys: 0
Duplicate provider group keys: 0

=== DATA QUALITY SUMMARY ===
✅ All checks passed!


  duplicate_codes = dim_code.group_by(["code_type", "code"]).agg(pl.count().alias("count")).filter(pl.col("count") > 1)
  duplicate_payers = dim_payer.group_by("payer_slug").agg(pl.count().alias("count")).filter(pl.col("count") > 1)
  duplicate_pgs = dim_pg.group_by("pg_uid").agg(pl.count().alias("count")).filter(pl.col("count") > 1)


In [16]:
# Debug the cross-reference mapping issue
import polars as pl

# Load the tables (back out a level, then go under data, not core)
fact = pl.read_parquet("../data/gold/fact_rate.parquet")
xref_npi = pl.read_parquet("../data/xrefs/xref_pg_member_npi.parquet")
xref_tin = pl.read_parquet("../data/xrefs/xref_pg_member_tin.parquet")

print("=== DEBUGGING CROSS-REFERENCE ISSUE ===")

# Check what pg_uids exist in fact vs xref tables
fact_pg_uids = fact.select("pg_uid").unique()
xref_npi_pg_uids = xref_npi.select("pg_uid").unique()
xref_tin_pg_uids = xref_tin.select("pg_uid").unique()

print(f"Unique pg_uids in fact table: {fact_pg_uids.height:,}")
print(f"Unique pg_uids in NPI xref: {xref_npi_pg_uids.height:,}")
print(f"Unique pg_uids in TIN xref: {xref_tin_pg_uids.height:,}")

# Check for overlap
fact_npi_overlap = fact_pg_uids.join(xref_npi_pg_uids, on="pg_uid", how="inner")
fact_tin_overlap = fact_pg_uids.join(xref_tin_pg_uids, on="pg_uid", how="inner")

print(f"Fact pg_uids that have NPI mappings: {fact_npi_overlap.height:,}")
print(f"Fact pg_uids that have TIN mappings: {fact_tin_overlap.height:,}")

# Show some sample pg_uids from each table to see if they look different
print("\n=== SAMPLE PG_UIDS ===")
print("Sample fact pg_uids:")
print(fact_pg_uids.head(3))
print("\nSample NPI xref pg_uids:")
print(xref_npi_pg_uids.head(3))
print("\nSample TIN xref pg_uids:")
print(xref_tin_pg_uids.head(3))

# Check if there are any pg_uids that exist in xref but not in fact
npi_not_in_fact = xref_npi_pg_uids.join(fact_pg_uids, on="pg_uid", how="anti")
tin_not_in_fact = xref_tin_pg_uids.join(fact_pg_uids, on="pg_uid", how="anti")

print(f"\nNPI xref pg_uids NOT in fact table: {npi_not_in_fact.height:,}")
print(f"TIN xref pg_uids NOT in fact table: {tin_not_in_fact.height:,}")

# If there's a mismatch, show some examples
if npi_not_in_fact.height > 0:
    print("\nSample NPI xref pg_uids not in fact:")
    print(npi_not_in_fact.head(3))

if tin_not_in_fact.height > 0:
    print("\nSample TIN xref pg_uids not in fact:")
    print(tin_not_in_fact.head(3))

=== DEBUGGING CROSS-REFERENCE ISSUE ===
Unique pg_uids in fact table: 637
Unique pg_uids in NPI xref: 952
Unique pg_uids in TIN xref: 952
Fact pg_uids that have NPI mappings: 637
Fact pg_uids that have TIN mappings: 637

=== SAMPLE PG_UIDS ===
Sample fact pg_uids:
shape: (3, 1)
┌─────────────────────────────────┐
│ pg_uid                          │
│ ---                             │
│ str                             │
╞═════════════════════════════════╡
│ 0c9046c67c2d317c91e95bdc68d6db… │
│ 833152a0034c3b8b3455a111829fa9… │
│ 440382b2ba13cc0f8787058d848e86… │
└─────────────────────────────────┘

Sample NPI xref pg_uids:
shape: (3, 1)
┌─────────────────────────────────┐
│ pg_uid                          │
│ ---                             │
│ str                             │
╞═════════════════════════════════╡
│ d04a4401f5a42f6d03968655f76530… │
│ b45822a29cb9073d5906c10bab59fb… │
│ a910491ce9955a0083ee09c0b9c36d… │
└─────────────────────────────────┘

Sample TIN xref pg_uids:
shape: 

In [17]:
# Test the pg_uid fix
import polars as pl


print("=== TESTING PG_UID FIX ===")

# Check what pg_uids exist in fact vs xref tables
fact_pg_uids = fact.select("pg_uid").unique()
xref_npi_pg_uids = xref_npi.select("pg_uid").unique()
xref_tin_pg_uids = xref_tin.select("pg_uid").unique()

print(f"Unique pg_uids in fact table: {fact_pg_uids.height:,}")
print(f"Unique pg_uids in NPI xref: {xref_npi_pg_uids.height:,}")
print(f"Unique pg_uids in TIN xref: {xref_tin_pg_uids.height:,}")

# Check for overlap (this should now be > 0!)
fact_npi_overlap = fact_pg_uids.join(xref_npi_pg_uids, on="pg_uid", how="inner")
fact_tin_overlap = fact_pg_uids.join(xref_tin_pg_uids, on="pg_uid", how="inner")

print(f"\nFact pg_uids that have NPI mappings: {fact_npi_overlap.height:,}")
print(f"Fact pg_uids that have TIN mappings: {fact_tin_overlap.height:,}")

# Show some sample matching pg_uids
if fact_npi_overlap.height > 0:
    print("\n=== SAMPLE MATCHING PG_UIDS ===")
    sample_matches = fact_npi_overlap.head(3)
    print("Sample matching pg_uids:")
    print(sample_matches)
    
    # Show the actual data for one matching pg_uid
    sample_pg_uid = sample_matches.item(0, 0)
    print(f"\nData for pg_uid: {sample_pg_uid}")
    
    fact_sample = fact.filter(pl.col("pg_uid") == sample_pg_uid).head(1)
    npi_sample = xref_npi.filter(pl.col("pg_uid") == sample_pg_uid).head(1)
    tin_sample = xref_tin.filter(pl.col("pg_uid") == sample_pg_uid).head(1)
    
    print("Fact table sample:")
    print(fact_sample.select(["pg_uid", "provider_group_id_raw", "payer_slug"]))
    print("NPI xref sample:")
    print(npi_sample)
    print("TIN xref sample:")
    print(tin_sample)

print("\n=== SUCCESS CHECK ===")
if fact_npi_overlap.height > 0 and fact_tin_overlap.height > 0:
    print("✅ SUCCESS! pg_uids now match between fact and cross-reference tables!")
    print("✅ The ETL fix worked!")
else:
    print("❌ Still no matches. Need to investigate further.")

=== TESTING PG_UID FIX ===
Unique pg_uids in fact table: 637
Unique pg_uids in NPI xref: 952
Unique pg_uids in TIN xref: 952

Fact pg_uids that have NPI mappings: 637
Fact pg_uids that have TIN mappings: 637

=== SAMPLE MATCHING PG_UIDS ===
Sample matching pg_uids:
shape: (3, 1)
┌─────────────────────────────────┐
│ pg_uid                          │
│ ---                             │
│ str                             │
╞═════════════════════════════════╡
│ 2fb9b1fcfceaec741dd7ea83bde1f6… │
│ f58ad428df5da16b81c443a4e72868… │
│ fdb18ec69c97eb87293fae8b7fa4a8… │
└─────────────────────────────────┘

Data for pg_uid: 2fb9b1fcfceaec741dd7ea83bde1f60c
Fact table sample:
shape: (1, 3)
┌─────────────────────────────────┬───────────────────────┬─────────────────────────────────┐
│ pg_uid                          ┆ provider_group_id_raw ┆ payer_slug                      │
│ ---                             ┆ ---                   ┆ ---                             │
│ str                         