##### Functions
###### This notebook is the source for the utility functions that are called from the main processing notebook



In [None]:
spark.conf.set("spark.sql.parquet.vorder.enabled", "true")
spark.conf.set("spark.microsoft.delta.optimizeWrite.enabled", "true")
spark.conf.set("spark.microsoft.delta.optimizeWrite.binSize", "1073741824")

In [None]:
import json
import yaml
from pyspark.sql.functions import *
from pyspark.sql.types import *
from datetime import datetime
from delta.tables import DeltaTable


In [None]:
# Assuming that the meta data files for each data source are in one folder for the data source
def metadata_loader(data_source, entity):
    # Load the folder from the standard file path
    config_folder = f'/lakehouse/default/Files/PipelineMetadata/{dataSource}/'
    # Load the config file based on the source & entity
    cf = open(config_folder + data_source + '_' + entity + '.yml')
    meta_data = yaml.safe_load(cf)

    return(meta_data)

In [None]:
def template_loader(data_source, entity, config_version):
    
    template_location = f'/lakehouse/default/Files/PipelineMetadata/config_templates/{data_source}/template_{data_source}_{entity}_{config_version}.yml'
    tf = open(template_location)
    template = yaml.safe_load(tf)
    return template


In [None]:
def verify_config(template, config):
    
    if isinstance(template, dict) and isinstance(config, dict):
        # Check if the keys match
        if set(template.keys()) != set(config.keys()):
            return False
        # Recursively compare sub-keys
        for key in template.keys():
            if not verify_config(template[key], config[key]):
                return False
        return True
    elif isinstance(template, list) and isinstance(config, list):
        # Check if the lengths match
        if len(template) != len(config):
            return False
        # Recursively compare elements
        for i in range(len(template)):
            if not verify_config(template[i], config[i]):
                return False
        return True
    else:
        return True


In [None]:
def clean_columns(df):
    # Defines a regex pattern for special characters to be replaced
    pattern = re.compile(r'[.\`\'",;(){}\[\]$*!?%&|<>+=/\-\\]')

    clean_name = lambda name: pattern.sub('_', name)

    columns_to_clean = df.columns
    
    # Create a dictionary with old and new column names
    rename_dict = {col: clean_name(col) for col in columns_to_clean}
    
    # Rename columns
    for old_name, new_name in rename_dict.items():
        df = df.withColumnRenamed(old_name, new_name)
    
    return df

In [None]:
# Create schema from json file
def to_dataframe_schema(custom_schema):

    fields = []
    for field_info in custom_schema:
        field_name = field_info["fieldName"]
        field_type = field_info["fieldType"]

        if field_type == "ByteType()":
            field_type = ByteType()
        elif field_type == "ByteType()":
            field_type = ByteType()
        elif field_type == "IntegerType()":
            field_type = IntegerType()
        elif field_type == "LongType()":
            field_type = LongType()
        elif field_type == "FloatType()":
            field_type = FloatType()
        elif field_type == "DoubleType()":
            field_type = DoubleType()
        elif field_type == "DecimalType()":
            field_type = DecimalType()
        elif field_type == "StringType()":
            field_type = StringType()
        elif field_type == "BinaryType()":
            field_type = BinaryType()
        elif field_type == "BooleanType()":
            field_type = BooleanType()
        elif field_type == "TimestampType()":
            field_type = TimestampType()
        elif field_type == "DateType()":
            field_type = DateType()
        elif field_type == "DayTimeIntervalType()":
            field_type = DayTimeIntervalType()
        elif field_type == "ArrayType()":
            field_type = ArrayType()
        elif field_type == "MapType()":
            field_type = MapType()
        elif field_type == "StructField()":
            field_type = StructField()


        if field_info["nullable"].lower() == 'true':
            nullable = True
        elif  field_info["nullable"].lower() == 'false':
            nullable = False
        else:
            nullable = False

        fields.append(StructField(field_name, field_type, nullable))

    # Create the schema
    schema = StructType(fields)
    return(schema)

In [None]:
def move_landing_file(file_path, move_location):

    #get and format datetime
    now = datetime.now()
    filename_timestamp = now.strftime("%Y%m%d%H%M%S")
    
    # get file name
    if move_location == 'Success':
        updated_path = file_path.replace('/Landing/', '/Processed/').replace('.', '_' + filename_timestamp + '.')
        mssparkutils.fs.mv(file_path, updated_path, True)

    elif move_location == 'Failure':
        updated_path = file_path.replace('/Landing/', '/Failed/').replace('.', '_' + filename_timestamp + '.')
        mssparkutils.fs.mv(file_path, updated_path, True)

In [None]:
def create_table_if_not_exists (config_table_options_raw_lakehouse_name, config_table_options_raw_layer_name, config_table_options_raw_table_name
                                , write_schema, config_table_options_raw_partition_columns, config_table_options_raw_partition_type):
    try:
   
        if spark.catalog.tableExists(config_table_options_raw_table_name, config_table_options_raw_lakehouse_name) == False:
            
            if config_table_options_raw_layer_name != "":
                adjusted_table_name = config_table_options_raw_lakehouse_name + '.' + config_table_options_raw_layer_name + '_' + config_table_options_raw_table_name
            else:
                adjusted_table_name = config_table_options_raw_lakehouse_name + '.' +  config_table_options_raw_table_name

            # Create a DataFrame with the specified schema (no data)
            data = spark.createDataFrame([], schema=write_schema)
            if config_table_options_raw_partition_type != 'reference':
                data.write.format("delta").mode("overwrite").partitionBy(*config_table_options_raw_partition_columns).saveAsTable(adjusted_table_name)
            elif config_table_options_raw_partition_type == 'reference':
                data.write.format("delta").mode("overwrite").saveAsTable(adjusted_table_name)

    except:
        print('create table failed')


In [None]:
def get_config_values(config, keys, valueType):
    current_level = config
    for key in keys:
        current_level = current_level.get(key, {})
        if not isinstance(current_level, dict):
            break
    
    return current_level if current_level else default


In [None]:
# Added to remove the issue of mix of string and bool values you can get in the config
def str_to_bool(string):
     if isinstance(string, bool):
          return(string)
     else:
          if string.lower() == 'true':
               return True
          elif string.lower() == 'false':
               return False

In [None]:
def df_schema_to_table(config_tableOptions_raw_lakehouse_name, config_tableOptions_raw_table_name, df_schema):
    # Note this just check column names and not order
    # for ordering use tableSchema.fieldNames() == dfSchema.fieldNames()
    table = spark.table(config_tableOptions_raw_lakehouse_name + '.' + config_tableOptions_raw_table_name)
    table_schema = table.schema

    if table_schema == df_schema:
        return(True)
    else:
        return(False)

In [None]:
def create_date_partitions(df, config_tableOptions_raw_partition_date_format, onfig_tableOptions_raw_dataframe_partition_columns):

    date_columns = config_tableOptions_raw_partition_date_format.split('/')
    date_columns = [x.lower() for x in date_columns]

    # date_format("date", "yyyyMM")). \
    if 'yyyy' in date_columns:
        df = df.withColumn("PartitionYear", date_format(onfig_tableOptions_raw_dataframe_partition_columns, "yyyy"))
    if 'mm' in date_columns:
        df = df.withColumn("PartitionMonth", date_format(onfig_tableOptions_raw_dataframe_partition_columns, "MM"))
    if 'dd' in date_columns:
        df = df.withColumn("PartitionDay", date_format(onfig_tableOptions_raw_dataframe_partition_columns, "dd"))

    return(df)

In [None]:
def compare_schema(dataframe_schema, custom_schema):
    # Leave out added metadata as they have been added after load
    columns_to_exclude = ['MetaCreatedDate', 'MetaUpdatedDate', 'MetaSourceFilename']
    
    # Remove meta columns
    dataframe_schema = StructType([field for field in dataframe_schema if field.name not in columns_to_exclude])

    return dataframe_schema == custom_schema


In [None]:
def get_lakehouse_Id(config_tableOptions_raw_lakehouse_name):
    lakehouses = spark.catalog.listDatabases()
    for lakehouse in lakehouses:
        if lakehouse.name == config_tableOptions_raw_lakehouse_name:
            uri_parts = lakehouse.locationUri.split('/')
            lakehouse_Id = uri_parts[-2]
    return(lakehouse_Id)

In [None]:
# Basic rasie error, not yet implimented
def basic_error(message):
    raise Exception(message)
    mssparkutils.notebook.exit(message)