Paramètrage de la connexion JDBC

In [0]:
# Paramètres de connexion
jdbc_hostname = "sql-datasource-dev-ghe.database.windows.net"
jdbc_port = 1433
jdbc_database = "sqldb-adventureworks-dev-ghe"
jdbc_url = f"jdbc:sqlserver://{jdbc_hostname}:{jdbc_port};database={jdbc_database}"

# Authentification SQL
username = dbutils.secrets.get(scope="kv-jdbc", key="sql-username")
password = dbutils.secrets.get(scope="kv-jdbc", key="sql-password")

# Options JDBC communes
connection_properties = {
    "user": username,
    "password": password,
    "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver"
}

Détection automatique du catalog Unity

In [0]:
catalogs = [row.catalog for row in spark.sql("SHOW CATALOGS").collect()]
unity_catalogs = [c for c in catalogs if c != "hive_metastore"]

if len(unity_catalogs) == 1:
    default_catalog = unity_catalogs[0]
else:
    default_catalog = next((c for c in unity_catalogs if c.startswith("dbw_")), "hive_metastore")

dbutils.widgets.text("my_catalog", default_catalog, "Catalog détecté")
catalog = dbutils.widgets.get("my_catalog")

bronze_schema = "bronze"


Récupère toutes les colonnes d'une table dans la base source.

In [0]:
from pyspark.sql.functions import col

def get_columns_for_table(table_name: str, schema: str = "SalesLT") -> list:
    cols_df = spark.read.jdbc(
        url=jdbc_url,
        table="INFORMATION_SCHEMA.COLUMNS",
        properties=connection_properties
    ).filter(
        (col("TABLE_SCHEMA") == schema) &
        (col("TABLE_NAME") == table_name)
    ).orderBy("ORDINAL_POSITION")

    return [row["COLUMN_NAME"] for row in cols_df.collect()]


Fonction qui converti les nom en CamelCase en snake_case pour les noms de table en bronze

In [0]:
import re

def to_snake_case(name: str) -> str:
    s1 = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


Récupère les colonnes constituant la clé primaire d'une table SQL Server.

In [0]:
def get_primary_keys(table_name: str, schema: str = "SalesLT") -> list:
    key_usage_df = spark.read.jdbc(
        url=jdbc_url,
        table="INFORMATION_SCHEMA.KEY_COLUMN_USAGE",
        properties=connection_properties
    ).filter(
        (col("TABLE_SCHEMA") == schema) &
        (col("TABLE_NAME") == table_name)
    )

    constraints_df = spark.read.jdbc(
        url=jdbc_url,
        table="INFORMATION_SCHEMA.TABLE_CONSTRAINTS",
        properties=connection_properties
    ).filter(
        (col("TABLE_SCHEMA") == schema) &
        (col("TABLE_NAME") == table_name) &
        (col("CONSTRAINT_TYPE") == "PRIMARY KEY")
    )

    primary_keys_df = key_usage_df.join(
        constraints_df,
        on="CONSTRAINT_NAME",
        how="inner"
    ).orderBy("ORDINAL_POSITION")

    return [row["COLUMN_NAME"] for row in primary_keys_df.collect()]


Fonction de récupération du nom des tables

In [0]:
source_tables_df = spark.read.jdbc(
    url=jdbc_url,
    table="INFORMATION_SCHEMA.TABLES",
    properties=connection_properties
).filter("TABLE_SCHEMA = 'SalesLT' AND TABLE_TYPE = 'BASE TABLE'")

source_table_names = [row["TABLE_NAME"] for row in source_tables_df.collect()]


Détection des tables présentes dans la couche Bronze

In [0]:
bronze_tables_df = spark.sql(f"SHOW TABLES IN {catalog}.{bronze_schema}")
bronze_table_names = [row["tableName"] for row in bronze_tables_df.collect()]

Construction d'une table pour faire les tests

In [0]:
tables_to_test = []

for table_name in source_table_names:
    table_snake = to_snake_case(table_name)
    bronze_table_name = f"bronze_saleslt_{table_snake}"

    if bronze_table_name in bronze_table_names:
        primary_keys = get_primary_keys(table_name)
        if not primary_keys:
            print(f"Aucune clé primaire détectée pour {table_name}, table ignorée.")
            continue

        columns = get_columns_for_table(table_name)

        tables_to_test.append({
            "source": f"SalesLT.{table_name}",
            "bronze": bronze_table_name,
            "primary_keys_source": primary_keys,  # ex: CustomerID
            "primary_keys_bronze": [to_snake_case(pk) for pk in primary_keys],  # ex: customer_id
            "columns_source": columns,  # ex: Title
            "columns_bronze": [to_snake_case(c) for c in columns]  # ex: title
        })
    else:
        print(f"Table non trouvée dans la couche bronze : {bronze_table_name}")


Fonction de test pour comparer la source et Bronze sur le count et les valeurs de l'échantillon aléatoire

In [0]:
import random
from pyspark.sql.functions import col, trim, regexp_replace
from functools import reduce

def test_table_sample(source_table, bronze_table, primary_keys_source, primary_keys_bronze, columns_source, columns_bronze):
    try:
        print(f"\nTest de la table : {source_table} ➜ {bronze_table}")

        # COUNT total dans la source
        count_source_total = spark.read.jdbc(
            url=jdbc_url,
            table=source_table,
            properties=connection_properties
        ).count()

        # COUNT total dans la Bronze
        count_bronze_total = spark.read.table(f"{catalog}.{bronze_schema}.{bronze_table}").count()

        # Lecture des clés depuis la source
        source_ids_df = spark.read.jdbc(
            url=jdbc_url,
            table=source_table,
            properties=connection_properties
        ).select(*primary_keys_source).distinct()

        all_rows = source_ids_df.collect()
        if not all_rows:
            print(f"Table : {source_table}\n- Aucune donnée disponible dans la source.\nStatut : Test ignoré\n")
            return

        sample_rows = random.sample(all_rows, min(25, len(all_rows)))

        # Construction du WHERE dynamique (à partir des noms CamelCase)
        def format_condition(row):
            return "(" + " AND ".join([f"{k} = {repr(row[k])}" for k in primary_keys_source]) + ")"

        where_clause = " OR ".join([format_condition(r) for r in sample_rows])
        query = f"(SELECT * FROM {source_table} WHERE {where_clause}) AS src_sample"

        # Chargement source et nettoyage
        source_sample = spark.read.jdbc(url=jdbc_url, table=query, properties=connection_properties)

        for field in source_sample.schema.fields:
            if field.dataType.simpleString() == "string":
                source_sample = source_sample.withColumn(
                    field.name,
                    regexp_replace(trim(col(field.name)), "[\\u00A0\\r\\n]", "")
                )

        # Lecture bronze
        bronze_df = spark.read.table(f"{catalog}.{bronze_schema}.{bronze_table}")
        bronze_sample = bronze_df

        # Filtrage Bronze sur les valeurs de clés
        for i, bronze_key in enumerate(primary_keys_bronze):
            sample_values = [r[primary_keys_source[i]] for r in sample_rows]
            bronze_sample = bronze_sample.filter(col(bronze_key).isin(sample_values))

        # Jointure sur les clés
        join_expr = reduce(lambda a, b: a & b, [
            col(f"src.{primary_keys_source[i]}") == col(f"brz.{primary_keys_bronze[i]}")
            for i in range(len(primary_keys_source))
        ])

        joined_df = source_sample.alias("src").join(
            bronze_sample.alias("brz"),
            on=join_expr,
            how="inner"
        )

        # Comparaison des colonnes
        mismatches = []
        for i, source_col in enumerate(columns_source):
            bronze_col = columns_bronze[i]
            if bronze_col in bronze_sample.columns:
                diff_df = joined_df.filter(col(f"src.{source_col}") != col(f"brz.{bronze_col}"))
                count_diff = diff_df.count()
                if count_diff > 0:
                    mismatches.append((source_col, count_diff))
                    print(f"Divergence détectée sur la colonne : {source_col} ({count_diff} ligne(s))")

                    diff_df.select(
                        *[col(f"src.{k}").alias(f"{k}_source") for k in primary_keys_source],
                        col(f"src.{source_col}").alias(f"{source_col}_source"),
                        col(f"brz.{bronze_col}").alias(f"{bronze_col}_bronze")
                    ).show(5, truncate=False)

        # Résumé global
        print(f"\nRésumé : {source_table}")
        print(f"- Total lignes source : {count_source_total}")
        print(f"- Total lignes Bronze : {count_bronze_total}")
        if len(primary_keys_source) == 1:
            print(f"- Clé primaire : {primary_keys_source[0]}")
        else:
            print(f"- Clé primaire : composite ({', '.join(primary_keys_source)})")
        print(f"- Colonnes testées : {len(columns_source)}")
        print(f"- Colonnes divergentes : {len(mismatches)}")
        if not mismatches and count_source_total == count_bronze_total:
            print("Statut : Aucune divergence\n")
        else:
            print("Statut : Divergence détectée\n")

    except Exception as e:
        print(f"Table : {source_table}\n- Erreur lors du test : {str(e)}\nStatut : Divergence détectée\n")


Exécution du test pour toutes les tables

In [0]:
for t in tables_to_test:
    try:
        test_table_sample(
            source_table=t["source"],
            bronze_table=t["bronze"],
            primary_keys_source=t["primary_keys_source"],
            primary_keys_bronze=t["primary_keys_bronze"],
            columns_source=t["columns_source"],
            columns_bronze=t["columns_bronze"]
        )
    except Exception as e:
        print(f"Erreur lors du test de {t['source']} : {str(e)}")


Affiche les logs pour les processus bronze

In [0]:
query = f"""
SELECT * FROM {catalog}.logs.bronze_processing_log
ORDER BY `timestamp`
"""

df = spark.sql(query)
df.display()