In [0]:
from datetime import datetime
import pytz

def ensure_widget(name: str, default: str):
    try:
        dbutils.widgets.get(name)      # already exists (job injected) -> keep it
    except Exception:
        dbutils.widgets.text(name, default)  # interactive -> create default

tz = pytz.timezone("America/Indiana/Indianapolis")
today_str = datetime.now(tz).strftime("%Y-%m-%d")

ensure_widget("env", "dev")
ensure_widget("storage_root", "s3://pakeyj-data-sales")
ensure_widget("process_date", today_str)  # job param overrides because widget already exists
ensure_widget("dedup_columns", "order_id")
ensure_widget("drop_columns", "run_id,source_system")
ensure_widget("bronze_table", "bronze_orders")
ensure_widget("silver_table", "silver_orders")
ensure_widget("data_label", "orders")

env = dbutils.widgets.get("env")
storage_root = dbutils.widgets.get("storage_root")
process_date = dbutils.widgets.get("process_date")
dedup_columns = [c.strip() for c in dbutils.widgets.get("dedup_columns").split(",") if c.strip()]
drop_columns = [c.strip() for c in dbutils.widgets.get("drop_columns").split(",") if c.strip()]
bronze_table = dbutils.widgets.get("bronze_table")
silver_table = dbutils.widgets.get("silver_table")
data_label = dbutils.widgets.get("data_label")

print("env:", dbutils.widgets.get("env"))
print("storage_root:", dbutils.widgets.get("storage_root"))
print("process_date:", dbutils.widgets.get("process_date"))
print("dedup_columns:", dbutils.widgets.get("dedup_columns"))
print("drop_columns:", dbutils.widgets.get("drop_columns"))
print("bronze_table:", dbutils.widgets.get("bronze_table"))
print("silver_table:", dbutils.widgets.get("silver_table"))
print("data_label:", dbutils.widgets.get("data_label"))

In [0]:
from src.common.config import load_config
import importlib
import src.common.config as config

importlib.reload(config)

cfg = load_config(env=env, storage_root=storage_root, process_date=process_date)
storage_root = cfg.storage_root
process_date = cfg.process_date
bronze_path= cfg.paths[bronze_table]
silver_path = cfg.paths[silver_table] 
print(f"storage_root: {storage_root}")
print(f"process_date: {process_date}")
print(f"bronze_path: {bronze_path}")
print(f"silver_path: {silver_path}")

In [0]:
def peek(df, name, n=5):
    print(df.columns)
    print(f"\n=== {name} ===")
    print(f"rows: {df.count()}")
    df.show(n,truncate=False)

# foo=spark.read.format('delta').load(bronze_path)
# peek(foo, 'bronze')

In [0]:
import json

widget_names = [
    "env",
    "storage_root",
    "process_date",
    "dedup_columns",
    "drop_columns",
    "bronze_table",
    "silver_table",
    "data_label",
]

widgets_dict = {name: dbutils.widgets.get(name) for name in widget_names}

print(json.dumps(widgets_dict, indent=2))

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from delta.tables import DeltaTable

def dedupe_latest(df, key_cols, ts_col = 'ingested_at'):
  w = Window.partitionBy(*key_cols).orderBy(F.col(ts_col).desc())
  return (
      df.withColumn("_rn", F.row_number().over(w))
        .filter(F.col("_rn") == 1)
        .drop("_rn")
  )

def drop_cols(df, dropped_columns):
  return df.drop(*dropped_columns)

def build_silver(df, silver_path, process_date, *, dataset_name = 'dataset'):
  (df
    .write.format("delta")
    .mode("overwrite")
    .option("replaceWhere", f"ingest_date = '{process_date}'")
    .partitionBy("ingest_date")
    .save(silver_path)
  )    
# def build_silver_full(df, silver_path):
#     (df.write.format("delta")
#        .mode("overwrite")
#        .save(silver_path))


ts_col = "ingested_at"

df_batch = (spark.read.format("delta").load(bronze_path)
            .filter(F.col("ingest_date") == process_date))

df_batch = drop_cols(df_batch, drop_columns)
df_batch = dedupe_latest(df_batch, dedup_columns, ts_col)

merge_cond = " AND ".join([f"t.{c} = s.{c}" for c in dedup_columns])

if DeltaTable.isDeltaTable(spark, silver_path):
    silver_dt = DeltaTable.forPath(spark, silver_path)

    (silver_dt.alias("t")
     .merge(df_batch.alias("s"), merge_cond)
     .whenMatchedUpdateAll(condition=f"s.{ts_col} >= t.{ts_col}")
     .whenNotMatchedInsertAll()
     .execute())
else:
    # first run: create the table
    (df_batch.write.format("delta")
       .mode("overwrite")
       .save(silver_path))

# write to history table
if env == "prod":
    df_batch.write.format("delta").mode("append").save(f"{silver_path}_history")


In [0]:
# for showing duplicate orders if needed

# def duplicated_orders(df, key_cols, ts_col = 'ingested_at', *, dataset_name = 'dataset', show_count = True):
#   dup_keys = (df.groupBy(*key_cols)
#               .count()
#               .filter(F.col("count") > 1)
#               .select(*key_cols)
#               )
#   if show_count:
#     n_dup_keys = dup_keys.count()
#     print(f"Found {n_dup_keys} unique keys")
#   return (df.join(dup_keys, key_cols, 'inner')
#           .orderBy(*[F.col(c).asc() for c in key_cols], F.col(ts_col).desc())
#           )
# df = spark.read.format("delta").load(silver_path)
# peek(df,silver_path)
# df = drop_cols(df,drop_columns)
# df_dupes = duplicated_orders(df,dedup_columns, ts_col = 'ingested_at', dataset_name=data_label, show_count=True) 
