In [0]:
import json
import base64
import tempfile
import os
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.sql.types import *
from google.cloud import bigquery
from google.oauth2 import service_account

In [0]:
# --- 1. CATALOG & CONFIG HELPERS ---

def load_catalog(catalog_path):
    return json.loads(dbutils.fs.head(catalog_path, 1_000_000))

In [0]:
CATALOG_TO_SPARK_TYPE = {
    "string": "string",
    "integer": "long",      # safer for MySQL / BigQuery
    "number": "double",
    "boolean": "boolean"
}

In [0]:
def build_table_mappings(catalog):
    tables = []

    for s in catalog.get("catalog", {}).get("streams", []):
        selected_cols = []
        column_types = {}

        properties = s.get("schema", {}).get("properties", {})

        for m in s.get("metadata", []):
            breadcrumb = m.get("breadcrumb")
            meta = m.get("metadata", {})

            if (
                meta.get("selected")
                and isinstance(breadcrumb, list)
                and len(breadcrumb) >= 2
                and breadcrumb[0] == "properties"
            ):
                col_name = breadcrumb[-1]
                selected_cols.append(col_name)

                col_schema = properties.get(col_name, {})
                types = col_schema.get("type", [])
                fmt = col_schema.get("format")

                if fmt == "date-time":
                    column_types[col_name] = "timestamp"
                else:
                    base_type = next(t for t in types if t != "null")
                    column_types[col_name] = CATALOG_TO_SPARK_TYPE.get(base_type, "string")

        tables.append({
            "source_table": s["stream"],
            "destination_table": s["destination_table"],
            "replication_method": s["replication_method"],
            "bookmark_columns": s.get("bookmark_properties", []),
            "unique_constraints": s.get("unique_constraints", []),
            "unique_conflict_method": s.get("unique_conflict_method"),
            "selected_columns": selected_cols,
            "column_types": column_types   
        })

    return tables

In [0]:
from pyspark.sql.functions import col

def cast_df_from_catalog(df, column_types):
    for col_name, spark_type in column_types.items():
        if col_name in df.columns:
            df = df.withColumn(col_name, col(col_name).cast(spark_type))
    return df

In [0]:
def get_bq_client(bq_cfg):
    key_json = base64.b64decode(bq_cfg["credentials_b64"]).decode()
    creds_info = json.loads(key_json)
    credentials = service_account.Credentials.from_service_account_info(creds_info)
    return bigquery.Client(credentials=credentials, project=bq_cfg["project_id"])

In [0]:
# --- 2. SCHEMA EVOLUTION HELPERS ---

def pandas_dtype_to_bq(dtype):
    dtype = str(dtype)
    if dtype == "Int64":
        return "INT64"
    elif dtype == "object":
        return "STRING"
    elif dtype.startswith("datetime64"):
        return "DATETIME"
    elif dtype=='float64':
        return "FLOAT64"
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")

In [0]:
def spark_type_to_bq(f_type):
    """Maps Spark DataType to BQ SQL types, keeping DATETIME for timestamps."""
     
    if isinstance(f_type, TimestampType):
        return "DATETIME"  # Maintains the datetime64 requirement
    elif isinstance(f_type, (DecimalType, LongType, IntegerType)):
        return "INT64"
    elif isinstance(f_type, BooleanType):
        return "BOOL"
    elif isinstance(f_type, (FloatType, DoubleType)):
        return "FLOAT64"
    else:
        return "STRING"

In [0]:
def evolve_bq_schema(
    bq_cfg,
    target_table,
    df_schema,
    client,
    replication_method,
    staging_table_exists=None
):
    """
    Detect new columns from Spark DF and add them to BigQuery target
    (and staging table if incremental).
    """

    target_full_path = f"{bq_cfg['project_id']}.{bq_cfg['dataset']}.{target_table}"
    staging_full_path = f"{bq_cfg['project_id']}.{bq_cfg['dataset']}.{target_table}_staging"

    # Fetch target schema
    bq_table = client.get_table(target_full_path)
    existing_columns = {
        field.name.lower()
        for field in bq_table.schema
        if not field.name.lower().startswith("_mage")
    }

    # Detect new columns
    new_columns_ddl = []
    for field in df_schema:
        if field.name.lower() not in existing_columns:
            bq_type = spark_type_to_bq(field.dataType)
            new_columns_ddl.append(f"ADD COLUMN `{field.name}` {bq_type}")

    if not new_columns_ddl:
        print("No new columns from source")
        return

    # Alter target table
    alter_target_sql = f"""
        ALTER TABLE `{target_full_path}`
        {', '.join(new_columns_ddl)}
    """
    print(f"Schema Evolution for target_table {target_table}")
    client.query(alter_target_sql).result()

    # Alter staging table only for incremental loads
    if replication_method == "INCREMENTAL" and staging_table_exists:
        alter_staging_sql = f"""
            ALTER TABLE `{staging_full_path}`
            {', '.join(new_columns_ddl)}
        """
        print(f"Schema Evolution for {staging_table}")
        client.query(alter_staging_sql).result()
    else:
        print("Staging table schema evolution skipped (FULL load)")
    return


In [0]:
# --- 3. DATA TRANSFORMATIONS ---

def normalize_decimals(df):
    for f in df.schema.fields:
        if isinstance(f.dataType, DecimalType):
            df = df.withColumn(f.name, F.col(f.name).cast("long"))
    return df

In [0]:
def boolean_to_int64(df):
    for f in df.schema.fields:
        if isinstance(f.dataType, BooleanType):
            df = df.withColumn(f.name, F.col(f.name).cast('byte'))
    return df

In [0]:
def normalize_timestamps(df):
    """
    Ensures the function returns TWO values: the dataframe and the list of modified columns.
    """
    for f in df.schema.fields:
        if isinstance(f.dataType, (TimestampType, DateType)):
            df = df.withColumn(
                f.name,
                F.date_format(F.col(f.name), "yyyy-MM-dd HH:mm:ss")
            )            
    # This must return a tuple of (DataFrame, List)
    return df

In [0]:

# DATABASE CONFIGS
def get_mysql_config(secret_scope):
    db_name = dbutils.secrets.get(secret_scope, "database")
    host = dbutils.secrets.get(secret_scope, "host")
    port = dbutils.secrets.get(secret_scope, "port")
    user = dbutils.secrets.get(secret_scope, "username")
    password = dbutils.secrets.get(secret_scope, "password")
    jdbc_url = f"jdbc:mysql://{host}:{port}/{db_name}?useSSL=false&serverTimezone=UTC"
    return {"jdbc_url": jdbc_url, "user": user, "password": password}

In [0]:
def get_bigquery_config(secret_scope, secret_prefix_bq, bq_sa_name):
    project_id = dbutils.secrets.get(secret_scope, "project_id")
    dataset = dbutils.secrets.get(secret_scope, f"{secret_prefix_bq}_dataset")
    staging_dataset = dbutils.secrets.get(secret_scope, "staging_dataset")
    key_json = dbutils.secrets.get(secret_scope, bq_sa_name)
    key_b64 = base64.b64encode(key_json.encode()).decode()
    return {"project_id": project_id, "dataset": dataset, "credentials_b64": key_b64, "staging_dataset": staging_dataset}

In [0]:
def table_exists_check(client, bq_cfg, staging_table):
    try:
        client.get_table(f"{bq_cfg['project_id']}.{bq_cfg['staging_dataset']}.{staging_table}")
        staging_table_exists = True
        return staging_table_exists
    except:
        staging_table_exists = False
        return staging_table_exists

In [0]:
def get_number_type_columns_for_all_streams(catalog):
    """
    Returns a dict:
    {
        stream_name: [number_type_columns]
    }
    """
    stream_number_cols = {}

    streams = catalog["catalog"]["streams"]

    for stream in streams:
        stream_name = stream["stream"]
        schema_props = stream["schema"]["properties"]

        number_cols = [
            col_name
            for col_name, col_def in schema_props.items()
            if "number" in col_def.get("type", [])
        ]

        if number_cols:
            stream_number_cols[stream_name] = number_cols

    return stream_number_cols

In [0]:
def cast_number_columns(df, number_columns):
    for col in number_columns:
        if col in df.columns:
            df = df.withColumn(col, F.col(col).cast(DoubleType()))
    return df

In [0]:
# --- 4. MAIN PIPELINE ---

def run_pipeline(
    catalog_path,
    mysql_secret_scope,
    bq_secret_scope,
    bq_sa_name,
    secret_prefix_bq,
    target_table_prefix,
    start_date,
    end_date
):
    # Initialization
    catalog = load_catalog(catalog_path)
    tables = build_table_mappings(catalog)
    mysql_cfg = get_mysql_config(mysql_secret_scope)
    bq_cfg = get_bigquery_config(bq_secret_scope, secret_prefix_bq, bq_sa_name)
    key_b64 = bq_cfg["credentials_b64"]
    client = get_bq_client(bq_cfg)
    print(f'********************************************************************',bq_cfg["staging_dataset"])
    # Get number columns for all tables
    number_columns_map = get_number_type_columns_for_all_streams(catalog)
    print(f"Number columns map:{number_columns_map}")

    for t in tables:
        src_table = t["source_table"]
        dest_table = t["destination_table"]
        target_table = f"{target_table_prefix}_{dest_table}"
        target_full_path = f"{bq_cfg['project_id']}.{bq_cfg['dataset']}.{target_table}"

        print(f"\nProcessing table: {dest_table}")

        # Check existence
        table_exists = False
        try:
            client.get_table(target_full_path)
            table_exists = True
        except Exception:
            table_exists = False

        # Incremental logic: Bookmark fetch
        bookmark_col = t["bookmark_columns"][0] if t["bookmark_columns"] else None
        where_clause = "1=1"

        # --------------------------------------------------
        # Step 1: Build WHERE clause
        # --------------------------------------------------

        if start_date and end_date:
            # Date-range based UPSERT
            where_clause = (
                f"{bookmark_col} >= '{start_date}' "
                f"AND {bookmark_col} < '{end_date}'"
            )
            print(f"Date range UPSERT -> {start_date} to {end_date}")

        elif t["replication_method"] == "INCREMENTAL" and table_exists and bookmark_col:
            try:
                max_query = f"SELECT MAX({bookmark_col}) FROM `{target_full_path}`"
                res = list(client.query(max_query).result())
                last_val = res[0][0]
                if last_val:
                    where_clause = f"{bookmark_col} > '{last_val}'"
                    print(f"Incremental Filter: {where_clause}")
            except Exception as e:
                print(f"Skipping bookmark fetch: {e}")

        # Read Source Data
        query = f"(SELECT * FROM {src_table} WHERE {where_clause}) t"
        df = (
            spark.read.format("jdbc")
            .option("url", mysql_cfg["jdbc_url"])
            .option("dbtable", query)
            .option("user", mysql_cfg["user"])
            .option("password", mysql_cfg["password"])
            .load()
        )

        if df.isEmpty():
            print("No data to process.")
            continue
        print(f"Rows fetched -> {df.count()}")

        if t["selected_columns"]:
            df = df.select(*[c for c in t["selected_columns"] if c in df.columns])

        df = cast_df_from_catalog(df, t["column_types"])
        
        if src_table in number_columns_map:
            df = cast_number_columns(df, number_columns_map[src_table])

        df = normalize_decimals(df)
        df = boolean_to_int64(df)

        # Write Logic
        if t["replication_method"] == "FULL_TABLE":
            evolve_bq_schema(bq_cfg, target_table, df.schema, client, t["replication_method"])
            df = normalize_timestamps(df)
            df.write.format("bigquery")\
            .option("credentials", key_b64)\
            .option("parentProject", bq_cfg["project_id"])\
            .option("project", bq_cfg["project_id"])\
            .option("dataset", bq_cfg["dataset"])\
            .option("table", target_table)\
            .option("allowFieldAddition", "true")\
            .option('writeMethod', 'direct')\
            .mode("overwrite")\
            .save()
            print(f"Full load completed {target_table}.")

        else:
            # Method 3: Merge Schema + Dynamic Merge
            staging_table = f"{target_table}_staging"
            staging_table_exists = table_exists_check(client, bq_cfg, staging_table)

            evolve_bq_schema(bq_cfg, target_table, df.schema, client, t["replication_method"], staging_table_exists)
            df = normalize_timestamps(df)
            
            query = f"""
                    SELECT *
                    FROM `{bq_cfg['project_id']}.{bq_cfg['dataset']}.{target_table}`
                    """
            target_df = client.query(query).to_dataframe()
            target_columns = target_df.columns.tolist()

            if not staging_table_exists:
                columns_ddl = []

                for col, dtype in target_df.dtypes.items():
                    bq_type = pandas_dtype_to_bq(dtype)
                    columns_ddl.append(f"`{col}` {bq_type}")

                columns_sql = ",\n  ".join(columns_ddl)

                create_table_sql = f"""
                CREATE OR REPLACE TABLE `{bq_cfg['project_id']}.{bq_cfg['staging_dataset']}.{staging_table}` (
                {columns_sql}
                )
                """
                client.query(create_table_sql).result()
                print("Staging Table created successfully")
            else:
                print("Staging Table already exists")
            
            display(df)
            # Write batch to staging
            df.write.format("bigquery")\
            .option("credentials", key_b64)\
            .option("parentProject", bq_cfg["project_id"])\
            .option("project", bq_cfg["project_id"])\
            .option("dataset", bq_cfg["staging_dataset"])\
            .option("table", staging_table)\
            .option("allowFieldAddition", "true")\
            .option("writeMethod", "direct")\
            .mode("overwrite")\
            .save()

            if t["unique_conflict_method"] == "UPDATE" and t["unique_constraints"]:
                merge_condition = " AND ".join(
                    [f"T.`{k}` = S.`{k}`" for k in t["unique_constraints"]]
                )
                update_columns = ",\n".join([f"T.{c} = S.{c}" for c in target_columns])
                
                merge_sql = f"""
                MERGE `{bq_cfg['project_id']}.{bq_cfg['dataset']}.{target_table}` T
                USING `{bq_cfg['project_id']}.{bq_cfg['staging_dataset']}.{staging_table}` S
                ON {merge_condition}
                WHEN MATCHED THEN
                UPDATE SET {update_columns}
                WHEN NOT MATCHED THEN
                INSERT ROW
                """
                client.query(merge_sql).result()
                print("*****************Merge Completed*****************")

            client.query(
                f"TRUNCATE TABLE `{bq_cfg['project_id']}.{bq_cfg['staging_dataset']}.{staging_table}`"
            ).result()
            print("\nPipeline execution finished")