In [0]:
from pyspark.sql import functions as F
from delta.tables import DeltaTable
import logging, traceback, uuid
from datetime import datetime, timezone

In [0]:
AUDIT_CATALOG = "adf_adb_audit"
AUDIT_SCHEMA  = "audit"
AUDIT_TABLE   = "pipeline_audit"
AUDIT_BASE_PATH="abfss://gizmoboxadb@gijodatabricksextdl.dfs.core.windows.net/adf_adb_audit/"
catalog_loc = AUDIT_BASE_PATH.rstrip("/") + "/"
schema_loc  = catalog_loc + "audit"

In [0]:
def _now_utc():
    return datetime.now(timezone.utc)

In [0]:
def _dbutils_ctx():
    out = {"user": None, "triggered_by": "unknown", "job_id": None,
           "run_id": None, "spark_version": spark.version}
    try:
        from pyspark.dbutils import DBUtils
        dbu = DBUtils(spark)  # type: ignore
        nb = dbu.notebook.getContext()
        try:
            out["user"] = spark.sql("select current_user()").collect()[0][0]
        except:
            pass
        try:
            tags = {t.key(): t.value() for t in nb.tags().get()}
            out["job_id"] = tags.get("jobId") or tags.get("job_id")
            out["run_id"] = tags.get("jobRunId") or tags.get("runId")
            if tags.get("jobId"):
                out["triggered_by"] = "schedule"
            else:
                out["triggered_by"] = "manual"
        except:
            pass
    except:
        pass
    if not out["run_id"]:
        out["run_id"] = str(uuid.uuid4())  # interactive fallback
    return out


In [0]:
CREATE_CATALOG = f"""CREATE CATALOG IF NOT EXISTS {AUDIT_CATALOG}
						MANAGED LOCATION '{catalog_loc}'"""

In [0]:
CREATE_SCHEMA = f"""CREATE SCHEMA IF NOT EXISTS {AUDIT_SCHEMA}
					MANAGED LOCATION '{schema_loc}'"""

In [0]:
CREATE_TABLE = f"""CREATE TABLE IF NOT EXISTS {AUDIT_CATALOG}.{AUDIT_SCHEMA}.{AUDIT_TABLE} (
						environment STRING, 
						job_id STRING, 
						job_name STRING, 
						run_id STRING,
						triggered_by STRING, 
						target_table STRING, 
						layer STRING,
						start_time_utc TIMESTAMP, 
						end_time_utc TIMESTAMP, 
						duration_ms BIGINT,
						record_count BIGINT, 
						run_status STRING, 
						retry_of_run_id STRING, 
						audit_date DATE
					) USING DELTA PARTITIONED BY (audit_date)"""

In [0]:

# Ensure audit catalog/schema/table exist with explicit storage locations.

def ensure_audit_objects(logger=None,AUDIT_BASE_PATH="abfss://gizmoboxadb@gijodatabricksextdl.dfs.core.windows.net/adf_adb_audit/"):
    spark.sql(CREATE_CATALOG)
    spark.sql(f"USE CATALOG {AUDIT_CATALOG}")
    spark.sql(CREATE_SCHEMA)
    spark.sql(CREATE_TABLE)
	
    if logger: logger.info("Audit catalog/schema/table ensured.")



In [0]:
def _layer_from_target_table(target_table: str) -> str:
    # expects "<catalog>.<schema>.<table>"
    parts = (target_table or "").split(".")
    return parts[1] if len(parts) >= 2 else "unknown"

In [0]:
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, LongType, DateType
from pyspark.sql import functions as F
from delta.tables import DeltaTable
import uuid
from datetime import datetime, timezone

def audit_start(
    pENV: str = None, ENV: str = None,
    pJOB_NAME: str = None, JOB_NAME: str = None,
    pTARGET_TABLE: str = None, TARGET_TABLE: str = None,
    pRETRY_OF_RUN_ID: str = None, RETRY_OF_RUN_ID: str = None,
    logger=None
) -> dict:
    """Create/Upsert a RUNNING row for (run_id, layer) with an explicit schema."""
    ensure_audit_objects(logger)

    # --- resolve args (support both corporate p* and plain names) ---
    env          = (pENV or ENV or "").lower()
    job_name     = pJOB_NAME or JOB_NAME
    target_table = pTARGET_TABLE or TARGET_TABLE
    retry_id     = pRETRY_OF_RUN_ID or RETRY_OF_RUN_ID

    ctx = _dbutils_ctx()
    run_id = ctx["run_id"] or str(uuid.uuid4())
    start_ts = datetime.now(timezone.utc)
    layer = _layer_from_target_table(target_table)

    # --- explicit schema to avoid NullType inference ---
    schema = StructType([
        StructField("environment",    StringType(),   True),
        StructField("job_id",         StringType(),   True),
        StructField("job_name",       StringType(),   False),
        StructField("run_id",         StringType(),   False),
        StructField("triggered_by",   StringType(),   True),
        StructField("target_table",   StringType(),   True),
        StructField("layer",          StringType(),   True),
        StructField("start_time_utc", TimestampType(),True),
        StructField("end_time_utc",   TimestampType(),True),
        StructField("duration_ms",    LongType(),     True),
        StructField("record_count",   LongType(),     True),
        StructField("run_status",     StringType(),   True),
        StructField("retry_of_run_id",StringType(),   True),
        StructField("audit_date",     DateType(),     True)
    ])

    row = [{
        "environment": env,
        "job_id": ctx["job_id"],
        "job_name": job_name,
        "run_id": run_id,
        "triggered_by": ctx["triggered_by"],
        "target_table": target_table,
        "layer": layer,
        "start_time_utc": start_ts,
        "end_time_utc": None,
        "duration_ms": None,
        "record_count": None,
        "run_status": "RUNNING",
        "retry_of_run_id": retry_id,
        "audit_date": start_ts.date()
    }]

    df = spark.createDataFrame(row, schema=schema)

    tgt = f"{AUDIT_CATALOG}.{AUDIT_SCHEMA}.{AUDIT_TABLE}"
    DeltaTable.forName(spark, tgt).alias("t") \
        .merge(df.alias("s"), "t.run_id = s.run_id AND t.layer = s.layer") \
        .whenMatchedUpdateAll() \
        .whenNotMatchedInsertAll() \
        .execute()

    if logger: logger.info(f"[AUDIT] START {job_name} {layer} run_id={run_id}")
    return {"run_id": run_id, "layer": layer, "start_time_utc": start_ts}


In [0]:
def audit_update_count(*, RUN_ID: str, TARGET_TABLE: str,
                       RECORD_COUNT: int, logger: logging.Logger = None):
    layer = _layer_from_target_table(TARGET_TABLE)
    tgt = DeltaTable.forName(spark, f"{AUDIT_CATALOG}.{AUDIT_SCHEMA}.{AUDIT_TABLE}")
    upd = (spark.range(1)
        .select(
            F.lit(RUN_ID).alias("run_id"),
            F.lit(layer).alias("layer"),
            F.lit(int(RECORD_COUNT)).alias("record_count")
        ))
    tgt.alias("t").merge(
        upd.alias("s"),
        "t.run_id = s.run_id AND t.layer = s.layer"
    ).whenMatchedUpdate(set={"record_count": F.col("s.record_count")}).execute()
    if logger:
        logger.info(f"[AUDIT] COUNT {layer} rows={RECORD_COUNT} run_id={RUN_ID}")

In [0]:
from datetime import timezone

def audit_finalize(*, RUN_ID: str, TARGET_TABLE: str, RUN_STATUS: str, logger=None):
    layer = _layer_from_target_table(TARGET_TABLE)
    end_ts = _now_utc()  # aware UTC

    tbl = f"{AUDIT_CATALOG}.{AUDIT_SCHEMA}.{AUDIT_TABLE}"
    row = (spark.table(tbl)
           .where((F.col("run_id")==RUN_ID) & (F.col("layer")==layer))
           .select("start_time_utc")
           .limit(1).collect())

    duration_ms = None
    try:
        st = row[0][0] if row else None
        if st is not None:
            # make start aware-UTC if Spark returned it naive
            if getattr(st, "tzinfo", None) is None:
                st = st.replace(tzinfo=timezone.utc)
            else:
                st = st.astimezone(timezone.utc)
            duration_ms = int((end_ts - st).total_seconds() * 1000)
    except Exception as ex:
        if logger: logger.warning(f"[AUDIT] duration calc failed: {ex}; continuing without duration.")

    upd = (spark.range(1).select(
        F.lit(RUN_ID).alias("run_id"),
        F.lit(layer).alias("layer"),
        F.lit(RUN_STATUS).alias("run_status"),
        F.lit(end_ts).alias("end_time_utc"),
        F.lit(duration_ms).cast("bigint").alias("duration_ms"),
        F.lit(end_ts.date()).alias("audit_date")
    ))
    DeltaTable.forName(spark, tbl).alias("t") \
      .merge(upd.alias("s"), "t.run_id = s.run_id AND t.layer = s.layer") \
      .whenMatchedUpdate(set={
          "run_status": F.col("s.run_status"),
          "end_time_utc": F.col("s.end_time_utc"),
          "duration_ms": F.col("s.duration_ms"),
          "audit_date": F.col("s.audit_date"),
      }).execute()

    if logger: logger.info(f"[AUDIT] FINAL {layer} status={RUN_STATUS} run_id={RUN_ID}")