In [0]:
raw_posts_df = spark.table("default.raw_posts")

In [0]:
display(raw_posts_df.limit(10))

In [0]:
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
from pyspark.sql import DataFrame

def split_tag_into_array(df: DataFrame) -> DataFrame:
    return (
        df.withColumn(
            "TagsArray",
            F.filter(F.split(F.col("tags"),r'\|'), lambda x: x != ""))
        .drop("tags")
    )

def rename_columns(df: DataFrame) -> DataFrame:
    return df.withColumnRenamed("Id", "PostId")

def map_post_type(df: DataFrame) -> DataFrame:
    map_data = spark.createDataFrame(
        [
            (1, "Question"),
            (2, "Answer"),
            (3, "Orphaned tag wiki"),
            (4, "Tag wiki excerpt"),
            (5, "Tag wiki"),
            (6, "Moderator nomination"),
            (7, "Wiki placeholder"),
            (8, "Privilege wiki"),
            (9, "Article"),
            (10, "HelpArticle"),
            (12, "Collection"),
            (13, "ModeratorQuestionnaireResponse"),
            (14, "Announcement"),
            (15, "CollectiveDiscussion"),
            (17, "CollectiveCollection")
        ],
        ["PostTypeId", "PostTypeName"]
    )
    return df.join(
        F.broadcast(map_data),
        on="PostTypeId",
        how="left"
    ).drop(map_data["PostTypeId"])

stg_post_df = (
    raw_posts_df
    .transform(split_tag_into_array)
    .transform(rename_columns)
    .transform(map_post_type)
)

display(stg_post_df)

In [0]:
display(stg_post_df.limit(3))

In [0]:
from delta.tables import DeltaTable
import pyspark.sql.functions as F


def incremental_upsert(dest_table:str, df:DataFrame, uniqueKey:str, updated_at:str, full_refresh=False):
    """
    Performs incremental upsert using updated_at as the cursor_value with unique_key 
    Doesnt support deleted, very minimal
    """
    if not spark.catalog.tableExists(dest_table) or full_refresh:
        (
            df
            .write
            .format("delta")
            .mode("overwrite")
            .option("overwriteSchema", "true")
            .saveAsTable(dest_table)
        )
    else:
        last_max = (
            spark.table(dest_table)
            .agg(F.max(updated_at).alias("max_ts"))
            .collect()[0]["max_ts"]                 
        )

        incr_df = df.filter(F.col(updated_at) > last_max)

        if incr_df.head(1):
            delta_table = DeltaTable.forName(spark,dest_table)
            (
            delta_table.alias("t").merge(
                source=incr_df.alias("s"),
                condition=f"s.{unique_key} = t.{unique_key}"
            )
            .whenMatchedUpdateAll()
            .whenNotMatchedInsertAll()
            .execute()
            )
 

dest_table = "default.stg_posts"
unique_key = "PostId"
updated_at = "CreationDate"
incremental_upsert(dest_table,stg_post_df,unique_key,updated_at)


In [0]:
# spark.table(dest_table).rdd.getNumPartitions()  # RDD operations are not supported on shared clusters, consider alternative methods such as spark.table(dest_table).rdd.getNumPartitions() if supported, or use DataFrame methods. 

In [0]:
incremental_upsert(dest_table,stg_post_df.repartition(4),"PostId","CreationDate",full_refresh=True)

In [0]:
spark.conf.set("spark.sql.shuffle.partitions", 4)

In [0]:
display(spark.table(dest_table).limit(5))