In [None]:
# nb_gold_utils
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T
from typing import List, Dict, Optional, Tuple

# ------------------------------------------------------------
# Microsoft Fabric / OneLake filesystem access
# ------------------------------------------------------------
from notebookutils import mssparkutils



# -----------------------------------------------------------------------------
# 0) Runtime helpers
# -----------------------------------------------------------------------------
def now_ts():
    return F.current_timestamp()

def normalize_run_id(gold_run_id: Optional[str]) -> str:
    if gold_run_id is None or str(gold_run_id).strip() == "":
        raise ValueError("gold_run_id is required (non-empty).")
    return str(gold_run_id).strip()

def normalize_entity(entity: str) -> str:
    if entity is None or str(entity).strip() == "":
        raise ValueError("entity is required (non-empty).")
    return str(entity).strip()

# -----------------------------------------------------------------------------
# 1) Contract-first assertions
# -----------------------------------------------------------------------------
def assert_required_columns(df: DataFrame, required_cols: List[str], ctx: str = "") -> None:
    missing = [c for c in required_cols if c not in df.columns]
    if missing:
        raise ValueError(f"[{ctx}] Missing required columns: {missing}")

def assert_no_additional_columns(df: DataFrame, allowed_cols: List[str], ctx: str = "") -> None:
    extra = [c for c in df.columns if c not in allowed_cols]
    if extra:
        raise ValueError(f"[{ctx}] Additional (unexpected) columns: {extra}")

def assert_column_types(df: DataFrame, expected_types: Dict[str, T.DataType], ctx: str = "") -> None:
    # Spark types are not always 1:1; keep strict but realistic.
    schema = {f.name: f.dataType for f in df.schema.fields}
    mismatches = []
    for col, t_expected in expected_types.items():
        if col not in schema:
            mismatches.append((col, "MISSING", str(t_expected)))
        else:
            t_actual = schema[col]
            if type(t_actual) != type(t_expected):
                mismatches.append((col, str(t_actual), str(t_expected)))
    if mismatches:
        raise ValueError(f"[{ctx}] Type mismatches: {mismatches}")

def assert_not_null(df: DataFrame, cols: List[str], ctx: str = "") -> None:
    checks = []
    for c in cols:
        checks.append(F.sum(F.when(F.col(c).isNull(), F.lit(1)).otherwise(F.lit(0))).alias(c))
    row = df.select(checks).collect()[0].asDict()
    bad = {k: v for k, v in row.items() if v and v > 0}
    if bad:
        raise ValueError(f"[{ctx}] NOT NULL violated: {bad}")

def assert_unique_key(df: DataFrame, key_cols: List[str], ctx: str = "") -> None:
    if not key_cols:
        raise ValueError(f"[{ctx}] key_cols is required for uniqueness check")
    dup = (
        df.groupBy([F.col(c) for c in key_cols])
          .count()
          .filter(F.col("count") > 1)
          .limit(1)
          .count()
    )
    if dup > 0:
        raise ValueError(f"[{ctx}] Uniqueness violated for key: {key_cols}")

# -----------------------------------------------------------------------------
# 2) Hash helpers (stable keys)
# -----------------------------------------------------------------------------
def add_key_hash(df: DataFrame, key_cols: List[str], out_col: str = "key_hash") -> DataFrame:
    # Important: cast to string to avoid null/type issues
    exprs = [F.coalesce(F.col(c).cast("string"), F.lit("∅")) for c in key_cols]
    return df.withColumn(out_col, F.sha2(F.concat_ws("||", *exprs), 256))

def add_natural_keys_json(df: DataFrame, key_cols: List[str], out_col: str = "natural_keys") -> DataFrame:
    return df.withColumn(out_col, F.to_json(F.struct(*[F.col(c) for c in key_cols])))

# -----------------------------------------------------------------------------
# 3) Conformance helpers
# -----------------------------------------------------------------------------
def split_orphans_left_anti(
    fact_df: DataFrame,
    dim_df: DataFrame,
    fact_key: str,
    dim_key: str,
    rule_id: str,
    anom_type: str,
    entity: str,
    gold_run_id: str,
    source_table: str,
    severity: str = "HIGH",
    anom_domain: str = "CONFORMANCE",
    natural_key_cols_for_event: Optional[List[str]] = None
) -> Tuple[DataFrame, DataFrame]:

    gold_run_id = normalize_run_id(gold_run_id)
    entity = normalize_entity(entity)

    f = fact_df.alias("f")
    d = dim_df.select(F.col(dim_key).alias(dim_key)).dropDuplicates().alias("d")

    join_cond = F.col(f"f.{fact_key}") == F.col(f"d.{dim_key}")

    # conform = left_semi, orphan = left_anti (sur condition qualifiée)
    conform = f.join(d, join_cond, "left_semi").select("f.*")
    orphan  = f.join(d, join_cond, "left_anti").select("f.*")

    nk_cols = natural_key_cols_for_event or [fact_key]

    orphan_evt = (
        orphan
        .transform(lambda x: add_key_hash(x, nk_cols, "key_hash"))
        .transform(lambda x: add_natural_keys_json(x, nk_cols, "natural_keys"))
        .withColumn("gold_run_id", F.lit(gold_run_id))
        .withColumn("event_ts", F.current_timestamp())
        .withColumn("entity", F.lit(entity))
        .withColumn("anom_domain", F.lit(anom_domain))
        .withColumn("anom_type", F.lit(anom_type))
        .withColumn("severity", F.lit(severity))
        .withColumn("rule_id", F.lit(rule_id))
        .withColumn("source_table", F.lit(source_table))
        .withColumn("detail", F.lit(f"Orphan detected: {fact_key} not found in dim ({dim_key})."))
        .withColumn("gold_load_ts", F.current_timestamp())
        .select(
            "gold_run_id","event_ts","entity","anom_domain","anom_type","severity","rule_id",
            "key_hash","natural_keys","source_table","detail","gold_load_ts"
        )
    )

    return conform, orphan_evt


# -----------------------------------------------------------------------------
# 4) Anomaly writers (append-only)
# -----------------------------------------------------------------------------
def write_anomaly_events(anom_df: DataFrame, table_name: str = "gold_anomaly_event") -> None:
    # Append-only, never overwrite
    (anom_df.write
        .mode("append")
        .format("delta")
        .saveAsTable(table_name)
    )

def write_anomaly_kpis(
    anom_event_df: DataFrame,
    gold_run_id: str,
    entity: str,
    table_name: str = "gold_anomaly_kpi",
    sample_limit: int = 10
) -> None:
    gold_run_id = normalize_run_id(gold_run_id)
    entity = normalize_entity(entity)

    # Aggregate counts + sample keys
    samples = (
        anom_event_df
        .groupBy("anom_domain","anom_type","severity","rule_id")
        .agg(
            F.count(F.lit(1)).alias("row_count"),
            F.slice(F.collect_list("natural_keys"), 1, sample_limit).alias("sample_keys_arr")
        )
        .withColumn("sample_keys", F.to_json(F.col("sample_keys_arr")))
        .drop("sample_keys_arr")
        .withColumn("gold_run_id", F.lit(gold_run_id))
        .withColumn("kpi_ts", now_ts())
        .withColumn("entity", F.lit(entity))
        .withColumn("gold_load_ts", now_ts())
        .select(
            "gold_run_id","kpi_ts","entity",
            "anom_domain","anom_type","severity","rule_id",
            "row_count","sample_keys","gold_load_ts"
        )
    )

    (samples.write
        .mode("append")
        .format("delta")
        .saveAsTable(table_name)
    )

# -----------------------------------------------------------------------------
# 5) Standard write helpers (facts/dims)
# -----------------------------------------------------------------------------
def truncate_table(table_name: str) -> None:
    spark.sql(f"TRUNCATE TABLE {table_name}")

def write_gold_table_append(df: DataFrame, table_name: str, partition_cols: Optional[List[str]] = None) -> None:
    writer = df.write.mode("append").format("delta")
    if partition_cols:
        writer = writer.partitionBy(partition_cols)
    writer.saveAsTable(table_name)

def rebuild_gold_table(df: DataFrame, table_name: str, partition_cols: Optional[List[str]] = None) -> None:
    # Design: TRUNCATE + APPEND (idempotent)
    truncate_table(table_name)
    write_gold_table_append(df, table_name, partition_cols)

# -----------------------------------------------------------------------------
# 6) Convenience metrics
# -----------------------------------------------------------------------------
def df_count(df: DataFrame) -> int:
    return df.count()

def null_counts(df: DataFrame, cols: List[str]) -> Dict[str, int]:
    agg_exprs = [F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c) for c in cols]
    row = df.select(agg_exprs).collect()[0].asDict()
    return {k: int(v) for k, v in row.items()}

def log_metrics_dict(metrics: Dict) -> None:
    # Keep it simple for now: printed; later route to a log table if needed
    print(metrics)


In [None]:
# ============================================================
# nb_gold_utils — Contract Loader + Assertions (Gold)
# ============================================================

from pyspark.sql import DataFrame
from pyspark.sql import types as T
from pyspark.sql import functions as F
from typing import Dict, Any, List, Optional
import re
import json

# ----------------------------
# YAML read helpers
# ----------------------------
def _read_text(path: str) -> str:
    """
    Read a small YAML file from OneLake Files using Microsoft Fabric API.
    """
    try:
        return mssparkutils.fs.head(path, 1024 * 1024)  # up to 1MB
    except Exception as e:
        raise FileNotFoundError(f"Cannot read contract at path: {path}. Error: {e}")


def _parse_yaml(text: str) -> Dict[str, Any]:
    """
    Parse YAML into a Python dict.
    Prefers PyYAML if available; otherwise a minimal fallback is NOT recommended.
    """
    try:
        import yaml  # PyYAML (commonly available)
        return yaml.safe_load(text)
    except Exception as e:
        raise RuntimeError(
            "YAML parsing failed. Ensure PyYAML is available in the Spark environment. "
            f"Original error: {e}"
        )

# ----------------------------
# Type mapping (SQL-ish → Spark)
# ----------------------------
_DEC_RE = re.compile(r"DECIMAL\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", re.IGNORECASE)
_VC_RE  = re.compile(r"(VAR)?CHAR\s*\(\s*\d+\s*\)", re.IGNORECASE)

def _to_spark_type(type_str: str) -> T.DataType:
    """
    Convert a contract type string into a Spark DataType.
    Supports: BIGINT, INT, STRING, BOOLEAN, DATE, TIMESTAMP, DOUBLE, FLOAT,
              DECIMAL(p,s), and common CHAR/VARCHAR forms → StringType.
    """
    if type_str is None:
        return T.StringType()

    s = type_str.strip().upper()

    if s in ("STRING", "TEXT"):
        return T.StringType()
    if _VC_RE.match(s) or s in ("CHAR", "VARCHAR"):
        return T.StringType()
    if s in ("BIGINT", "LONG"):
        return T.LongType()
    if s in ("INT", "INTEGER"):
        return T.IntegerType()
    if s in ("SMALLINT",):
        return T.ShortType()
    if s in ("TINYINT",):
        return T.ByteType()
    if s in ("BOOLEAN", "BOOL"):
        return T.BooleanType()
    if s in ("DATE",):
        return T.DateType()
    if s in ("TIMESTAMP", "DATETIME"):
        return T.TimestampType()
    if s in ("DOUBLE",):
        return T.DoubleType()
    if s in ("FLOAT", "REAL"):
        return T.FloatType()

    m = _DEC_RE.match(s)
    if m:
        p = int(m.group(1))
        sc = int(m.group(2))
        return T.DecimalType(precision=p, scale=sc)

    # If unknown: fail fast (banking-grade)
    raise ValueError(f"Unsupported type in contract: '{type_str}'")

def _expected_types_from_contract(contract: Dict[str, Any]) -> Dict[str, T.DataType]:
    cols = contract.get("columns", []) or []
    out: Dict[str, T.DataType] = {}
    for c in cols:
        name = c["name"]
        typ  = c.get("type")
        out[name] = _to_spark_type(typ)
    return out

def _expected_columns_from_contract(contract: Dict[str, Any]) -> List[str]:
    cols = contract.get("columns", []) or []
    return [c["name"] for c in cols]

# ----------------------------
# Contract path conventions
# ----------------------------
def gold_contract_path(table_name: str, base_dir: str = "Files/governance/schema_registry/gold") -> str:
    """
    Builds the standard contract file path.
    Example: gold_fact_transactions -> Files/governance/schema_registry/gold/gold_fact_transactions.yaml
    """
    if table_name is None or str(table_name).strip() == "":
        raise ValueError("table_name is required")
    return f"{base_dir.rstrip('/')}/{table_name.strip()}.yaml"

# ----------------------------
# Public API: load contract
# ----------------------------
def load_gold_contract(table_name: str, base_dir: str = "Files/governance/schema_registry/gold") -> Dict[str, Any]:
    path = gold_contract_path(table_name, base_dir)
    txt = _read_text(path)
    contract = _parse_yaml(txt)

    # Minimal normalization / sanity checks
    if not isinstance(contract, dict):
        raise ValueError(f"Invalid contract format in {path}: expected dict at root")

    c_table = contract.get("table")
    if c_table and c_table != table_name:
        # Contract mismatch is a real governance issue; fail fast.
        raise ValueError(f"Contract table mismatch: file says '{c_table}', requested '{table_name}'")

    if "columns" not in contract or not contract["columns"]:
        raise ValueError(f"Contract {path} has no columns definition")

    return contract

# ----------------------------
# Public API: apply assertions
# ----------------------------
def apply_gold_contract_assertions(
    df: DataFrame,
    contract: Dict[str, Any],
    ctx: Optional[str] = None,
    enforce_types: bool = True,
    enforce_not_null: bool = True,
    enforce_unique: bool = True
) -> None:
    """
    Applies contract-first assertions on a dataframe.
    - required columns
    - no additional columns
    - type checks (optional)
    - NOT NULL (optional; from constraints.not_null)
    - UNIQUE (optional; from constraints.unique)
    """
    table_name = contract.get("table", "UNKNOWN_TABLE")
    context = ctx or f"{table_name}"

    expected_cols = _expected_columns_from_contract(contract)
    assert_required_columns(df, expected_cols, ctx=context)
    assert_no_additional_columns(df, expected_cols, ctx=context)

    if enforce_types:
        expected_types = _expected_types_from_contract(contract)
        assert_column_types(df, expected_types, ctx=context)

    constraints = contract.get("constraints", {}) or {}

    if enforce_not_null:
        not_null_cols = constraints.get("not_null", []) or []
        if not_null_cols:
            assert_not_null(df, not_null_cols, ctx=context)

    if enforce_unique:
        unique_defs = constraints.get("unique", []) or []
        # unique may be:
        # - ["colA"] (single unique key) OR
        # - [["colA"], ["colB","colC"]] (multiple unique constraints)
        if unique_defs:
            # normalize
            if isinstance(unique_defs, list) and len(unique_defs) > 0 and isinstance(unique_defs[0], str):
                unique_defs = [unique_defs]  # single constraint given as list of strings

            for uq in unique_defs:
                if not isinstance(uq, list) or not uq:
                    raise ValueError(f"[{context}] Invalid unique constraint format: {unique_defs}")
                assert_unique_key(df, uq, ctx=f"{context} UNIQUE {uq}")

# ----------------------------
# Optional helper: cast/project dataframe to contract
# (useful before assertions if upstream types are messy)
# ----------------------------
def project_to_gold_contract(df: DataFrame, contract: Dict[str, Any]) -> DataFrame:
    """
    Select columns in canonical order and cast to contract types (string-based casts for Fabric stability).
    """
    exprs = []
    for c in (contract.get("columns", []) or []):
        name = c["name"]
        typ_str = (c.get("type") or "STRING").strip()
        exprs.append(F.col(name).cast(typ_str).alias(name))
    return df.select(*exprs)

