# TPC-DS Generator

In [None]:
%pip install duckdb

In [None]:
import duckdb
import os

from functools import reduce
from pyspark.sql import functions as F

In [None]:
partition_facts_by_year = True
scale_factor = 100

duckdb_path = f"/lakehouse/default/Files/tpc_ds__sf_{scale_factor}.duckdb"
staging_path = f"/lakehouse/default/Files/tpc_ds__sf_{scale_factor}"
target_schema = f"tpc_ds__sf_{scale_factor}"
staging_base = f"Files/tpc_ds__sf_{scale_factor}"

## Connect To DuckDB

In [None]:
def connect(duckdb_path):
    duckdb_threads = 8
    duckdb_memory_limit = "64GB"
    duckdb_temp_dir = "/lakehouse/default/Files/duckdb_tmp"
    duckdb_progress = True
    duckdb_preserve_order = False

    # -----------------------------
    # Connect
    # -----------------------------
    con = duckdb.connect(duckdb_path)

    # -----------------------------
    # PRAGMAs
    # -----------------------------
    con.execute(f"PRAGMA threads = {duckdb_threads};")
    con.execute(f"PRAGMA memory_limit = '{duckdb_memory_limit}';")
    con.execute(f"PRAGMA temp_directory = '{duckdb_temp_dir}';")
    con.execute(f"PRAGMA enable_progress_bar = {str(duckdb_progress).lower()};")
    con.execute(f"PRAGMA preserve_insertion_order = {str(duckdb_preserve_order).lower()};")

    return con

## Generate Data

In [None]:
def generate_data(con):
    con.execute("INSTALL tpcds;")
    con.execute("LOAD tpcds;")
    con.execute(f"CALL dsdgen(sf = {scale_factor});")

    print(f"TPC-DS SF={scale_factor} generated in {duckdb_path}")

## Stage Data

In [None]:
def stage_data(con, staging_path):
    tables = con.sql("SHOW ALL TABLES").df()

    os.mkdir(staging_path)

    for table in tables["name"]:
        table_staging_path = f"{staging_path}/{table}.parquet"

        con.execute(f"""
            COPY {table} TO '{table_staging_path}' (
                FORMAT parquet,
                COMPRESSION zstd,
                ROW_GROUP_SIZE 100_000
            );
        """)

        print(f"Exported [{table}] to [{table_staging_path}]")

## Load Data

### Helpers

In [None]:
def setup_schema(target_schema):
    spark.sql(f"CREATE DATABASE IF NOT EXISTS {target_schema}")

def parquet_file(staging_base, name: str) -> str:
    return f"{staging_base}/{name}.parquet"

def save_table(target_schema, df, name: str, partition_cols=None, mode="overwrite"):
    writer = df.write.format("delta").mode(mode).option("overwriteSchema", "true")
    if partition_cols:
        writer = writer.partitionBy(*partition_cols)
    writer.saveAsTable(f"{target_schema}.{name}")
    print(f"Wrote {target_schema}.{name}")

def maybe_add_year_partition(df, date_year):
    """
    Partition facts by year, joining on the unified dim__date PK.
    Expects df has: _key__dim__date
    Expects date_year has: _key__dim__date, year
    """
    if not partition_facts_by_year:
        return df, None
    return (
        df.join(date_year, "_key__dim__date", "left")
          .withColumnRenamed("year", "_part__year"),
        ["_part__year"]
    )

def concat_pk(*cols, sep="~"):
    """Stable string PK concatenation."""
    return F.concat_ws(sep, *[F.col(c).cast("string") for c in cols])

### Dimensions

In [None]:
# -------------------------
# Dimensions (FIXED: correct TPC-DS web_page column names)
# -------------------------
def load_dimensions(target_schema, staging_base):
    # dim__date
    date_dim = spark.read.parquet(parquet_file(staging_base, "date_dim"))
    dim__date = date_dim.selectExpr(
        "cast(d_date_sk as int) as _key__dim__date",
        "d_date as date",
        "cast(d_year as int) as year",
        "cast(d_moy as int) as month_of_year",
        "cast(d_dom as int) as day_of_month",
        "cast(d_qoy as int) as quarter_of_year",
        "d_day_name as day_name",
        "cast(d_week_seq as int) as week_seq"
    )
    save_table(target_schema, dim__date, "dim__date")

    date_year = dim__date.select("_key__dim__date", "year")

    # dim__time (required)
    dim__time = spark.read.parquet(parquet_file(staging_base, "time_dim")).selectExpr(
        "cast(t_time_sk as int) as _key__dim__time",
        "t_time_id as time_id",
        "t_time as time",
        "cast(t_hour as int) as hour",
        "cast(t_minute as int) as minute",
        "cast(t_second as int) as second",
        "t_am_pm as am_pm",
        "t_shift as shift",
        "t_sub_shift as sub_shift",
        "t_meal_time as meal_time"
    )
    save_table(target_schema, dim__time, "dim__time")

    # dim__item
    dim__item = spark.read.parquet(parquet_file(staging_base, "item")).selectExpr(
        "cast(i_item_sk as int) as _key__dim__item",
        "i_item_id as item_id",
        "i_item_desc as item_desc",
        "i_category as category",
        "i_class as class",
        "i_brand as brand",
        "cast(i_manufact_id as int) as manufact_id",
        "i_current_price as current_price",
        "i_size as size"
    )
    save_table(target_schema, dim__item, "dim__item")

    # dim__promotion
    dim__promotion = spark.read.parquet(parquet_file(staging_base, "promotion")).selectExpr(
        "cast(p_promo_sk as int) as _key__dim__promotion",
        "p_promo_id as promo_id",
        "p_channel_dmail as channel_dmail",
        "p_channel_email as channel_email",
        "p_channel_tv as channel_tv",
        "p_discount_active as discount_active"
    )
    save_table(target_schema, dim__promotion, "dim__promotion")

    # dim__reason
    dim__reason = spark.read.parquet(parquet_file(staging_base, "reason")).selectExpr(
        "cast(r_reason_sk as int) as _key__dim__reason",
        "r_reason_id as reason_id",
        "r_reason_desc as reason_desc"
    )
    save_table(target_schema, dim__reason, "dim__reason")

    # dim__store
    dim__store = spark.read.parquet(parquet_file(staging_base, "store")).selectExpr(
        "cast(s_store_sk as int) as _key__dim__store",
        "s_store_id as store_id",
        "s_store_name as store_name",
        "cast(s_company_id as int) as company_id",
        "s_company_name as company_name",
        "s_state as state",
        "s_country as country",
        "s_city as city"
    )
    save_table(target_schema, dim__store, "dim__store")

    # dim__web_site
    dim__web_site = spark.read.parquet(parquet_file(staging_base, "web_site")).selectExpr(
        "cast(web_site_sk as int) as _key__dim__web_site",
        "web_site_id as web_site_id",
        "web_name as web_name",
        "cast(web_company_id as int) as company_id",
        "web_company_name as company_name",
        "web_country as country"
    )
    save_table(target_schema, dim__web_site, "dim__web_site")

    # dim__web_page (FIXED: wp_web_page_sk/wp_web_page_id)
    dim__web_page = spark.read.parquet(parquet_file(staging_base, "web_page")).selectExpr(
        "cast(wp_web_page_sk as int) as _key__dim__web_page",
        "wp_web_page_id as web_page_id",
        "wp_url as url",
        "wp_type as page_type",
        "wp_autogen_flag as autogen_flag",
        "cast(wp_creation_date_sk as int) as _key__dim__date__creation",
        "cast(wp_access_date_sk as int) as _key__dim__date__access"
    )
    save_table(target_schema, dim__web_page, "dim__web_page")

    # dim__catalog_page
    dim__catalog_page = spark.read.parquet(parquet_file(staging_base, "catalog_page")).selectExpr(
        "cast(cp_catalog_page_sk as int) as _key__dim__catalog_page",
        "cp_catalog_page_id as catalog_page_id",
        "cp_department as department",
        "cast(cp_catalog_number as int) as catalog_number"
    )
    save_table(target_schema, dim__catalog_page, "dim__catalog_page")

    # dim__call_center
    dim__call_center = spark.read.parquet(parquet_file(staging_base, "call_center")).selectExpr(
        "cast(cc_call_center_sk as int) as _key__dim__call_center",
        "cc_call_center_id as call_center_id",
        "cc_name as name",
        "cc_country as country",
        "cc_state as state",
        "cc_city as city"
    )
    save_table(target_schema, dim__call_center, "dim__call_center")

    # dim__warehouse
    dim__warehouse = spark.read.parquet(parquet_file(staging_base, "warehouse")).selectExpr(
        "cast(w_warehouse_sk as int) as _key__dim__warehouse",
        "w_warehouse_id as warehouse_id",
        "w_warehouse_name as warehouse_name",
        "w_state as state",
        "w_country as country",
        "w_city as city"
    )
    save_table(target_schema, dim__warehouse, "dim__warehouse")

    # dim__ship_mode (required by web/catalog)
    dim__ship_mode = spark.read.parquet(parquet_file(staging_base, "ship_mode")).selectExpr(
        "cast(sm_ship_mode_sk as int) as _key__dim__ship_mode",
        "sm_ship_mode_id as ship_mode_id",
        "sm_type as ship_mode_type",
        "sm_carrier as carrier",
        "sm_contract as contract"
    )
    save_table(target_schema, dim__ship_mode, "dim__ship_mode")

    # dim__customer (flatten snowflake)
    cust = spark.read.parquet(parquet_file(staging_base, "customer")).alias("c")
    addr = spark.read.parquet(parquet_file(staging_base, "customer_address")).alias("a")
    cdemo = spark.read.parquet(parquet_file(staging_base, "customer_demographics")).alias("cd")
    hdemo = spark.read.parquet(parquet_file(staging_base, "household_demographics")).alias("hd")
    iband = spark.read.parquet(parquet_file(staging_base, "income_band")).alias("ib")

    dim__customer = (
        cust
        .join(addr, cust["c_current_addr_sk"] == addr["ca_address_sk"], "left")
        .join(cdemo, cust["c_current_cdemo_sk"] == cdemo["cd_demo_sk"], "left")
        .join(hdemo, cust["c_current_hdemo_sk"] == hdemo["hd_demo_sk"], "left")
        .join(iband, hdemo["hd_income_band_sk"] == iband["ib_income_band_sk"], "left")
        .selectExpr(
            "cast(c_customer_sk as int) as _key__dim__customer",
            "c_customer_id as customer_id",
            "c_first_name as first_name",
            "c_last_name as last_name",
            "cast(c_birth_year as int) as birth_year",
            "ca_country as country",
            "ca_state as state",
            "ca_county as county",
            "ca_city as city",
            "ca_zip as zip",
            "cd_gender as gender",
            "cd_marital_status as marital_status",
            "cd_education_status as education_status",
            "hd_buy_potential as buy_potential",
            "cast(ib_lower_bound as int) as income_lower",
            "cast(ib_upper_bound as int) as income_upper"
        )
    )
    save_table(target_schema, dim__customer, "dim__customer")

    return date_year

### Facts

In [None]:
# -------------------------
# Facts (FULL MEASURES, simplified customer: one customer key per fact)
# -------------------------
def load_facts(date_year, target_schema, staging_base):
    # ------------------------------------------------------------
    # fact__sales (Union of Store, Web, Catalog)
    # ------------------------------------------------------------
    
    # Store Sales
    store_sales = spark.read.parquet(parquet_file(staging_base, "store_sales")).selectExpr(
        "'store' as channel",
        "concat(cast(ss_item_sk as string), '~', cast(ss_ticket_number as string)) as _key__fact__sales",

        "cast(ss_sold_date_sk as int) as _key__dim__date",
        "cast(ss_sold_time_sk as int) as _key__dim__time",
        "cast(ss_item_sk as int) as _key__dim__item",
        "cast(ss_customer_sk as int) as _key__dim__customer",
        "cast(ss_store_sk as int) as _key__dim__store",
        "cast(ss_promo_sk as int) as _key__dim__promotion",

        "cast(ss_ticket_number as long) as _degenerate__order_number",

        "cast(ss_quantity as int) as _measure__sales__quantity",
        "ss_wholesale_cost as _measure__sales__wholesale_cost",
        "ss_list_price as _measure__sales__list_price",
        "ss_sales_price as _measure__sales__sales_price",
        "ss_ext_discount_amt as _measure__sales__ext_discount_amt",
        "ss_ext_sales_price as _measure__sales__ext_sales_price",
        "ss_ext_wholesale_cost as _measure__sales__ext_wholesale_cost",
        "ss_ext_list_price as _measure__sales__ext_list_price",
        "ss_ext_tax as _measure__sales__ext_tax",
        "ss_coupon_amt as _measure__sales__coupon_amt",
        "ss_net_paid as _measure__sales__net_paid",
        "ss_net_paid_inc_tax as _measure__sales__net_paid_inc_tax",
        "ss_net_profit as _measure__sales__net_profit"
    )

    # Web Sales
    web_sales = spark.read.parquet(parquet_file(staging_base, "web_sales")).selectExpr(
        "'web' as channel",
        "concat(cast(ws_item_sk as string), '~', cast(ws_order_number as string)) as _key__fact__sales",

        "cast(ws_sold_date_sk as int) as _key__dim__date",
        "cast(ws_sold_time_sk as int) as _key__dim__time",
        "cast(ws_ship_date_sk as int) as _key__dim__ship_date",
        "cast(ws_item_sk as int) as _key__dim__item",

        "cast(ws_bill_customer_sk as int) as _key__dim__customer",

        "cast(ws_web_page_sk as int) as _key__dim__web_page",
        "cast(ws_web_site_sk as int) as _key__dim__web_site",
        "cast(ws_ship_mode_sk as int) as _key__dim__ship_mode",
        "cast(ws_warehouse_sk as int) as _key__dim__warehouse",
        "cast(ws_promo_sk as int) as _key__dim__promotion",

        "cast(ws_order_number as long) as _degenerate__order_number",

        "cast(ws_quantity as int) as _measure__sales__quantity",
        "ws_wholesale_cost as _measure__sales__wholesale_cost",
        "ws_list_price as _measure__sales__list_price",
        "ws_sales_price as _measure__sales__sales_price",
        "ws_ext_discount_amt as _measure__sales__ext_discount_amt",
        "ws_ext_sales_price as _measure__sales__ext_sales_price",
        "ws_ext_wholesale_cost as _measure__sales__ext_wholesale_cost",
        "ws_ext_list_price as _measure__sales__ext_list_price",
        "ws_ext_tax as _measure__sales__ext_tax",
        "ws_coupon_amt as _measure__sales__coupon_amt",
        "ws_ext_ship_cost as _measure__sales__ext_ship_cost",
        "ws_net_paid as _measure__sales__net_paid",
        "ws_net_paid_inc_tax as _measure__sales__net_paid_inc_tax",
        "ws_net_paid_inc_ship as _measure__sales__net_paid_inc_ship",
        "ws_net_paid_inc_ship_tax as _measure__sales__net_paid_inc_ship_tax",
        "ws_net_profit as _measure__sales__net_profit"
    )

    # Catalog Sales
    catalog_sales = spark.read.parquet(parquet_file(staging_base, "catalog_sales")).selectExpr(
        "'catalog' as channel",
        "concat(cast(cs_item_sk as string), '~', cast(cs_order_number as string)) as _key__fact__sales",

        "cast(cs_sold_date_sk as int) as _key__dim__date",
        "cast(cs_sold_time_sk as int) as _key__dim__time",
        "cast(cs_ship_date_sk as int) as _key__dim__ship_date",
        "cast(cs_item_sk as int) as _key__dim__item",

        "cast(cs_bill_customer_sk as int) as _key__dim__customer",

        "cast(cs_call_center_sk as int) as _key__dim__call_center",
        "cast(cs_catalog_page_sk as int) as _key__dim__catalog_page",
        "cast(cs_ship_mode_sk as int) as _key__dim__ship_mode",
        "cast(cs_warehouse_sk as int) as _key__dim__warehouse",
        "cast(cs_promo_sk as int) as _key__dim__promotion",

        "cast(cs_order_number as long) as _degenerate__order_number",

        "cast(cs_quantity as int) as _measure__sales__quantity",
        "cs_wholesale_cost as _measure__sales__wholesale_cost",
        "cs_list_price as _measure__sales__list_price",
        "cs_sales_price as _measure__sales__sales_price",
        "cs_ext_discount_amt as _measure__sales__ext_discount_amt",
        "cs_ext_sales_price as _measure__sales__ext_sales_price",
        "cs_ext_wholesale_cost as _measure__sales__ext_wholesale_cost",
        "cs_ext_list_price as _measure__sales__ext_list_price",
        "cs_ext_tax as _measure__sales__ext_tax",
        "cs_coupon_amt as _measure__sales__coupon_amt",
        "cs_ext_ship_cost as _measure__sales__ext_ship_cost",
        "cs_net_paid as _measure__sales__net_paid",
        "cs_net_paid_inc_tax as _measure__sales__net_paid_inc_tax",
        "cs_net_paid_inc_ship as _measure__sales__net_paid_inc_ship",
        "cs_net_paid_inc_ship_tax as _measure__sales__net_paid_inc_ship_tax",
        "cs_net_profit as _measure__sales__net_profit"
    )

    fact__sales = store_sales.unionByName(web_sales, allowMissingColumns=True).unionByName(catalog_sales, allowMissingColumns=True)
    fact__sales, parts = maybe_add_year_partition(fact__sales, date_year)
    save_table(target_schema, fact__sales, "fact__sales", parts)

    # ------------------------------------------------------------
    # fact__returns (Union of Store, Web, Catalog)
    # ------------------------------------------------------------

    # Store Returns
    store_returns = spark.read.parquet(parquet_file(staging_base, "store_returns")).selectExpr(
        "'store' as channel",
        "concat(cast(sr_item_sk as string), '~', cast(sr_ticket_number as string)) as _key__fact__returns",

        "cast(sr_returned_date_sk as int) as _key__dim__date",
        "cast(sr_return_time_sk as int) as _key__dim__time",
        "cast(sr_item_sk as int) as _key__dim__item",
        "cast(sr_customer_sk as int) as _key__dim__customer",
        "cast(sr_store_sk as int) as _key__dim__store",
        "cast(sr_reason_sk as int) as _key__dim__reason",

        "cast(sr_ticket_number as long) as _degenerate__order_number",

        "cast(sr_return_quantity as int) as _measure__returns__return_quantity",
        "sr_return_amt as _measure__returns__return_amt",
        "sr_return_tax as _measure__returns__return_tax",
        "sr_return_amt_inc_tax as _measure__returns__return_amt_inc_tax",
        "sr_fee as _measure__returns__fee",
        "sr_return_ship_cost as _measure__returns__return_ship_cost",
        "sr_refunded_cash as _measure__returns__refunded_cash",
        "sr_reversed_charge as _measure__returns__reversed_charge",
        "sr_store_credit as _measure__returns__store_credit",
        "sr_net_loss as _measure__returns__net_loss"
    )

    # Web Returns
    web_returns = spark.read.parquet(parquet_file(staging_base, "web_returns")).selectExpr(
        "'web' as channel",
        "concat(cast(wr_item_sk as string), '~', cast(wr_order_number as string)) as _key__fact__returns",

        "cast(wr_returned_date_sk as int) as _key__dim__date",
        "cast(wr_returned_time_sk as int) as _key__dim__time",
        "cast(wr_item_sk as int) as _key__dim__item",

        "cast(wr_returning_customer_sk as int) as _key__dim__customer",

        "cast(wr_web_page_sk as int) as _key__dim__web_page",
        "cast(wr_reason_sk as int) as _key__dim__reason",

        "cast(wr_order_number as long) as _degenerate__order_number",

        "cast(wr_return_quantity as int) as _measure__returns__return_quantity",
        "wr_return_amt as _measure__returns__return_amt",
        "wr_return_tax as _measure__returns__return_tax",
        "wr_return_amt_inc_tax as _measure__returns__return_amt_inc_tax",
        "wr_fee as _measure__returns__fee",
        "wr_return_ship_cost as _measure__returns__return_ship_cost",
        "wr_refunded_cash as _measure__returns__refunded_cash",
        "wr_reversed_charge as _measure__returns__reversed_charge",
        "wr_account_credit as _measure__returns__store_credit",
        "wr_net_loss as _measure__returns__net_loss"
    )

    # Catalog Returns
    catalog_returns = spark.read.parquet(parquet_file(staging_base, "catalog_returns")).selectExpr(
        "'catalog' as channel",
        "concat(cast(cr_item_sk as string), '~', cast(cr_order_number as string)) as _key__fact__returns",

        "cast(cr_returned_date_sk as int) as _key__dim__date",
        "cast(cr_returned_time_sk as int) as _key__dim__time",
        "cast(cr_item_sk as int) as _key__dim__item",

        "cast(cr_returning_customer_sk as int) as _key__dim__customer",

        "cast(cr_call_center_sk as int) as _key__dim__call_center",
        "cast(cr_catalog_page_sk as int) as _key__dim__catalog_page",
        "cast(cr_ship_mode_sk as int) as _key__dim__ship_mode",
        "cast(cr_warehouse_sk as int) as _key__dim__warehouse",
        "cast(cr_reason_sk as int) as _key__dim__reason",

        "cast(cr_order_number as long) as _degenerate__order_number",

        "cast(cr_return_quantity as int) as _measure__returns__return_quantity",
        "cr_return_amount as _measure__returns__return_amt",
        "cr_return_tax as _measure__returns__return_tax",
        "cr_return_amt_inc_tax as _measure__returns__return_amt_inc_tax",
        "cr_fee as _measure__returns__fee",
        "cr_return_ship_cost as _measure__returns__return_ship_cost",
        "cr_refunded_cash as _measure__returns__refunded_cash",
        "cr_reversed_charge as _measure__returns__reversed_charge",
        "cr_store_credit as _measure__returns__store_credit",
        "cr_net_loss as _measure__returns__net_loss"
    )

    fact__returns = store_returns.unionByName(web_returns, allowMissingColumns=True).unionByName(catalog_returns, allowMissingColumns=True)
    fact__returns, parts = maybe_add_year_partition(fact__returns, date_year)
    save_table(target_schema, fact__returns, "fact__returns", parts)

    # ------------------------------------------------------------
    # fact__inventory
    # PK: (inv_date_sk, inv_item_sk, inv_warehouse_sk)
    # Degenerate dims: none
    # ------------------------------------------------------------
    fact__inventory = spark.read.parquet(parquet_file(staging_base, "inventory")).selectExpr(
        "concat(cast(inv_date_sk as string), '~', cast(inv_item_sk as string), '~', cast(inv_warehouse_sk as string)) as _key__fact__inventory",

        "cast(inv_date_sk as int) as _key__dim__date",
        "cast(inv_item_sk as int) as _key__dim__item",
        "cast(inv_warehouse_sk as int) as _key__dim__warehouse",

        "cast(inv_quantity_on_hand as int) as _measure__inventory__qty_on_hand"
    )
    fact__inventory, parts = maybe_add_year_partition(fact__inventory, date_year)
    save_table(target_schema, fact__inventory, "fact__inventory", parts)

## Generate Puppini Bridge

In [None]:
def generate_bridge(target_schema):
    """
    Creates {target_schema}._bridge by vertically stacking all tables in the schema
    into a single "long" table.

    Bridge schema:
      - peripheral: source table name
      - _key__* columns: kept as BIGINT where possible, otherwise left as STRING
      - _measure__* columns: cast to DOUBLE
      - also includes degenerate columns (_degenerate__*) as STRING (optional but useful)

    Notes / fixes vs original:
      - Only include non-temporary tables, and explicitly drop _bridge if present
      - Don't blindly cast every _key__ to bigint (your _key__fact__... are strings like "a~b")
        -> cast numeric keys to BIGINT where safe, else keep STRING
      - Keeps deterministic column order: keys, degenerates, measures (alphabetical)
      - Uses reduce(unionByName) to avoid a Python-side loop building a huge lineage
    """
    bridge_table = "_bridge"

    # --- discover tables in schema ---
    tables_df = spark.sql(f"SHOW TABLES IN {target_schema}")  # columns: namespace/database, tableName, isTemporary
    table_names = (
        tables_df
        .where("isTemporary = false")
        .select("tableName")
        .collect()
    )
    table_names = [r["tableName"] for r in table_names if r["tableName"] != bridge_table]

    if not table_names:
        raise ValueError(f"No tables found in schema {target_schema} (excluding {bridge_table}).")

    # --- pass 1: collect column names across tables ---
    key_cols = set()
    meas_cols = set()
    deg_cols = set()
    per_table_cols = {}  # table -> set(cols)
    key_is_numeric = {}  # key col -> bool (true only if every occurrence is numeric type)

    def is_numeric_dtype(dt) -> bool:
        # Spark dtypes are strings like "int", "bigint", "string", "decimal(10,2)"
        s = str(dt).lower()
        return (
            "int" in s
            or "bigint" in s
            or "smallint" in s
            or "tinyint" in s
            or "decimal" in s
            or "double" in s
            or "float" in s
            or "long" in s
            or "short" in s
        )

    for t in table_names:
        df = spark.table(f"{target_schema}.{t}")
        cols = set(df.columns)
        per_table_cols[t] = cols

        for c in cols:
            if c.startswith("_key__"):
                key_cols.add(c)
            elif c.startswith("_measure__"):
                meas_cols.add(c)
            elif c.startswith("_degenerate__"):
                deg_cols.add(c)

        # update numeric-ness for key cols present in this table
        schema_map = {f.name: f.dataType for f in df.schema.fields}
        for c in [c for c in cols if c.startswith("_key__")]:
            this_numeric = is_numeric_dtype(schema_map[c])
            if c not in key_is_numeric:
                key_is_numeric[c] = this_numeric
            else:
                key_is_numeric[c] = key_is_numeric[c] and this_numeric

    # Deterministic column order
    key_cols_ordered = sorted(key_cols)
    deg_cols_ordered = sorted(deg_cols)
    meas_cols_ordered = sorted(meas_cols)
    bridge_cols_ordered = key_cols_ordered + deg_cols_ordered + meas_cols_ordered

    # --- helper: normalized select for one table ---
    def bridge_select_for_table(table_name: str):
        df = spark.table(f"{target_schema}.{table_name}")
        exprs = [F.lit(table_name).alias("peripheral")]

        for c in bridge_cols_ordered:
            if c in df.columns:
                if c.startswith("_key__"):
                    # numeric keys -> BIGINT, otherwise STRING (e.g., _key__fact__... contains "~")
                    if key_is_numeric.get(c, False):
                        exprs.append(F.col(c).cast("bigint").alias(c))
                    else:
                        exprs.append(F.col(c).cast("string").alias(c))
                elif c.startswith("_measure__"):
                    exprs.append(F.col(c).cast("double").alias(c))
                else:  # _degenerate__
                    exprs.append(F.col(c).cast("string").alias(c))
            else:
                # missing column -> typed null
                if c.startswith("_key__"):
                    if key_is_numeric.get(c, False):
                        exprs.append(F.lit(None).cast("bigint").alias(c))
                    else:
                        exprs.append(F.lit(None).cast("string").alias(c))
                elif c.startswith("_measure__"):
                    exprs.append(F.lit(None).cast("double").alias(c))
                else:
                    exprs.append(F.lit(None).cast("string").alias(c))

        return df.select(*exprs)

    # --- pass 2: union all ---
    parts = [bridge_select_for_table(t) for t in table_names]
    bridge_df = reduce(lambda a, b: a.unionByName(b, allowMissingColumns=True), parts)

    # --- write bridge ---
    save_table(target_schema, bridge_df, bridge_table, mode="overwrite")

    print(f"Wrote {target_schema}.{bridge_table}")
    print(f"Tables included: {len(table_names)}")
    print(f"Bridge columns: 1 + {len(bridge_cols_ordered)} (peripheral + keys/degenerates/measures)")
    print(f"Key columns numeric: {sum(1 for k,v in key_is_numeric.items() if v)} / {len(key_is_numeric)}")

## Orchestrate

In [None]:
con = connect(duckdb_path)
generate_data(con)
stage_data(con, staging_path)
con.close()

setup_schema(target_schema)
date_year = load_dimensions(target_schema, staging_base)
load_facts(date_year, target_schema, staging_base)
generate_bridge(target_schema)