# 1. Load Data

In [0]:
spark
from pyspark.sql import functions as F
from pyspark.sql import types as T



In [0]:
path = "/Volumes/workspace/default/bops/BOPS data.xlsx"


In [0]:
# See sheets name
sheets = (spark.read.format("excel")
          .option("operation", "listSheets")
          .load(path))

display(sheets)


In [0]:
# Read B&M
bm = (spark.read.format("excel")
      .option("headerRows", 1)                 # first row = column names
      .option("dataAddress", "B&M Sales")    # whole sheet
      .option("inferSchema", True)             # infer column types
      .load(path))

display(bm.limit(5))

In [0]:

online = (spark.read.format("excel")
          .option("headerRows", 1)
          .option("dataAddress", "Online Sales")
          .option("inferSchema", True)             # infer column types
          .load(path))
display(online.limit(5))


In [0]:
from pyspark.sql import functions as F

bm_clean = (bm
  .withColumnRenamed("id (store)", "id_store")
  .withColumn("date", F.to_date("date"))        # keep date-only
  .withColumn("sales", F.col("sales").cast("double"))  # easier for stats/plots
  .withColumn("after", F.col("after").cast("int")) # make binary flags explicitly int (or boolean) for clarity:
  .withColumn("usa", F.col("usa").cast("int"))
)


In [0]:
# see duplicates
print("rows:", bm_clean.count())
print("unique (id_store, date):", bm_clean.select("id_store","date").distinct().count())


In [0]:
online_clean = (online
  .withColumnRenamed("id (DMA)", "id_dma")
  .withColumn("date", F.to_date("date"))        # keep date-only
  .withColumn("sales", F.col("sales").cast("double"))  # easier for stats/plots
  .withColumn("after", F.col("after").cast("int"))
  .withColumn("close", F.col("close").cast("int"))
)


In [0]:
# see duplicates
print("rows:", online_clean.count())
print("unique (id_dma, date):", online_clean.select("id_dma","date").distinct().count())


In [0]:
# save data
bm_clean.write.format("delta").mode("overwrite").saveAsTable("bm_sales_clean")
online_clean.write.format("delta").mode("overwrite").saveAsTable("online_sales_clean")

# 2. QA Test

In [0]:
# ---------- Load dataframes if not already in memory ----------
try:
    bm_clean
except NameError:
    bm_clean = spark.table("bm_sales_clean")

try:
    online_clean
except NameError:
    online_clean = spark.table("online_sales_clean")

# ---------- Helpers ----------
def _first_existing_col(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

def _dtype_name(df, col):
    return dict(df.dtypes).get(col)

def _is_numeric_dtype(dtype_str):
    if dtype_str is None:
        return False
    d = dtype_str.lower()
    return any(x in d for x in ["int", "bigint", "long", "double", "float", "decimal", "smallint", "tinyint"])

def _is_date_like(dtype_str):
    if dtype_str is None:
        return False
    d = dtype_str.lower()
    return ("date" in d) or ("timestamp" in d)

def _fail(msg):
    raise Exception(f"[QA FAIL] {msg}")

def _warn(msg):
    print(f"[QA WARN] {msg}")

def _pass(msg):
    print(f"[QA PASS] {msg}")

# ---------- Resolve expected columns (robust to naming) ----------
bm_id_col   = _first_existing_col(bm_clean, ["id_store", "store_id", "id (store)", "id_store "])
bm_date_col = _first_existing_col(bm_clean, ["date"])
bm_sales_col= _first_existing_col(bm_clean, ["sales"])
bm_after_col= _first_existing_col(bm_clean, ["after"])
bm_usa_col  = _first_existing_col(bm_clean, ["usa"])

on_id_col   = _first_existing_col(online_clean, ["dma_id", "id_dma", "id (DMA)", "id_dma "])
on_date_col = _first_existing_col(online_clean, ["date"])
on_sales_col= _first_existing_col(online_clean, ["sales"])
on_after_col= _first_existing_col(online_clean, ["after"])
on_close_col= _first_existing_col(online_clean, ["close"])

# Hard requirements
required_bm = [bm_id_col, bm_date_col, bm_sales_col, bm_after_col, bm_usa_col]
required_on = [on_id_col, on_date_col, on_sales_col, on_after_col, on_close_col]

if any(c is None for c in required_bm):
    _fail(f"B&M table missing required columns. Found columns: {bm_clean.columns}")
if any(c is None for c in required_on):
    _fail(f"Online table missing required columns. Found columns: {online_clean.columns}")

# ---------- Gate 1: Schema & type sanity ----------
def gate_schema_types(df, id_col, date_col, sales_col, flag_cols, label):
    # ID numeric
    if not _is_numeric_dtype(_dtype_name(df, id_col)):
        _fail(f"{label}: {id_col} should be numeric; got {_dtype_name(df, id_col)}")

    # Date date/timestamp
    if not _is_date_like(_dtype_name(df, date_col)):
        _fail(f"{label}: {date_col} should be date/timestamp; got {_dtype_name(df, date_col)}")

    # Sales numeric
    if not _is_numeric_dtype(_dtype_name(df, sales_col)):
        _fail(f"{label}: {sales_col} should be numeric; got {_dtype_name(df, sales_col)}")

    # Flags numeric (int/bigint/etc.)
    for fcol in flag_cols:
        if not _is_numeric_dtype(_dtype_name(df, fcol)):
            _fail(f"{label}: {fcol} should be numeric flag (0/1); got {_dtype_name(df, fcol)}")

    _pass(f"{label}: schema/type sanity checks passed.")

gate_schema_types(
    bm_clean, bm_id_col, bm_date_col, bm_sales_col, [bm_after_col, bm_usa_col], "B&M"
)
gate_schema_types(
    online_clean, on_id_col, on_date_col, on_sales_col, [on_after_col, on_close_col], "Online"
)

# ---------- Gate 2: Primary key uniqueness ----------
def gate_key_uniqueness(df, key_cols, label, sample=20):
    dup = (df.groupBy(*key_cols).count().filter(F.col("count") > 1))
    ndup = dup.count()
    if ndup > 0:
        _warn(f"{label}: found {ndup} duplicate keys on {key_cols}. Showing examples:")
        display(dup.orderBy(F.desc("count")).limit(sample))
        _fail(f"{label}: key uniqueness failed for {key_cols}.")
    _pass(f"{label}: key uniqueness passed for {key_cols}.")

gate_key_uniqueness(bm_clean, [bm_id_col, bm_date_col], "B&M")
gate_key_uniqueness(online_clean, [on_id_col, on_date_col], "Online")

# ---------- Gate 3: Flag domain validity (must be 0/1) ----------
def gate_binary_flags(df, flag_cols, label):
    for c in flag_cols:
        bad = df.select(c).where((F.col(c).isNotNull()) & (~F.col(c).isin([0, 1]))).count()
        if bad > 0:
            _warn(f"{label}: flag {c} has {bad} rows not in {{0,1}}. Distinct values:")
            display(df.select(c).distinct().orderBy(c))
            _fail(f"{label}: binary flag check failed for {c}.")
    _pass(f"{label}: binary flag checks passed for {flag_cols}.")

gate_binary_flags(bm_clean, [bm_after_col, bm_usa_col], "B&M")
gate_binary_flags(online_clean, [on_after_col, on_close_col], "Online")

# ---------- Gate 4: Completeness (no nulls in key + critical fields) ----------
def gate_not_null(df, cols, label):
    null_counts = {c: df.where(F.col(c).isNull()).count() for c in cols}
    bad = {c: n for c, n in null_counts.items() if n > 0}
    if bad:
        _warn(f"{label}: nulls found in critical columns: {bad}")
        _fail(f"{label}: not-null gate failed.")
    _pass(f"{label}: not-null gate passed for {cols}.")

gate_not_null(bm_clean, [bm_id_col, bm_date_col, bm_sales_col, bm_after_col, bm_usa_col], "B&M")
gate_not_null(online_clean, [on_id_col, on_date_col, on_sales_col, on_after_col, on_close_col], "Online")

# ---------- Gate 5: Value sanity (sales non-negative) + Outlier WARN ----------
def gate_sales_sanity(df, id_col, date_col, sales_col, label):
    neg = df.where(F.col(sales_col) < 0).count()
    if neg > 0:
        _warn(f"{label}: found {neg} rows with negative {sales_col}. Showing examples:")
        display(df.where(F.col(sales_col) < 0).select(id_col, date_col, sales_col).limit(20))
        _fail(f"{label}: sales sanity failed (negative sales).")
    _pass(f"{label}: sales non-negative check passed.")

    # Soft outlier diagnostics (WARN only)
    qs = df.approxQuantile(sales_col, [0.5, 0.99, 0.999], 0.001)
    if len(qs) == 3:
        median, p99, p999 = qs
        print(f"{label}: {sales_col} quantiles -> median={median:.4f}, p99={p99:.4f}, p999={p999:.4f}")
        # Show top 10 for inspection
        display(df.select(id_col, date_col, sales_col).orderBy(F.desc(sales_col)).limit(10))
        if median > 0 and (p999 / median) > 1_000:
            _warn(f"{label}: extreme tail detected (p999/median > 1000). Consider log1p(sales) robustness.")
    else:
        _warn(f"{label}: could not compute quantiles for {sales_col} (unexpected).")

gate_sales_sanity(bm_clean, bm_id_col, bm_date_col, bm_sales_col, "B&M")
gate_sales_sanity(online_clean, on_id_col, on_date_col, on_sales_col, "Online")

# ---------- Gate 6: Time coverage sanity (weekly completeness pattern) ----------
def gate_time_coverage(df, date_col, label):
    # counts per week-date
    counts = df.groupBy(date_col).count()
    n_dates = counts.count()
    if n_dates == 0:
        _fail(f"{label}: no dates found in {date_col}.")
    _pass(f"{label}: found {n_dates} unique {date_col} values.")

    # Flag weeks with unusually low coverage (WARN only)
    median_cnt = counts.approxQuantile("count", [0.5], 0.001)[0]
    low = counts.where(F.col("count") < F.lit(0.5) * F.lit(median_cnt)).orderBy("count")
    low_n = low.count()
    if low_n > 0:
        _warn(f"{label}: {low_n} weeks have <50% of median row coverage (median={median_cnt}). Review missingness.")
        display(low.limit(25))
    else:
        _pass(f"{label}: weekly coverage looks stable (no weeks <50% median coverage).")

    # min/max date
    mm = df.select(F.min(date_col).alias("min_date"), F.max(date_col).alias("max_date")).collect()[0]
    print(f"{label}: date range = {mm['min_date']} to {mm['max_date']}")

gate_time_coverage(bm_clean, bm_date_col, "B&M")
gate_time_coverage(online_clean, on_date_col, "Online")

print("\nâœ… All HARD QA gates passed. Soft warnings (if any) are listed above.")

# 3. Single Source of Truth

In [0]:
on = spark.table("online_sales_clean")

did_panel = (
    on
    # keep only what you need (optional: keep year/week too)
    .select(
        F.col("id_dma").cast("bigint").alias("id_dma"),
        F.to_date("date").alias("date"),
        F.col("year").cast("int").alias("year"),
        F.col("week").cast("int").alias("week"),
        F.col("after").cast("int").alias("after"),
        F.col("close").cast("int").alias("close"),
        F.col("sales").cast("double").alias("sales"),
    )
    # create interaction term
    .withColumn("treated_post", (F.col("close") * F.col("after")).cast("int"))
    # enforce one row per DMA-week
    .dropDuplicates(["id_dma", "date"])
)

did_panel.write.format("delta").mode("overwrite").saveAsTable("did_panel")