In [0]:
from pyspark.sql.functions import current_timestamp

def add_ingestion_date(input_df):
  # Add a new column 'ingestion_date' with the current timestamp to the input DataFrame
  output_df = input_df.withColumn('ingestion_date', current_timestamp())
  return output_df

In [0]:
def rearrange_partition_column(input_df, partition_column):
    # Initialize an empty list to store column names except the partition column
    column_list = []
    for column_name in input_df.schema.names:
        # Add column to the list if it is not the partition column
        if column_name != partition_column:
            column_list.append(column_name)
    # Append the partition column at the end of the list
    column_list.append(partition_column)
    # Select columns in the new order
    output_df = input_df.select(column_list)
    return output_df

In [0]:
def overwrite_partition(input_df, db_name, table_name, partition_column):
    # Rearrange columns so that the partition column is at the end, which is often required for partitioned writes
    output_df = rearrange_partition_column(input_df, partition_column)

    # Set Spark configuration to enable dynamic partition overwrite mode
    # This allows only the partitions present in the DataFrame to be overwritten, rather than all partitions
    spark.conf.set("spark.sql.sources.partitionOverwriteMode","dynamic")

    # Check if the target table exists in the specified database
    if (spark._jsparkSession.catalog().tableExists(f"{db_name}.{table_name}")):
        # If the table exists, overwrite only the relevant partitions using insertInto with overwrite mode
        output_df.write.mode("overwrite").insertInto(f"{db_name}.{table_name}")
    else:
        # If the table does not exist, create a new partitioned table in Parquet format
        output_df.write.mode("overwrite").partitionBy(partition_column).format("parquet").saveAsTable(f"{db_name}.{table_name}")

In [0]:
def df_column_to_list(input_df, column_name):
    # Select the specified column, get distinct values, and collect as a list of Row objects
    df_row_list = input_df.select(column_name).distinct().collect()

    # Extract the column values from the Row objects into a Python list
    column_value_list = [row[column_name] for row in df_row_list]
    return column_value_list

In [0]:
def merge_delta_data(input_df, db_name, table_name, folder_path, merge_condition, partition_column):  
    # Enable dynamic partition pruning to improve join performance by reducing the amount of data read during query execution
    spark.conf.set("spark.databricks.optimizer.dynamicPartionPruning", "true")
    
    from delta.tables import DeltaTable
    # Check if the target Delta table exists in the specified database and schema
    if (spark._jsparkSession.catalog().tableExists(f"{db_name}.{table_name}")):
        # If the Delta table exists, perform an upsert (merge) operation:
        # - Update all columns in the target table when records match the merge condition
        # - Insert new records from the source DataFrame when no match is found
        deltaTable = DeltaTable.forPath(spark, f"{folder_path}/{table_name}")
        deltaTable.alias("tgt").merge(
            input_df.alias("src"),
            merge_condition) \
        .whenMatchedUpdateAll() \
        .whenNotMatchedInsertAll() \
        .execute()
    else:
        # If the Delta table does not exist, create a new Delta table:
        # - Write the input DataFrame to the specified database and table name
        # - Partition the table by the specified partition column for efficient data management and query performance
        # - Use Delta format to enable ACID transactions and time travel features
        input_df.write.mode("overwrite").partitionBy(partition_column).format("delta").saveAsTable(f"{db_name}.{table_name}")