In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from delta.tables import DeltaTable

SILVER_DB = ""   # si vous utilisez le default Lakehouse

# ------------------------------------------------------------
# TRANSACTIONS conventions
# ------------------------------------------------------------

TRANSACTIONS_PARTITION_COLS = ["txn_month"]


def add_txn_dates(df, ts_col="txn_ts"):
    if ts_col not in df.columns:
        raise ValueError(f"Column '{ts_col}' not found in dataframe")
    return (
        df
        .withColumn("txn_date", F.to_date(F.col(ts_col)))
        .withColumn("txn_month", F.date_trunc("month", F.col(ts_col)).cast("date"))
    )


def normalize_mcc(df, col="mcc"):
    return df.withColumn(
        "mcc_code",
        F.lpad(F.col(col).cast("string"), 4, "0")
    )

def cast_amount(df, col="amount", precision=18, scale=2):
    return df.withColumn(
        col,
        F.col(col).cast(f"decimal({precision},{scale})")
    )

def add_tech_columns(
    df,
    source_file_col="source_file",
    ingestion_ts_col="ingestion_ts",
    ingestion_date_col="ingestion_date",
    default_source_file=None
):
    
    # Standardise les colonnes techniques Silver.
    # - source_file : renommée si nécessaire
    # - ingestion_ts / ingestion_date : conservées si présentes, sinon générées
    

    # --- source_file ---
    if source_file_col in df.columns and source_file_col != "source_file":
        df = df.withColumnRenamed(source_file_col, "source_file")
    elif "source_file" not in df.columns:
        df = df.withColumn("source_file", F.lit(default_source_file))

    # --- ingestion_ts ---
    if ingestion_ts_col in df.columns:
        if ingestion_ts_col != "ingestion_ts":
            df = df.withColumnRenamed(ingestion_ts_col, "ingestion_ts")
    else:
        df = df.withColumn("ingestion_ts", F.current_timestamp())

    # --- ingestion_date ---
    if ingestion_date_col in df.columns:
        if ingestion_date_col != "ingestion_date":
            df = df.withColumnRenamed(ingestion_date_col, "ingestion_date")
    else:
        df = df.withColumn("ingestion_date", F.to_date(F.col("ingestion_ts")))

    return df


def add_record_hash(df, cols):
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns for record_hash: {missing}")

    return df.withColumn(
        "record_hash",
        F.sha2(
            F.concat_ws(
                "||",
                *[F.coalesce(F.col(c).cast("string"), F.lit("∅")) for c in cols]
            ),
            256
        )
    )

def assert_required_columns(df, cols):
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")


def deduplicate_latest(
    df,
    key_cols,
    order_col="ingestion_ts"
):
    w = Window.partitionBy(*key_cols).orderBy(F.col(order_col).desc())
    return (
        df
        .withColumn("rn", F.row_number().over(w))
        .filter("rn = 1")
        .drop("rn")
    )

def validate_partitions(
    df, 
    partition_cols, 
    expect_date_types=True, 
    allow_nulls=False, 
    max_distinct_threshold=None):

    if partition_cols is None or len(partition_cols) == 0:
        return True

    #   Valide les colonnes de partition avant écriture Delta.

    #   Paramètres
    #   ----------
    #   df : DataFrame
    #   partition_cols : list[str]
    #       Colonnes utilisées pour le partitionnement (ex: ["txn_month"])
    #   expect_date_types : bool
    #       Si True, impose DateType pour les colonnes de partition
    #   allow_nulls : bool
    #       Si False, interdit les NULL dans les partitions
    #   max_distinct_threshold : int | None
    #       Si défini, alerte si le nombre de partitions distinctes dépasse ce seuil
    #
    #    Raises
    #    ------
    #   ValueError en cas de violation bloquante
    

    # 1) Existence des colonnes
    missing = [c for c in partition_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing partition columns: {missing}")

    # 2) Types
    if expect_date_types:
        bad_types = []
        for c in partition_cols:
            dtype = dict(df.dtypes).get(c)
            if dtype != "date":
                bad_types.append((c, dtype))
        if bad_types:
            raise ValueError(
                "Partition columns must be of type 'date'. Invalid: "
                + ", ".join([f"{c}({t})" for c, t in bad_types])
            )

    # 3) Nullabilité
    if not allow_nulls:
        null_checks = [
            F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c)
            for c in partition_cols
        ]
        nulls = df.select(*null_checks).collect()[0].asDict()
        offenders = {c: v for c, v in nulls.items() if v > 0}
        if offenders:
            raise ValueError(
                f"Null values found in partition columns: {offenders}"
            )

    # 4) Cardinalité (anti-explosion de partitions)
    if max_distinct_threshold is not None:
        distinct_count = (
            df.select(*partition_cols).distinct().count()
        )
        if distinct_count > max_distinct_threshold:
            raise ValueError(
                f"Partition cardinality too high ({distinct_count}) "
                f"> threshold ({max_distinct_threshold})."
            )

    return True



def write_silver_delta(df, table_name, partition_cols=None, mode="overwrite"):
    writer = (
        df.write
        .format("delta")
        .mode(mode)
        .option("overwriteSchema", "true")
    )

    if partition_cols and len(partition_cols) > 0:
        writer = writer.partitionBy(*partition_cols)

    writer.saveAsTable(table_name)



def write_silver_transactions(
    df,
    table_name="silver_transactions",
    mode="overwrite"
):
    validate_partitions(
        df,
        partition_cols=TRANSACTIONS_PARTITION_COLS,
        expect_date_types=True,
        allow_nulls=False,
        max_distinct_threshold=240
    )

    write_silver_delta(
        df,
        table_name=table_name,
        partition_cols=TRANSACTIONS_PARTITION_COLS,
        mode=mode
    )


def merge_silver_delta(df, table_name, key_cols):
    if not spark.catalog.tableExists(table_name):
        df.write.format("delta").saveAsTable(table_name)
        return

    tgt = DeltaTable.forName(spark, table_name)
    cond = " AND ".join([f"t.{k} = s.{k}" for k in key_cols])

    (
        tgt.alias("t")
        .merge(df.alias("s"), cond)
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

# ------------------------------------------------------------
# FX conventions
# ------------------------------------------------------------
FX_PARTITION_COLS = ["fx_month"]

def add_fx_dates(df, date_col="fx_date"):
    if date_col not in df.columns:
        raise ValueError(f"Column '{date_col}' not found in dataframe")
    return (
        df
        .withColumn("fx_date", F.to_date(F.col(date_col)))
        .withColumn("fx_month", F.date_trunc("month", F.col("fx_date")).cast("date"))
    )

def normalize_currency_codes(df, cols):
    for c in cols:
        if c not in df.columns:
            raise ValueError(f"Column '{c}' not found in dataframe")
        df = df.withColumn(c, F.upper(F.trim(F.col(c))))
    return df

def cast_rate(df, col="rate", precision=18, scale=8):
    if col not in df.columns:
        raise ValueError(f"Column '{col}' not found in dataframe")
    return df.withColumn(col, F.col(col).cast(f"decimal({precision},{scale})"))


def write_silver_fx_rates(
    df,
    table_name="silver_fx_rates",
    mode="overwrite"
):
    validate_partitions(
        df,
        partition_cols=FX_PARTITION_COLS,
        expect_date_types=True,
        allow_nulls=False,
        max_distinct_threshold=240
    )

    write_silver_delta(
        df,
        table_name=table_name,
        partition_cols=FX_PARTITION_COLS,
        mode=mode
    )

# ------------------------------------------------------------
# MCC, USER, CARD conventions
# ------------------------------------------------------------

def normalize_text(df, cols):
    for c in cols:
        if c in df.columns:
            df = df.withColumn(c, F.upper(F.trim(F.col(c))))
    return df

def cast_decimal(df, col, precision=18, scale=2):
    if col not in df.columns:
        raise ValueError(f"Column '{col}' not found in dataframe")
    return df.withColumn(col, F.col(col).cast(f"decimal({precision},{scale})"))

def normalize_mcc_code(df, col="mcc"):
    if col not in df.columns:
        raise ValueError(f"Column '{col}' not found in dataframe")
    return df.withColumn("mcc_code", F.lpad(F.col(col).cast("string"), 4, "0"))

def parse_bool_yn(df, col, true_values=("Y","YES","TRUE","1"), false_values=("N","NO","FALSE","0")):
    if col not in df.columns:
        raise ValueError(f"Column '{col}' not found in dataframe")
    c = F.upper(F.trim(F.col(col).cast("string")))
    return df.withColumn(
        col,
        F.when(c.isin([v.upper() for v in true_values]), F.lit(True))
         .when(c.isin([v.upper() for v in false_values]), F.lit(False))
         .otherwise(F.lit(None).cast("boolean"))
    )

def parse_date_multi(df, col, formats):
    if col not in df.columns:
        raise ValueError(f"Column '{col}' not found in dataframe")
    parsed = None
    for fmt in formats:
        candidate = F.to_date(F.col(col).cast("string"), fmt)
        parsed = candidate if parsed is None else F.coalesce(parsed, candidate)
    return df.withColumn(col, parsed)

def write_silver_mcc(df, table_name="silver_mcc", mode="overwrite"):
    write_silver_delta(df, table_name=table_name, partition_cols=[], mode=mode)

def write_silver_users(df, table_name="silver_users", mode="overwrite"):
    write_silver_delta(df, table_name=table_name, partition_cols=[], mode=mode)

def write_silver_cards(df, table_name="silver_cards", mode="overwrite"):
    write_silver_delta(df, table_name=table_name, partition_cols=[], mode=mode)
