In [33]:
from pyspark.sql.functions import *
from delta.tables import *
from datetime import datetime, timezone
import json
import base64
from pathlib import Path
import os

StatementMeta(, bab08844-21d0-4e24-a0eb-ae15aa15e06a, 35, Finished, Available)

In [12]:
class SyncConstants:
    OVERWRITE = "OVERWRITE"
    APPEND = "APPEND"
    FULL = "FULL"
    PARTITION = "PARTITION"
    WATERMARK = "WATERMARK"
    AUTO = "AUTO"
    TIME = "TIME"
    INITIAL_FULL_OVERWRITE = "INITIAL_FULL_OVERWRITE"
    INFORMATION_SCHEMA_TABLES = "INFORMATION_SCHEMA.TABLES"
    INFORMATION_SCHEMA_PARTITIONS = "INFORMATION_SCHEMA.PARTITIONS"
    INFORMATION_SCHEMA_COLUMNS = "INFORMATION_SCHEMA.COLUMNS"
    INFORMATION_SCHEMA_TABLE_CONSTRAINTS = "INFORMATION_SCHEMA.TABLE_CONSTRAINTS"
    INFORMATION_SCHEMA_KEY_COLUMN_USAGE = "INFORMATION_SCHEMA.KEY_COLUMN_USAGE"

    SQL_TBL_SYNC_SCHEDULE_PARTITION = "bq_sync_schedule_partition"
    SQL_TBL_SYNC_SCHEDULE = "bq_sync_schedule"
    SQL_TBL_SYNC_CONFIG = "bq_sync_configuration"
    SQL_TBL_DATA_TYPE_MAP = "bq_data_type_map"

StatementMeta(, bab08844-21d0-4e24-a0eb-ae15aa15e06a, 14, Finished, Available)

In [41]:
class ConfigBase():
    def __init__(self, config_path, gcp_credential):
        if config_path is None:
            raise ValueError("Missing Path to JSON User Config")
        
        if gcp_credential is None:
            raise ValueError("Missing GCP Credentials")

        self.ConfigPath = config_path
        self.UserConfig = None
        self.GCPCredential = None

        self.UserConfig = self.ensure_user_config()
        self.GCPCredential = self.load_gcp_credential(gcp_credential)
    
    def ensure_user_config(self):
        if self.UserConfig is None and self.ConfigPath is not None:
            config = self.load_user_config(self.ConfigPath)

            cfg = ConfigDataset(config)

            self.validate_user_config(cfg)
            
            return cfg
        else:
            return self.UserConfig
    
    def load_user_config(self, config_path):
        config_df = spark.read.option("multiline","true").json(config_path)
        config_df.createOrReplaceTempView("user_config_json")
        config_df.cache()
        return json.loads(config_df.toJSON().first())

    def validate_user_config(self, cfg):
        if cfg is None:
            raise RuntimeError("Invalid User Config")    
        return True

    def load_gcp_credential(self, credential):
        cred = None

        if self.is_base64(credential):
            cred = credential
        else:
            if os.path.exists(credential):
                file_contents = self.read_credential_file(credential)
                cred = self.convert_to_base64string(file_contents)
            else:
                raise ValueError("Invalid GCP Credential path supplied.")

        return cred

    def read_credential_file(self, credential_path):
        txt = Path(credential_path).read_text()
        txt = txt.replace("\n", "").replace("\r", "")

        return txt

    def convert_to_base64string(self, credential_val):
        credential_val_bytes = credential_val.encode("ascii") 
        
        base64_bytes = base64.b64encode(credential_val_bytes) 
        base64_string = base64_bytes.decode("ascii") 

        return base64_string

    def is_base64(self, val):
        try:
                if isinstance(val, str):
                        sb_bytes = bytes(val, 'ascii')
                elif isinstance(val, bytes):
                        sb_bytes = val
                else:
                        raise ValueError("Argument must be string or bytes")
                return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes
        except Exception:
                return False

    def read_bq_to_dataframe(self, query, cache_results=False):
        df = spark.read \
            .format("bigquery") \
            .option("parentProject", self.UserConfig.ProjectID) \
            .option("credentials", self.GCPCredential) \
            .option("viewsEnabled", "true") \
            .option("materializationDataset", self.UserConfig.Dataset) \
            .load(query)
        
        if cache_results:
            df.cache()
        
        return df

    def write_lakehouse_table(self, df, lakehouse, tbl_nm, mode=SyncConstants.OVERWRITE):
        dest_table = self.UserConfig.get_lakehouse_tablename(lakehouse, tbl_nm)

        df.write \
            .mode(mode) \
            .saveAsTable(dest_table)
    
    def create_infosys_proxy_view(self, trgt):
        clean_nm = trgt.replace(".", "_")
        tbl = self.UserConfig.flatten_3part_tablename(clean_nm)
        lakehouse_tbl = self.UserConfig.get_lakehouse_tablename(self.UserConfig.MetadataLakehouse, tbl)

        sql = f"""
        CREATE OR REPLACE TEMPORARY VIEW BQ_{clean_nm}
        AS
        SELECT *
        FROM {lakehouse_tbl}
        """
        spark.sql(sql)

    def create_userconfig_tables_proxy_view(self):
        sql = """
            CREATE OR REPLACE TEMPORARY VIEW user_config_tables
            AS
            SELECT
                project_id, dataset, tbl.table_name,
                tbl.enabled,tbl.load_priority,tbl.source_query,
                tbl.load_strategy,tbl.load_type,tbl.interval,
                tbl.watermark.column as watermark_column,
                tbl.partitioned.enabled as partition_enabled,
                tbl.partitioned.type as partition_type,
                tbl.partitioned.column as partition_column,
                tbl.partitioned.partition_grain,
                tbl.lakehouse_target.lakehouse,
                tbl.lakehouse_target.table_name AS lakehouse_target_table,
                tbl.keys
            FROM (SELECT project_id, dataset, EXPLODE(tables) AS tbl FROM user_config_json)
        """
        spark.sql (sql)

    def create_userconfig_tables_cols_proxy_view(self):
        sql = """
            CREATE OR REPLACE TEMPORARY VIEW user_config_table_keys
            AS
            SELECT
                project_id, dataset, table_name, pkeys.column
            FROM (
                SELECT
                    project_id, dataset, tbl.table_name, EXPLODE(tbl.keys) AS pkeys
                FROM (SELECT project_id, dataset, EXPLODE(tables) AS tbl FROM user_config_json)
            )
        """
        spark.sql(sql)

    def create_proxy_views(self):
        self.create_userconfig_tables_proxy_view()
        self.create_userconfig_tables_cols_proxy_view()
        self.create_infosys_proxy_view(SyncConstants.INFORMATION_SCHEMA_TABLES)
        self.create_infosys_proxy_view(SyncConstants.INFORMATION_SCHEMA_PARTITIONS)
        self.create_infosys_proxy_view(SyncConstants.INFORMATION_SCHEMA_COLUMNS)
        self.create_infosys_proxy_view(SyncConstants.INFORMATION_SCHEMA_TABLE_CONSTRAINTS)
        self.create_infosys_proxy_view(SyncConstants.INFORMATION_SCHEMA_KEY_COLUMN_USAGE)

StatementMeta(, bab08844-21d0-4e24-a0eb-ae15aa15e06a, 43, Finished, Available)

In [13]:
class Scheduler(ConfigBase):
    def __init__(self, config_path, gcp_credential):
        super().__init__(config_path, gcp_credential)

    def run(self):
        sql = f"""
        WITH new_schedule AS ( 
            SELECT CURRENT_TIMESTAMP() as scheduled
        ),
        last_bq_tbl_updates AS (
            SELECT table_catalog, table_schema, table_name, max(last_modified_time) as last_bq_tbl_update
            FROM bq_information_schema_partitions
            GROUP BY table_catalog, table_schema, table_name
        ),
        last_load AS (
            SELECT project_id, dataset, table_name, MAX(started) AS last_load_update
            FROM {SyncConstants.SQL_TBL_SYNC_SCHEDULE}
            WHERE status='COMPLETE'
            GROUP BY project_id, dataset, table_name
        ),
        schedule AS (
            SELECT
                UUID() AS schedule_id,
                c.project_id,
                c.dataset,
                c.table_name,
                n.scheduled,
                CASE WHEN ((l.last_load_update IS NULL) OR
                     (b.last_bq_tbl_update >= l.last_load_update))
                    THEN 'SCHEDULED' ELSE 'SKIPPED' END as status,
                NULL as started,
                NULL as completed,
                NULL as src_row_count,
                NULL as dest_row_count,
                NULL as dest_inserted_row_count,
                NULL as dest_updated_row_count,
                NULL as delta_version,
                NULL as spark_application_id,
                NULL as max_watermark,
                NULL as summary_load,
                c.priority
            FROM {SyncConstants.SQL_TBL_SYNC_CONFIG} c 
            LEFT JOIN {SyncConstants.SQL_TBL_SYNC_SCHEDULE} s ON 
                c.project_id= s.project_id AND
                c.dataset = s.dataset AND
                c.table_name = s.table_name AND
                s.status = 'SCHEDULED'
            LEFT JOIN last_bq_tbl_updates b ON
                c.project_id= b.table_catalog AND
                c.dataset = b.table_schema AND
                c.table_name = b.table_name
            LEFT JOIN last_load l ON 
                c.project_id= l.project_id AND
                c.dataset = l.dataset AND
                c.table_name = l.table_name
            CROSS JOIN new_schedule n
            WHERE s.schedule_id IS NULL
            AND c.enabled = TRUE
        )

        INSERT INTO {SyncConstants.SQL_TBL_SYNC_SCHEDULE}
        SELECT * FROM schedule s
        WHERE s.project_id = '{self.UserConfig.ProjectID}'
        AND s.dataset = '{self.UserConfig.Dataset}'
        """
        spark.sql(sql)

StatementMeta(, bab08844-21d0-4e24-a0eb-ae15aa15e06a, 15, Finished, Available)

In [14]:
class SyncSchedule:
    EndTime = None
    SourceRows = 0
    DestRows = 0
    InsertedRows = 0
    UpdatedRows = 0
    DeltaVersion = None
    SparkAppId = None
    MaxWatermark = None
    Status = None

    def __init__(self, row):
        self.Row = row
        self.StartTime = datetime.now(timezone.utc)
        self.ScheduleId = row["schedule_id"]
        self.LoadStrategy = row["load_strategy"]
        self.LoadType = row["load_type"]
        self.InitialLoad = row["initial_load"]
        self.ProjectId = row["project_id"]
        self.Dataset = row["dataset"]
        self.TableName = row["table_name"]
        self.SourceQuery = row["source_query"]
        self.MaxWatermark = row["max_watermark"]
        self.IsPartitioned = row["is_partitioned"]
        self.PartitionColumn = row["partition_column"]
        self.PartitionType = row["partition_type"]
        self.PartitionGrain = row["partition_grain"]
        self.WatermarkColumn = row["watermark_column"]
        self.LastScheduleLoadDate = row["last_schedule_dt"]
        self.Lakehouse = row["lakehouse"]
        self.DestinationTableName = row["lakehouse_table_name"]
    
    @property
    def SummaryLoadType(self):
        if self.InitialLoad:
            return SyncConstants.INITIAL_FULL_OVERWRITE
        else:
            return "{0}_{1}".format(self.LoadStrategy, self.LoadType)
    
    @property
    def Mode(self):
        if self.InitialLoad:
            return SyncConstants.OVERWRITE
        else:
            return self.LoadType
    
    @property
    def PrimaryKey(self):
        if self.Row["primary_keys"]:
            return self.Row["primary_keys"][0]
        else:
            return None
    
    @property
    def LakehouseTableName(self):
        return "{0}.{1}".format(self.Lakehouse, self.DestinationTableName)
        
    @property
    def BQTableName(self):
        return "{0}.{1}.{2}".format(self.ProjectId, self.Dataset, self.TableName)

    @property
    def IsTimeIngestionPartitioned(self):
        is_time = False

        if self.PartitionColumn == "_PARTITIONTIME" or self.PartitionColumn == "_PARTITIONDATE":
            is_time = True;

        return is_time;


    def UpdateRowCounts(self, src, dest, insert, update):
        self.SourceRows += src
        self.DestRows += dest

        match self.LoadStrategy:
            case SyncConstants.WATERMARK:
                self.InsertedRows += src     
            case SyncConstants.PARTITION:
                self.InsertedRows += dest  
            case _:
                self.InsertedRows += dest

        self.UpdatedRows = 0

StatementMeta(, bab08844-21d0-4e24-a0eb-ae15aa15e06a, 16, Finished, Available)

In [15]:
class ConfigDataset:
    def __init__(self, json_config):
        self.ProjectID = self.get_json_conf_val(json_config, "project_id", None)
        self.Dataset = self.get_json_conf_val(json_config, "dataset", None)
        self.LoadAllTables = self.get_json_conf_val(json_config, "load_all_tables", True)
        self.Autodetect = self.get_json_conf_val(json_config, "autodetect", True)
        self.MasterReset = self.get_json_conf_val(json_config, "master_reset", False)
        self.MetadataLakehouse = self.get_json_conf_val(json_config, "metadata_lakehouse", None)
        self.TargetLakehouse = self.get_json_conf_val(json_config, "target_lakehouse", None)

        self.Tables = []

        if "tables" in json_config:
            for t in json_config["tables"]:
                self.Tables.append(ConfigBQTable(t))
    
    def get_delimited_tables_list(self):
            return ''.join([str(x.TableName) for x in self.Tables])

    def get_table_name_list(self):
        return [str(x.TableName) for x in self.Tables]

    def get_bq_table_fullname(self, tbl_name):
        return f"{self.ProjectID}.{self.Dataset}.{tbl_name}"

    def get_lakehouse_tablename(self, lakehouse, tbl_name):
        return f"{lakehouse}.{tbl_name}"

    def flatten_3part_tablename(self, tbl_name):
        clean_project_id = self.ProjectID.replace("-", "_")
        return f"{clean_project_id}_{self.Dataset}_{tbl_name}"
    
    def get_json_conf_val(self, json, config_key, default_val = None):
        if config_key in json:
            return json[config_key]
        else:
            return default_val

class ConfigTableColumn:
    def __init__(self, col = ""):
        self.Column = col

class ConfigLakehouseTarget:
    def __init__(self, lakehouse = "", table = ""):
        self.Lakehouse = lakehouse
        self.Table = table

class ConfigPartition:
    def __init__(self, enabled = False, partition_type = "", col = ConfigTableColumn(), grain = ""):
        self.Enabled = enabled
        self.PartitionType = partition_type
        self.PartitionColumn = col
        self.Granularity = grain

class ConfigBQTable:
    def __str__(self):
        return str(self.TableName)

    def __init__(self, json_config):
        self.TableName = self.get_json_conf_val(json_config, "table_name", "")
        self.Priority = self.get_json_conf_val(json_config, "priority", 100)
        self.SourceQuery = self.get_json_conf_val(json_config, "source_query", "")
        self.LoadStrategy = self.get_json_conf_val(json_config, "load_strategy" , SyncConstants.FULL)
        self.LoadType = self.get_json_conf_val(json_config, "load_type", SyncConstants.OVERWRITE)
        self.Interval =  self.get_json_conf_val(json_config, "interval", SyncConstants.AUTO)
        self.Enabled =  self.get_json_conf_val(json_config, "enabled", True)

        if "lakehouse_target" in json_config:
            self.LakehouseTarget = ConfigLakehouseTarget( \
                self.get_json_conf_val(json_config["lakehouse_target"], "lakehouse", ""), \
                self.get_json_conf_val(json_config["lakehouse_target"], "table_name", ""))
        else:
            self.LakehouseTarget = ConfigLakehouseTarget()
        
        if "watermark" in json_config:
            self.Watermark = ConfigTableColumn( \
                self.get_json_conf_val(json_config["watermark"], "column", ""))
        else:
            self.Watermark = ConfigTableColumn()

        if "partitioned" in json_config:
            self.Partitioned = ConfigPartition( \
                self.get_json_conf_val(json_config["partitioned"], "enabled", False), \
                self.get_json_conf_val(json_config["partitioned"], "type", ""), \
                self.get_json_conf_val(json_config["partitioned"], "column", ""), \
                self.get_json_conf_val(json_config["partitioned"], "partition_grain", ""))
        else:
            self.Partitioned = ConfigPartition()
        
        self.Keys = []

        if "keys" in json_config:
            for c in json_config["keys"]:
                self.Keys.append(ConfigTableColumn( \
                    self.get_json_conf_val(c, "column", "")))
        
    def get_json_conf_val(self, json, config_key, default_val = None):
        if config_key in json:
            return json[config_key]
        else:
            return default_val

StatementMeta(, bab08844-21d0-4e24-a0eb-ae15aa15e06a, 17, Finished, Available)

In [1]:
class ConfigMetadataLoader(ConfigBase):
    def __init__(self, config_path, gcp_credential):
        super().__init__(config_path, gcp_credential)
    
    def create_autodetect_view(self):
        sql = """
        CREATE OR REPLACE TEMPORARY VIEW bq_table_metadata_autodetect
        AS
        WITH pkeys AS (    
            SELECT
                c.table_catalog, c.table_schema, c.table_name, 
                k.column_name AS pk_col
            FROM bq_information_schema_table_constraints c
            JOIN bq_information_schema_key_column_usage k ON
                k.table_catalog = c.table_catalog AND
                k.table_schema = c.table_schema AND
                k.table_name = c.table_name AND
                k.constraint_name = c.constraint_name
            JOIN bq_information_schema_columns n ON
                n.table_catalog = k.table_catalog AND
                n.table_schema = k.table_schema AND
                n.table_name = k.table_name AND
                n.column_name = k.column_name
            JOIN bq_data_type_map m ON n.data_type = m.data_type
            WHERE c.constraint_type = 'PRIMARY KEY'
            AND m.is_watermark = 'YES'
        ),
        pkeys_cnt AS (
            SELECT 
                table_catalog, table_schema, table_name, 
                COUNT(*) as pk_cnt
            FROM pkeys
            GROUP BY table_catalog, table_schema, table_name
        ),
        watermark_cols AS (
            SELECT 
                k.*
            FROM pkeys k
            JOIN pkeys_cnt c ON 
                k.table_catalog = c.table_catalog AND
                k.table_schema = c.table_schema AND
                k.table_name = c.table_name
            WHERE c.pk_cnt = 1
        ),
        partitions AS (
            SELECT
                table_catalog, table_schema, table_name, 
                count(*) as partition_count,
                avg(len(partition_id)) AS partition_id_len,
                sum(case when partition_id is NULL then 1 else 0 end) as null_partition_count
            FROM bq_information_schema_partitions
            GROUP BY table_catalog, table_schema, table_name
        ), 
        partition_columns AS
        (
            SELECT
                table_catalog, table_schema, table_name,
                column_name, c.data_type,
                m.partition_type AS partitioning_type
            FROM bq_information_schema_columns c
            JOIN bq_data_type_map m ON c.data_type=m.data_type
            WHERE is_partitioning_column = 'YES'
        ),
        partition_cfg AS
        (
            SELECT
                p.*,
                CASE WHEN p.partition_count = 1 AND p.null_partition_count = 1 THEN FALSE ELSE TRUE END AS is_partitioned,
                c.column_name AS partition_col,
                c.data_type AS partition_data_type,
                c.partitioning_type,
                CASE WHEN (c.partitioning_type = 'TIME')
                    THEN 
                        CASE WHEN (partition_id_len = 4) THEN 'YEAR'
                            WHEN (partition_id_len = 6) THEN 'MONTH'
                            WHEN (partition_id_len = 8) THEN 'DAY'
                            WHEN (partition_id_len = 10) THEN 'HOUR'
                            ELSE NULL END
                    ELSE NULL END AS partitioning_strategy
            FROM partitions p
            LEFT JOIN partition_columns c ON 
                p.table_catalog = c.table_catalog AND
                p.table_schema = c.table_schema AND
                p.table_name = c.table_name
        )

        SELECT 
            t.table_catalog, t.table_schema, t.table_name, t.is_insertable_into,
            p.is_partitioned, p.partition_col, p.partition_data_type, p.partitioning_type, p.partitioning_strategy,
            w.pk_col
        FROM bq_information_schema_tables t
        LEFT JOIN watermark_cols w ON 
            t.table_catalog = w.table_catalog AND
            t.table_schema = w.table_schema AND
            t.table_name = w.table_name
        LEFT JOIN partition_cfg p ON
            t.table_catalog = p.table_catalog AND
            t.table_schema = p.table_schema AND
            t.table_name = p.table_name
        """

        spark.sql(sql)

    def sync_bq_information_schema_tables(self):
        bq_table = self.UserConfig.get_bq_table_fullname(SyncConstants.INFORMATION_SCHEMA_TABLES)
        tbl_nm = self.UserConfig.flatten_3part_tablename(SyncConstants.INFORMATION_SCHEMA_TABLES.replace(".", "_"))

        bql = f"""
        SELECT *
        FROM {bq_table}
        WHERE table_type='BASE TABLE'
        AND table_name NOT LIKE '_bqc_%'
        """

        df = self.read_bq_to_dataframe(bql)

        if not self.UserConfig.LoadAllTables:
            filter_list = self.UserConfig.get_table_name_list()
            df = df.filter(col("table_name").isin(filter_list))    

        self.write_lakehouse_table(df, self.UserConfig.MetadataLakehouse, tbl_nm)

    def sync_bq_information_schema_table_dependent(self, dependent_tbl):
        bq_table = self.UserConfig.get_bq_table_fullname(SyncConstants.INFORMATION_SCHEMA_TABLES)
        bq_dependent_tbl = self.UserConfig.get_bq_table_fullname(dependent_tbl)
        tbl_nm = self.UserConfig.flatten_3part_tablename(dependent_tbl.replace(".", "_"))

        bql = f"""
        SELECT c.*
        FROM {bq_dependent_tbl} c
        JOIN {bq_table} t ON 
        t.table_catalog=c.table_catalog AND
        t.table_schema=c.table_schema AND
        t.table_name=c.table_name
        WHERE t.table_type='BASE TABLE'
        AND t.table_name NOT LIKE '_bqc_%'
        """

        df = self.read_bq_to_dataframe(bql)

        if not self.UserConfig.LoadAllTables:
            filter_list = self.UserConfig.get_table_name_list()
            df = df.filter(col("table_name").isin(filter_list)) 

        self.write_lakehouse_table(df, self.UserConfig.MetadataLakehouse, tbl_nm)

    def sync_bq_metadata(self):
        self.sync_bq_information_schema_tables()
        self.sync_bq_information_schema_table_dependent(SyncConstants.INFORMATION_SCHEMA_PARTITIONS)
        self.sync_bq_information_schema_table_dependent(SyncConstants.INFORMATION_SCHEMA_COLUMNS)
        self.sync_bq_information_schema_table_dependent(SyncConstants.INFORMATION_SCHEMA_TABLE_CONSTRAINTS)
        self.sync_bq_information_schema_table_dependent(SyncConstants.INFORMATION_SCHEMA_KEY_COLUMN_USAGE)

    def create_proxy_views(self):
        super().create_proxy_views()
        self.create_autodetect_view()

    def auto_detect_table_profiles(self):
        self.create_proxy_views()
        
        sql = f"""
        WITH default_config AS (
            SELECT autodetect, target_lakehouse FROM user_config_json
        ),
        pk AS (
            SELECT
            a.table_catalog, a.table_schema, a.table_name, array_agg(COALESCE(a.pk_col, u.column)) as pk
            FROM bq_table_metadata_autodetect a
            LEFT JOIN user_config_table_keys u ON
                a.table_catalog = u.project_id AND
                a.table_schema = u.dataset AND
                a.table_name = u.table_name
            GROUP BY a.table_catalog, a.table_schema, a.table_name
        ),
        source AS (
            SELECT
                a.table_catalog as project_id,
                a.table_schema as dataset,
                a.table_name as table_name,
                COALESCE(u.enabled, TRUE) AS enabled,
                COALESCE(u.lakehouse, d.target_lakehouse) AS lakehouse,
                COALESCE(u.lakehouse_target_table, a.table_name) AS lakehouse_table_name,
                COALESCE(u.source_query, '') AS source_query,
                COALESCE(u.load_priority, '100') AS priority,
                CASE WHEN (COALESCE(u.watermark_column, a.pk_col) IS NOT NULL AND
                        COALESCE(u.watermark_column, a.pk_col) <> '') THEN 'WATERMARK' 
                    WHEN (COALESCE(u.partition_enabled, a.is_partitioned) = TRUE) 
                        AND COALESCE(u.partition_column, a.partition_col, '') NOT IN 
                            ('_PARTITIONTIME', '_PARTITIONDATE')
                    THEN 'PARTITION' ELSE 'FULL' END AS load_strategy,
                CASE WHEN (COALESCE(u.watermark_column, a.pk_col) IS NOT NULL AND
                        COALESCE(u.watermark_column, a.pk_col) <> '') THEN 'APPEND' ELSE
                    'OVERWRITE' END AS load_type,
                COALESCE(u.interval, 'AUTO') AS interval,
                p.pk AS primary_keys,
                COALESCE(u.partition_enabled, a.is_partitioned) AS is_partitioned,
                COALESCE(u.partition_column, a.partition_col, '') AS partition_column,
                COALESCE(u.partition_type, a.partitioning_type, '') AS partition_type,
                COALESCE(u.partition_grain, a.partitioning_strategy, '') AS partition_grain,
                COALESCE(u.watermark_column, a.pk_col, '') AS watermark_column, 
                d.autodetect,
                CASE WHEN u.table_name IS NULL THEN FALSE ELSE TRUE END AS config_override,
                'INIT' AS sync_state,
                CURRENT_TIMESTAMP() as created_dt,
                NULL as last_updated_dt
            FROM bq_table_metadata_autodetect a
            JOIN pk p ON
                a.table_catalog = p.table_catalog AND
                a.table_schema = p.table_schema AND
                a.table_name = p.table_name
            LEFT JOIN user_config_tables u ON 
                a.table_catalog = u.project_id AND
                a.table_schema = u.dataset AND
                a.table_name = u.table_name
            CROSS JOIN default_config d
        )

        MERGE INTO {SyncConstants.SQL_TBL_SYNC_CONFIG} t
        USING source s
        ON t.project_id = s.project_id AND
            t.dataset = s.dataset AND
            t.table_name = s.table_name
        WHEN MATCHED AND t.sync_state <> 'INIT' THEN
            UPDATE SET
                t.enabled = s.enabled,
                t.interval = s.interval,
                t.priority = s.priority,
                t.last_updated_dt = CURRENT_TIMESTAMP()
        WHEN MATCHED AND t.sync_state = 'INIT' THEN
            UPDATE SET *
        WHEN NOT MATCHED THEN
            INSERT *
        """

        spark.sql(sql)

StatementMeta(, , , Waiting, )

In [11]:
class BQScheduleLoader(ConfigBase):
    def __init__(self, config_path, gcp_credential):
        super().__init__(config_path, gcp_credential)
        super().create_proxy_views()

    def save_telemetry(self, telemetry: SyncSchedule):
        df = spark.table(SyncConstants.SQL_TBL_SYNC_SCHEDULE) \
            .filter("schedule_id=='{0}'".format(telemetry.ScheduleId)) \
            .withColumn("started", lit(telemetry.StartTime)) \
            .withColumn("src_row_count", lit(telemetry.SourceRows)) \
            .withColumn("dest_row_count", lit(telemetry.DestRows)) \
            .withColumn("dest_inserted_row_count", lit(telemetry.InsertedRows)) \
            .withColumn("dest_updated_row_count", lit(telemetry.UpdatedRows)) \
            .withColumn("delta_version", lit(telemetry.DeltaVersion)) \
            .withColumn("spark_application_id", lit(telemetry.SparkAppId)) \
            .withColumn("max_watermark", lit(telemetry.MaxWatermark)) \
            .withColumn("summary_load", lit(telemetry.SummaryLoadType)) \
            .withColumn("status", lit("COMPLETE")) \
            .withColumn("completed", lit(telemetry.EndTime))

        schedule_df = DeltaTable.forName(spark, SyncConstants.SQL_TBL_SYNC_SCHEDULE).alias("t")

        schedule_df.merge( \
            df.alias('s'), 't.schedule_id = s.schedule_id') \
            .whenMatchedUpdate(set = \
            { \
            "status": "s.status", \
            "started": "s.started", \
            "completed": "s.completed", \
            "src_row_count": "s.src_row_count", \
            "dest_row_count": "s.dest_row_count", \
            "dest_inserted_row_count": "s.dest_inserted_row_count", \
            "dest_updated_row_count": "s.dest_updated_row_count", \
            "delta_version": "s.delta_version", \
            "spark_application_id": "s.spark_application_id", \
            "max_watermark": "s.max_watermark", \
            "summary_load": "s.summary_load" \
            } \
        ).execute()

        if telemetry.LoadStrategy == SyncConstants.PARTITION and not telemetry.InitialLoad:
            self.save_partition_telemtry(telemetry)

    def save_partition_telemtry(self, telemetry: SyncSchedule):
        sql = f"""
        SELECT 
            s.schedule_id, s.project_id, s.dataset, s.table_name, 
            p.partition_id, p.total_rows AS bq_total_rows, 
            p.last_modified_time AS bq_last_modified, p.storage_tier as bq_storage_tier,
            s.started, s.completed
        FROM {SyncConstants.SQL_TBL_SYNC_SCHEDULE_PARTITION} s
        JOIN bq_information_schema_partitions p ON
            s.project_id=p.table_catalog AND
            s.dataset=p.table_schema AND 
            s.table_name=p.table_name
        WHERE s.schedule_id='{telemetry.ScheduleId}'
        """
        df = spark.sql(sql)

        df.write.mode(SyncConstants.APPEND).saveAsTable(SyncConstants.SQL_TBL_SYNC_SCHEDULE_PARTITION)

    def get_table_delta_version(self, tbl):
        sql = f"DESCRIBE HISTORY {tbl}"
        df = spark.sql(sql) \
            .select(max(col("version")).alias("delta_version"))

        for row in df.collect():
            return row["delta_version"]

    def update_sync_config_state(self, project_id, dataset, table_name):
        sql = f"""
        UPDATE {SyncConstants.SQL_TBL_SYNC_CONFIG} 
        SET sync_state='COMMIT' 
        WHERE
            project_id='{project_id}' AND
            dataset='{dataset}' AND
            table_name='{table_name}'
        """
        spark.sql(sql)

    def get_table_partition_metadata(self, schedule):
        sql = f"""
        WITH last_load AS (
            SELECT project_id, dataset, table_name, MAX(started) AS last_load_update
            FROM {SyncConstants.SQL_TBL_SYNC_SCHEDULE}
            WHERE status='COMPLETE'
            GROUP BY project_id, dataset, table_name
        )

        SELECT
            sp.table_name, sp.partition_id, sp.total_rows, sp.last_modified_time, sp.storage_tier,
            s.last_load_update AS last_part_load
        FROM bq_information_schema_partitions sp
        LEFT JOIN last_load s ON 
            sp.table_catalog = s.project_id AND 
            sp.table_schema = s.dataset AND
            sp.table_name = s.table_name
        WHERE sp.table_catalog = '{schedule.ProjectId}'
        AND sp.table_schema = '{schedule.Dataset}'
        AND sp.table_name = '{schedule.TableName}'
        AND sp.last_modified_time >= s.last_load_update
        """
        return spark.sql(sql)

    def get_schedule(self):
        sql = f"""
        WITH last_completed_schedule AS (
            SELECT schedule_id, project_id, dataset, table_name, max_watermark, started AS last_schedule_dt
            FROM (
                SELECT schedule_id, project_id, dataset, table_name, started, max_watermark,
                ROW_NUMBER() OVER(PARTITION BY project_id, dataset, table_name ORDER BY scheduled DESC) AS row_num
                FROM {SyncConstants.SQL_TBL_SYNC_SCHEDULE}
                WHERE status='COMPLETE'
            )
            WHERE row_num = 1
        )

        SELECT c.*, 
            s.schedule_id,
            h.max_watermark,
            h.last_schedule_dt,
            CASE WHEN (h.schedule_id IS NULL) THEN TRUE ELSE FALSE END AS initial_load
        FROM {SyncConstants.SQL_TBL_SYNC_CONFIG} c
        JOIN {SyncConstants.SQL_TBL_SYNC_SCHEDULE} s ON 
            c.project_id = s.project_id AND
            c.dataset = s.dataset AND
            c.table_name = s.table_name
        LEFT JOIN last_completed_schedule h ON
            c.project_id = h.project_id AND
            c.dataset = h.dataset AND
            c.table_name = h.table_name
        WHERE s.status = 'SCHEDULED'
            AND s.priority = 100
            AND c.enabled = TRUE
            AND c.project_id = '{self.UserConfig.ProjectID}' 
            AND c.dataset = '{self.UserConfig.Dataset}'
        """
        df = spark.sql(sql)
        df.cache()

        return df

    def build_bq_partition_query(self, schedule, predicate = []):
        query = f"SELECT * FROM {schedule.BQTableName}"

        if predicate:
            query += " WHERE "
            query += " OR ".join(predicate)
        
        print(query)
        return query

    def get_max_watermark(self, lakehouse_tbl, watermark_col):
        df = spark.table(lakehouse_tbl) \
            .select(max(col(watermark_col)).alias("watermark"))

        for row in df.collect():
            return row["watermark"]

    def sync_bq_table(self, row):
        schedule = SyncSchedule(row)

        print("{0} {1}...".format(schedule.SummaryLoadType, schedule.TableName))

        if schedule.LoadStrategy == SyncConstants.PARTITION and not schedule.InitialLoad:
            print("Load by partition...")
            df_partitions = self.get_table_partition_metadata(schedule)

            predicate = []
            partitions_to_write = []

            for p in df_partitions.collect():
                part_id = p["partition_id"]

                if p["last_modified_time"] > p["last_part_load"]:
                    print("Partition {0} has changes...".format(part_id))

                    partitions_to_write.append(part_id)
                    
                    match schedule.PartitionGrain:
                        case "DAY":
                            part_format = "%Y%m%d"
                        case "MONTH":
                            part_format = "%Y%m"
                        case "YEAR":
                            part_format = "%Y"
                        case "HOUR":
                            part_format = "%Y%m%d%H"
                        case _:
                            raise Exception("Unsupported Partition Grain in Table Config")
                                
                    predicate.append(f"date_trunc({schedule.PartitionColumn}, {schedule.PartitionGrain}) = PARSE_DATETIME('{part_format}', '{part_id}')")
                else:
                    print("Partition {0} is up to date...".format(part_id))

            src = self.build_bq_partition_query(schedule, predicate)
        else:
            src = schedule.BQTableName     

            if schedule.SourceQuery != "":
                src = schedule.SourceQuery

        df_bq = super().read_bq_to_dataframe(src)

        predicate = None

        if schedule.LoadStrategy == SyncConstants.WATERMARK and not schedule.InitialLoad:
            pk = schedule.PrimaryKey
            max_watermark = schedule.MaxWatermark

            if max_watermark.isdigit():
                predicate = f"{pk} > {max_watermark}"
            else:
                predicate = f"{pk} > '{max_watermark}'"
            
        if predicate is not None:
            df_bq = df_bq.where(predicate)

        df_bq.cache()

        partition = None

        if schedule.IsPartitioned and not schedule.IsTimeIngestionPartitioned:
            print('Resolving Fabric partitioning...')
            if schedule.PartitionType == SyncConstants.TIME:
                partition_col = schedule.PartitionColumn
                part_format = ""
                part_col_name = f"__bq_part_{partition_col}"
                use_proxy_col = False

                match schedule.PartitionGrain:
                    case "DAY":
                        part_format = "yyyyMMdd"

                        if dict(df_bq.dtypes)[partition_col] == "date":
                            partition = partition_col
                        else:
                            partition = f"{part_col_name}_DAY"
                            use_proxy_col = True
                    case "MONTH":
                        part_format = "yyyyMM"
                        partition = f"{part_col_name}_MONTH"
                        use_proxy_col = True
                    case "YEAR":
                        part_format = "yyyy"
                        partition = f"{part_col_name}_YEAR"
                        use_proxy_col = True
                    case "HOUR":
                        part_format = "yyyyMMddHH"
                        partition = f"{part_col_name}_HOUR"
                        use_proxy_col = True
                    case _:
                        print('Unsupported partition grain...')
                
                print("{0} partitioning - partitioned by {1} (Requires Proxy Column: {2})".format( \
                    row["partition_grain"], \
                    partition, \
                    use_proxy_col))
                
                if use_proxy_col:
                    df_bq = df_bq.withColumn(partition, date_format(col(partition_col), part_format))

        if schedule.LoadStrategy == SyncConstants.PARTITION and not schedule.InitialLoad:
            if partitions_to_write:
                for part in partitions_to_write:
                    print(f"Writing {schedule.TableName}${part} partition...")
                    part_filter = f"{partition} = '{part}'"
                    
                    pdf = df_bq.where(part_filter)
                    part_cnt = pdf.count()
                    
                    schedule.UpdateRowCounts(0, part_cnt, 0, 0)

                    pdf.write \
                        .mode(SyncConstants.OVERWRITE) \
                        .option("replaceWhere", part_filter) \
                        .saveAsTable(schedule.LakehouseTableName)
        else:
            if partition is None:
                df_bq.write \
                    .mode(schedule.Mode) \
                    .saveAsTable(schedule.LakehouseTableName)
            else:
                df_bq.write \
                    .partitionBy(partition) \
                    .mode(schedule.Mode) \
                    .saveAsTable(schedule.LakehouseTableName)

        if schedule.LoadStrategy == SyncConstants.WATERMARK:
            schedule.MaxWatermark = self.get_max_watermark(schedule.LakehouseTableName, schedule.PrimaryKey)

        src_cnt = df_bq.count()

        if schedule.LoadStrategy != SyncConstants.PARTITION or schedule.InitialLoad:
            dest_cnt = spark.table(schedule.LakehouseTableName).count()
        else:
            dest_cnt = 0

        schedule.UpdateRowCounts(src_cnt, dest_cnt, 0, 0)    
        schedule.SparkAppId = spark.sparkContext.applicationId
        schedule.DeltaVersion = self.get_table_delta_version(schedule.LakehouseTableName)
        schedule.EndTime = datetime.now(timezone.utc)

        df_bq.unpersist()

        self.save_telemetry(schedule)

        if schedule.InitialLoad:
            self.update_sync_config_state(schedule.ProjectId, schedule.Dataset, schedule.TableName)

    def run_schedule(self):
        df_schedule = self.get_schedule()

        for row in df_schedule.collect():
            self.sync_bq_table(row)  

StatementMeta(, 188d4f89-f45c-455e-8263-facaf270c5df, 13, Finished, Available)

In [12]:
class SyncSetup(ConfigBase):
    def __init__(self, config_path, gcp_credential):
        super().__init__(config_path, gcp_credential)

    def get_fabric_lakehouse(self, nm):
        lakehouse = None

        try:
            lakehouse = mssparkutils.lakehouse.get(nm)
        except Exception:
            print("Lakehouse not found")

        return lakehouse

    def create_fabric_lakehouse(self, nm):
        lakehouse = get_fabric_lakehouse(nm)

        if (lakehouse is None):
            mssparkutils.lakehouse.create(nm)

    def setup(self):
        self.create_fabric_lakehouse(self.UserConfig.MetadataLakehouse)
        self.create_fabric_lakehouse(self.UserConfig.TargetLakehouse)
        self.create_all_tables()

    def drop_table(self, tbl):
        sql = f"DROP TABLE IF EXISTS {tbl_nm}"
        spark.sql(sql)

    def get_tbl_name(self, tbl):
        return self.UserConfig.get_lakehouse_tablename(self.UserConfig.MetadataLakehouse, tbl)

    def create_data_type_map_tbl(self):
        tbl_nm = self.get_tbl_name(SyncConstants.SQL_TBL_DATA_TYPE_MAP)
        self.drop_table(tbl_name)

        sql = f"""CREATE TABLE IF NOT EXISTS {tbl_nm} (data_type STRING, partition_type STRING, is_watermark STRING)"""
        spark.sql(sql)

        df = spark.read.format("csv").option("header","true").load("Files/data/bq_data_types.csv")
        df.write.mode("OVERWRITE").saveAsTable(tbl_nm)

    def create_sync_config_tbl(self):
        tbl_nm = self.get_tbl_name(SyncConstants.SQL_TBL_SYNC_CONFIG)
        self.drop_table(tbl_name)

        sql = f"""
        CREATE TABLE IF NOT EXISTS {tbl_nm}
        (
            project_id STRING,
            dataset STRING,
            table_name STRING,
            enabled BOOLEAN,
            lakehouse STRING,
            lakehouse_table_name STRING,
            source_query STRING,
            priority INTEGER,
            load_strategy STRING,
            load_type STRING,
            interval STRING,
            primary_keys ARRAY<STRING>,
            is_partitioned BOOLEAN,
            partition_column STRING,
            partition_type STRING,
            partition_grain STRING,
            watermark_column STRING,
            autodetect BOOLEAN,
            config_override BOOLEAN,
            sync_state STRING,
            created_dt TIMESTAMP,
            last_updated_dt TIMESTAMP
        )
        """
        spark.sql(sql)
    
    def create_sync_schedule_tbl(self):
        tbl_nm = self.get_tbl_name(SyncConstants.SQL_TBL_SYNC_SCHEDULE)
        self.drop_table(tbl_name)

        sql = f"""
        CREATE TABLE IF NOT EXISTS {tbl_nm} (
            schedule_id STRING,
            project_id STRING,
            dataset STRING,
            table_name STRING,
            scheduled TIMESTAMP,
            status STRING,
            started TIMESTAMP,
            completed TIMESTAMP,
            src_row_count BIGINT,
            dest_row_count BIGINT,
            dest_inserted_row_count BIGINT,
            dest_updated_row_count BIGINT,
            delta_version BIGINT,
            spark_application_id STRING,
            max_watermark STRING,
            summary_load STRING,
            priority INTEGER
        )
        PARTITIONED BY (priority)
        """
        spark.sql(sql)
    
    def create_sync_schedule_partition_tbl(self):
        tbl_nm = self.get_tbl_name(SyncConstants.SQL_TBL_SYNC_SCHEDULE_PARTITION)
        self.drop_table(tbl_name)

        sql = f"""
        CREATE TABLE IF NOT EXISTS {tbl_nm} (
            schedule_id STRING,
            project_id STRING,
            dataset STRING,
            table_name STRING,
            partition_id STRING,
            bq_total_rows BIGINT,
            bq_last_modified TIMESTAMP,
            bq_storage_tier STRING,
            started TIMESTAMP,
            completed TIMESTAMP
        )
        """
        spark.sql(sql)

    def create_all_tables(self):
        self.create_data_type_map_tbl()
        self.create_sync_config_tbl()
        self.create_sync_schedule_tbl()
        self.create_sync_schedule_partition_tbl()

StatementMeta(, 188d4f89-f45c-455e-8263-facaf270c5df, 14, Finished, Available)