In [4]:
##################################################################################################################
# description: setting common environment variables
##################################################################################################################

from notebookutils import mssparkutils
workspacename = mssparkutils.env.getWorkspaceName()
if workspacename == "fe-d-syn-enterpriseanalytics":
    dw_db_name="feddwentanalytics"
    dw_server_name="fe-d-syn-enterpriseanalytics.sql.azuresynapse.net"
    temp_adls_loc="abfss://temp@fedsaentanalytics.dfs.core.windows.net/notebooks_data/"
    storage_account="fedsaentanalytics"
elif workspacename == "fe-p-syn-enterpriseanalytics":
    dw_db_name="fepdwentanalytics"
    dw_server_name="fe-p-syn-enterpriseanalytics.sql.azuresynapse.net"
    temp_adls_loc="abfss://temp@fepsaentanalytics.dfs.core.windows.net/notebooks_data/"
    storage_account="fepsaentanalytics"
else:
    raise Exception("Invalid workspace name")

In [6]:
import pyodbc
from struct import pack
import re
import pandas as pd

##################################################################################################################
# description: returns dataframe (path, name, size) for files in directory
# parms:    
#           path = string abfss path
#           search_pattern = string regex search pattern
#           match_case = bool for regex search is it case sensative
##################################################################################################################

def ls_files_to_data_frame(path, search_pattern="", match_case=True):
    files = mssparkutils.fs.ls(path)
    matching_files = []

    if search_pattern:
        if match_case:
            p = re.compile(search_pattern)
        else:
            p = re.compile(search_pattern, re.IGNORECASE)

        for f in files:
            if re.match(p, f.name):
                matching_files.append(f)
    else:
        matching_files = files

    schema = ['path','name','size']

    if len(matching_files)>0:
        spark.conf.set("spark.sql.execution.arrow.enabled", "false")

        df = pd.DataFrame([[getattr(i,j) for j in schema] for i in files], columns = schema).sort_values('path')

        return df
    else: 
        return pd.DataFrame(columns=schema)

##################################################################################################################
# description: returns connection string and token for serverless sql pool
# parms:    
#           server = string serverless sql endpoint
#           dbname = string serverless sql database name
##################################################################################################################

def get_serverless_synapse_conn(dbname="default",server="fe-d-syn-enterpriseanalytics-ondemand.sql.azuresynapse.net,1433"):
    # see https://www.aizoo.info/post/dropping-a-sql-table-in-your-synapse-spark-notebooks-python-edition 
    cnnstr = f"Driver={{ODBC Driver 18 for SQL Server}};Server={server};Database={dbname};"

    auth_key = bytes(mssparkutils.credentials.getToken("DW"), 'utf8')
    exp_token = b""
    for i in auth_key:
        exp_token += bytes({i})
        exp_token += bytes(1)
    token_struct = pack("=i", len(exp_token)) + exp_token
    return(cnnstr,token_struct)

##################################################################################################################
# description: runs sql against serverless sql pool
# parms:    sql = string for sql
#           cnnstr = string the odbc conn string      
#           token = struct the auth token for serverless sql pool
#           server = string serverless sql endpoint
#           dbname = string serverless sql database name
##################################################################################################################

def run_serverless_sql(sql, cnnstr, token, server="", dbname=""):
    if not token or not cnnstr:
        cnnstr, token = get_serverless_synapse_conn(server, dbname)

    with pyodbc.connect(cnnstr, attrs_before={ 1256:token }) as conn:
        conn.autocommit = True
        cursor = conn.cursor()
        print(f"running {sql}")
        cursor.execute(sql)


##################################################################################################################
# description: gets tracked_cols (name, type) and partitionBy (list of partition column names)
# parms:    source_df = spark dataframe containing data
#           view_col_def = string valid json containing sql type to create in view
#           partition_by = string comma delimited of partition columns
##################################################################################################################

def get_schema_info(source_df, view_col_def, partition_by):
     # get list of columns to iterate over
    tracked_cols = pd.DataFrame(source_df.dtypes, columns=["col_name","spark_type"])
    tracked_cols['col_type'] = tracked_cols['spark_type'].str.replace(r"\(.*\)","")
    default_type_mapping_list = [
                                    ['string', 'varchar(max)']
                                    , ['long', 'bigint']
                                    , ['boolean', 'bit']
                                    , ['decimal', 'decimal(38,8)']
                                    , ['double', 'float']
                                    , ['float', 'float']
                                    , ['int', 'int']
                                    , ['bigint', 'bigint']
                                    , ['tinyint', 'smallint']
                                    , ['date', 'date']
                                    , ['timestamp', 'datetime2']
                                    , ['char', 'char']
                                    , ['binary', 'varbinary(max)']
                                  ]
   
    default_types_df = pd.DataFrame(default_type_mapping_list, columns =['col_type', 'sql_type'])
    if view_col_config:
        typed_cols=pd.DataFrame(json.loads(view_col_config))
    else:
        typed_cols=pd.DataFrame(columns = ["col_name", "sql_type"])

    partition_sql = ""
    partitionBy = []

    # create delta tables partition by syntax and col list
    if partition_by:
        # if partition_by is filled out split on comma delim... format for generated column is col_name=<and valid sql>
        for p in partition_by.split(","):
            if not p in tracked_cols['col_name'] and "=" in p:
                partition_sql += f"{p.split('=')[1].replace('*','').strip()} as {p.split('=')[0].strip()},"
                partitionBy.append(p.split("=")[0].strip())
            else:
                partitionBy.append(p.strip())

        # remove last ,
        partition_sql = partition_sql[:-1]

        print(f"partition_sql is {partition_sql}")
        print(f"partitionBy is {partitionBy}")

        tracked_partition_by_cols = pd.DataFrame(partitionBy, columns=["col_name"])
        tracked_cols = pd.concat([tracked_cols,tracked_partition_by_cols]).drop_duplicates(subset=["col_name"]).reset_index(drop=True)

    tracked_cols = pd.merge(tracked_cols,default_types_df, on=['col_type'], how='left')

    return tracked_cols, partitionBy, typed_cols

##################################################################################################################
# description: initiliaze a delta table... create a view in serverless sql pool to catalog
# parms:    source_df = the spark dataframe containing data to be loaded into delta
#           delta_table_path = string abfss path to delta table to write to      
#           parquet_file_path = string abfss path to parquet source... added to meta data of commit
#           dest_db_name = string name of database in serverless sql pool
#           dest_schema_name = string name of schema in serverless sql pool 
#           dest_view_name = string name of view in serverless sql pool 
#           tracked_cols = pandas dataframe with all columns that are in source table plus a few others
#           partitionBy = list of partition columns
##################################################################################################################

def initialize_delta_table(source_df, delta_table_path, parquet_file_path, dest_db_name, dest_schema_name, dest_view_name, tracked_cols, partitionBy, typed_cols):
    if mssparkutils.fs.exists(delta_table_path):
        print(f"Found {delta_table_path} and init flag set... will not overwrite contents but will overwrite schema.")     

    # write it out as delta table to adls
    if partitionBy:  
        #check paritition cols are already part of dataframe
        #if not all(x in source_df.columns for x in partitionBy): 
        #    source_df.createOrReplaceTempView(f"{table_name}_pre_partition")
        #    source_df=spark.sql(f"select {partition_sql},* from {table_name}_pre_partition")
        print(partitionBy)
        (source_df.write
                .format("delta")
                .partitionBy(partitionBy)
                .mode("overwrite")
                .option("overwriteSchema", "true")
                .option("delta.enableChangeDataFeed", "true")
                .option("userMetadata", f"init with {parquet_file_path}") \
                .save(delta_table_path)
        )
    else:
        (source_df.write
                .format("delta")
                .mode("overwrite")
                .option("overwriteSchema", "true")
                .option("delta.enableChangeDataFeed", "true")
                .option("userMetadata", f"init with {parquet_file_path}") \
                .save(delta_table_path)
        )
        
    # create strongly typed with definition for view creation
    sql_view_with_clause = ""
    for idx, row in tracked_cols.iterrows():
        type_rows = typed_cols.loc[typed_cols['col_name'].str.upper() == str(row['col_name']).upper()]
        if len(type_rows)>0: 
            sql_type = type_rows["sql_type"].iloc[0]
            sql_view_with_clause += f"[{row['col_name']}] {sql_type},"  
        else: 
            sql_view_with_clause += f"[{row['col_name']}] {row['sql_type']},"
    sql_view_with_clause=sql_view_with_clause[:-1]
    create_view_statement = f"""create view [{dest_schema_name}].[{table_name}] as
            select * 
            from openrowset(bulk '{delta_table_path}', format='DELTA')
            with ({
                sql_view_with_clause
            }) as v;
    """
    
    cnnstr,token=get_serverless_synapse_conn()
    # create sql db for adls container ... aka bronze,silver,gold
    run_serverless_sql(f"IF NOT EXISTS (SELECT * FROM master.sys.databases WHERE name = '{dest_db_name}') BEGIN CREATE DATABASE [{dest_db_name}] END", cnnstr, token)
        
    # switch db and create schema if not exists and consumption view
    cnnstr,token=get_serverless_synapse_conn(dbname=f'{dest_db_name}')
    # create schema
    run_serverless_sql(f"""
        IF (NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{dest_schema_name}')) 
        BEGIN
            EXEC ('CREATE SCHEMA [{dest_schema_name}]')
        END
    """, cnnstr, token)
    #creat view
    run_serverless_sql(f"""
        IF (EXISTS (select * from information_schema.tables where table_schema = '{dest_schema_name}' and table_name='{table_name}')) 
        BEGIN
            EXEC('DROP VIEW [{dest_schema_name}].[{table_name}]')
        END
    """, cnnstr, token)
    run_serverless_sql(create_view_statement, cnnstr, token)

##################################################################################################################
# description: incrementally load a delta table based on mode
# parms:    mode = append, partition+overwrite, merge, upsert, truncate+fill, del+insert
#           source_df = the spark dataframe containing data to be loaded into delta
#           delta_table_path = string abfss path to delta table to write to      
#           commit_meta_data = string to add meta data for commit.. eg. abfss path to parquet source... 
#           tracked_cols = pandas dataframe with all columns that are in source table plus a few others
#           keys = array of primary keys
#           partitionBy = list of partition columns
#           del_filter = deletion filter predicate only used if mode is del+insert
##################################################################################################################

def incremental_delta_table(mode, source_df, delta_table_path, commit_meta_data, tracked_cols, keys, partitionBy, del_filter):
    key_join = ""

    for idx, row in tracked_cols.iterrows():
        if row['col_name'] in keys:   
            key_join += f" s.{row['col_name']} = t.{row['col_name']} AND"

    key_join = key_join[:-3]
    print(key_join)
    dt = DeltaTable.forPath(spark, delta_table_path)
    replaceWhere=""

    if partitionBy:
        if mode == "truncate+fill":
            mode = "partition+overwrite"
            print("switching mode to partition+overwrite because table is processed partitioned")
        #check paritition cols are already part of dataframe
        #if not all(x in source_df.columns for x in partitionBy): 
        #    source_df.createOrReplaceTempView(f"{table_name}_change_pre_partition")
        #    source_df=spark.sql(f"select {partition_sql},* from {table_name}_change_pre_partition")
        
        source_df.createOrReplaceTempView(f"{table_name}_stage")

        #create replacewhere partition pruning (can get rid of when dynamic partition overwrite occurs in synapse.. in preview for dbricks)
        #and add to key_join for partitition pruning of merges
        print(partitionBy)
        for p in partitionBy:
        #    vals = source_df.toPandas()[p].unique()
            vals = spark.sql(f"select distinct {p} as col from {table_name}_stage").toPandas()["col"].unique()
            replaceWhere += "s." + p + " in (" + ", ".join(f"'{v}'" for v in vals) + ") AND "
            
        if replaceWhere:
            key_join = f"{replaceWhere} {key_join}"
            replaceWhere = replaceWhere.replace("s.","")
            replaceWhere = replaceWhere[:-4]

    # get current time to induce possible rollback
    current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(current_time)
    print(f"mode={mode},del_filter={del_filter}")

    spark.conf.set("spark.databricks.delta.commitInfo.userMetadata", commit_meta_data)

    if mode=="append":
        print(f"Appending with append mode")
        if partitionBy:  # create partition by columns 
            source_df.write \
                .format("delta") \
                .partitionBy(partitionBy) \
                .mode("append") \
                .option("overwriteSchema", "true") \
                .save(delta_table_path)
        else:
            source_df.write \
                .format("delta") \
                .mode("append") \
                .option("overwriteSchema", "true") \
                .save(delta_table_path)   
    elif mode=="partition+overwrite":
        print(f"Inserting with partition+overwrite mode")
        if partitionBy and replaceWhere:
            print(replaceWhere)
            source_df.write \
                    .format("delta") \
                    .mode("overwrite") \
                    .option("overwriteSchema", "true") \
                    .option("replaceWhere", replaceWhere) \
                    .save(delta_table_path)
    elif mode=="merge":
        print(key_join)
        dt.alias("t") \
                .merge(source_df.alias("s"), key_join) \
                .whenMatchedUpdateAll() \
                .whenNotMatchedInsertAll() \
                .execute()
        #.whenNotMatchedBySourceDelete() \ **doesn't exist yet in version of delta on synapse spark ... needs 2.3.0
            
    elif mode=="upsert":
        print(f"Upserting with upsert mode")
        
        print(key_join)
        dt.alias("t") \
            .merge(source_df.alias("s"), key_join) \
            .whenMatchedUpdateAll() \
            .whenNotMatchedInsertAll() \
            .execute()
            
    elif mode=="truncate+fill":
        print(f"Overwriting for truncate+fill mode")

        if partitionBy:  # create partition by columns 
            (source_df.write 
                .format("delta") 
                .partitionBy(partitionBy) 
                .mode("overwrite")
                .option("overwriteSchema", "true")
                .save(delta_table_path))
        else:
            (source_df.write 
                .format("delta") 
                .mode("overwrite")
                .option("overwriteSchema", "true") 
                .save(delta_table_path))
    elif mode=="del+insert":
        print(f"Merging with del+insert")
        
        if del_filter:
            dt.delete(del_filter)
        dt.alias("t") \
            .merge(source_df.alias("s"), key_join) \
            .whenMatchedUpdateAll() \
            .whenNotMatchedInsertAll() \
            .execute()
            # once delta is upgraded can do with one merge statement and adding
            #.whenNotMatchedBySource(condition=del_filter).delete() \
    else:
        print("No operation happening as mode selection has not been given enough information.")


##################################################################################################################
# description: databricks like function for table_changes... which isn't implemented in synapse. 
#               gives spark dataframe for version passed in... or just base table if that passed in
#               creates temp view too to be referenced in spark.sql
# parms:    tablename = can be managed or external table name... or in format of delta.`<abfss path>`
#           temp_view_name = name of temporary view to create in spark session for table
#           list_from_to_version = list of numbered or or timestamps ... item 0 from and item 1 to
##################################################################################################################

def fn_table(tablename, temp_view_name, list_from_to_version=[]):    
    n = len(list_from_to_version)
    if len(list_from_to_version)>0:
        if all([isinstance(vers, int) for vers in list_from_to_version]) == True:    
            if n > 1:    
                df = spark.read.format("delta").option("readChangeFeed", "true").option("startingVersion", list_from_to_version[0]).option("endingVersion", list_from_to_version[1]).table(tablename)    
            else:    
                df = spark.read.format("delta").option("readChangeFeed", "true").option("startingVersion", list_from_to_version[0]).table(tablename)      
        else:    
            if n > 1:    
                df = spark.read.format("delta").option("readChangeFeed", "true").option("startingTimestamp", list_from_to_version[0]).option("endingTimestamp", list_from_to_version[1]).table(tablename)    
            else:     
                df = spark.read.format("delta").option("readChangeFeed", "true").option("startingTimestamp", list_from_to_version[0]).table(tablename)    
    else: 
        df = spark.read.format("delta").table(tablename)    
            
    df.createOrReplaceTempView(temp_view_name)
    return df