##### Functions
###### This notebook is the source for the utility functions that are called from the main processing notebook



In [None]:
spark.conf.set("spark.sql.parquet.vorder.enabled", "true")
spark.conf.set("spark.microsoft.delta.optimizeWrite.enabled", "true")
spark.conf.set("spark.microsoft.delta.optimizeWrite.binSize", "1073741824")

In [None]:
import json
from pyspark.sql.functions import *
from pyspark.sql.types import *
from datetime import datetime
from delta.tables import DeltaTable


In [None]:
def metadataLoader(filePath, configFolder):
    # Strip out the names from the folder location
    f = filePath.split('/')
    l = len(f)
    sourceFolder = f[2]
    entityFolder = f[3]
    # Load the config file based on the source & entity
    cf = open(configFolder + sourceFolder + '_' + entityFolder + '.json')
    mdata = json.load(cf)
    return(mdata)

In [None]:
# Create schema from json file
def jsonToSchema(custom_schema):

    fields = []
    for field_info in custom_schema:
        field_name = field_info["fieldName"]
        field_type = field_info["fieldType"]

        if field_type == "ByteType()":
            field_type = ByteType()
        elif field_type == "ByteType()":
            field_type = ByteType()
        elif field_type == "IntegerType()":
            field_type = IntegerType()
        elif field_type == "LongType()":
            field_type = LongType()
        elif field_type == "FloatType()":
            field_type = FloatType()
        elif field_type == "DoubleType()":
            field_type = DoubleType()
        elif field_type == "DecimalType()":
            field_type = DecimalType()
        elif field_type == "StringType()":
            field_type = StringType()
        elif field_type == "BinaryType()":
            field_type = BinaryType()
        elif field_type == "BooleanType()":
            field_type = BooleanType()
        elif field_type == "TimestampType()":
            field_type = TimestampType()
        elif field_type == "DateType()":
            field_type = DateType()
        elif field_type == "DayTimeIntervalType()":
            field_type = DayTimeIntervalType()
        elif field_type == "ArrayType()":
            field_type = ArrayType()
        elif field_type == "MapType()":
            field_type = MapType()
        elif field_type == "StructField()":
            field_type = StructField()


        if field_info["nullable"].lower() == 'true':
            nullable = True
        elif  field_info["nullable"].lower() == 'false':
            nullable = False
        else:
            nullable = False

        fields.append(StructField(field_name, field_type, nullable))

    # Create the schema
    schema = StructType(fields)
    return(schema)

In [None]:
def moveLandingFile(filePath, moveLocation):

    #get and format datetime
    now = datetime.now()
    filenameTimestamp = now.strftime("%Y%m%d%H%M%S")
    
    # get file name
    if moveLocation == 'Success':
        updatedPath = filePath.replace('/Landing/', '/Processed/').replace('.', '_' + filenameTimestamp + '.')
        mssparkutils.fs.mv(filePath, updatedPath, True)

    elif moveLocation == 'Failure':
        updatedPath = filePath.replace('/Landing/', '/Failed/').replace('.', '_' + filenameTimestamp + '.')
        mssparkutils.fs.mv(filePath, updatedPath, True)

In [None]:
def createTableIfNotExists (lakehouseName, layerName, tableName, writeSchema, partitionColumns, rawPartitionType):
   
    if spark.catalog.tableExists(tableName, lakehouseName) == False and len(partitionColumns) != 0:
        if rawLayerName != "":
            adjustedTableName = lakehouseName + '.' + rawLayerName + '_' + rawTableName
        else:
            adjustedTableName = lakehouseName + '.' +  rawTableName

        # Create a DataFrame with the specified schema (no data)
        data = spark.createDataFrame([], schema=writeSchema)
        data.write.format("delta").mode("overwrite").partitionBy(*rawTablePartitionColumns).saveAsTable(adjustedTableName)

In [None]:
def getConfigValues(config, keys, valueType):
    current_level = config
    for key in keys:
        current_level = current_level.get(key, {})
        if not isinstance(current_level, dict):
            break
    
    
    return current_level if current_level else default


In [None]:
def strToBool(string):
    if string.lower() == 'true':
         return True
    elif string.lower() == 'false':
         return False

In [None]:
def schemaDFToTable(lakehouseName, tableName, dfSchema):
    # Note this just check column names and not order
    # for ordering use tableSchema.fieldNames() == dfSchema.fieldNames()
    table = spark.table(lakehouseName + '.' + tableName)
    tableSchema = table.schema

    if tableSchema == dfSchema:
        return(True)
    else:
        return(False)

In [None]:
def createDatePartitions(df, partitionDateFormat, rawDataframePartitionColumns):

    dateColumns = partitionDateFormat.split('/')
    dateColumns = [x.lower() for x in dateColumns]

    # date_format("date", "yyyyMM")). \
    if 'yyyy' in dateColumns:
        df = df.withColumn("PartitionYear", date_format(rawDataframePartitionColumns, "yyyy"))
    if 'mm' in dateColumns:
        df = df.withColumn("PartitionMonth", date_format(rawDataframePartitionColumns, "MM"))
    if 'dd' in dateColumns:
        df = df.withColumn("PartitionDay", date_format(rawDataframePartitionColumns, "dd"))

    return(df)

In [None]:
def getLakehouseId(lakehouseName):
    lakehouses = spark.catalog.listDatabases()
    for lakehouse in lakehouses:
        if lakehouse.name == lakehouseName:
            uri_parts = lakehouse.locationUri.split('/')
            lakehouseId = uri_parts[-2]
    return(lakehouseId)

In [None]:
# Basic rasie error, not yet implimented
def basicError(message):
    raise Exception(message)
    mssparkutils.notebook.exit(message)