In [0]:
raw_posts_df = spark.read.table('default.raw_posts')

In [0]:
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, MapType, IntegerType, StringType
from pyspark.sql import DataFrame


def split_tag_into_array(df: DataFrame) -> DataFrame:
    return (
        df
        .withColumn("TagsArray", F.filter(F.split(F.col("Tags"), r'\|'), lambda x: x != ""))
        .drop("Tags")
    )

def rename_columns(df: DataFrame) -> DataFrame:
    return (
        df.withColumnRenamed("Id", "PostId")
    )

def map_post_type(df: DataFrame) -> DataFrame:
    map_data = [
        (1, "Question"),
        (2, "Answer"),
        (3, "Orphaned tag wiki"),
        (4, "Tag wiki excerpt"),
        (5, "Tag wiki"),
        (6, "Moderator nomination"),
        (7, "Wiki placeholder"),
        (8, "Privilege wiki"),
        (9, "Article"),
        (10, "HelpArticle"),
        (12, "Collection"),
        (13, "ModeratorQuestionnaireResponse"),
        (14, "Announcement"),
        (15, "CollectiveDiscussion"),
        (17, "CollectiveCollection")
    ]

    map_schema = StructType([
        StructField("PostTypeId", IntegerType(), False),
        StructField("PostType", StringType(), False)
    ])

    map_df = spark.createDataFrame(map_data, schema=map_schema)

    return df.join(
        F.broadcast(map_df),
        df["PostTypeId"] == map_df["PostTypeId"],
        "left"
    ).drop(map_df["PostTypeId"])

stg_posts_df = (
    raw_posts_df
    .transform(split_tag_into_array)
    .transform(rename_columns)
    .transform(map_post_type)
)

In [0]:

from delta.tables import DeltaTable
import pyspark.sql.functions as F

def incremntal_upsert(df: DataFrame, dest_table: str, unique_key: str, updated_at: str, full_refresh: bool = False) -> DataFrame:
    """
    Performs incremental upserts using updated_at as the cursor value and unique key
    """

    if not spark.catalog.tableExists(dest_table) or full_refresh:
        (
            df.write.format("delta")
                    .mode("overwrite")
                    .option("overwriteSchema", "true")
                    .saveAsTable(dest_table)
         )
    else:
        # Get the latest value of updated_at
        last_max = (
                    spark.table(dest_table)
                    .agg(F.max(updated_at).alias("max_ts"))
                    .collect()[0]["max_ts"]
                )

        # Filter the new data to only include rows with updated_at > latest_update
        new_data = df.filter(F.col(updated_at) > last_max)

        # Upsert the new data
        if new_data.limit(1).count() > 0:
            delta_table = DeltaTable.forName(spark, dest_table)
            (
                delta_table.alias("t")
                    .merge(source=new_data.alias("s"),
                        condition=f"t.{unique_key} = s.{unique_key}")
                    .whenMatchedUpdateAll()
                    .whenNotMatchedInsertAll()
                    .execute()
            )
    return df

dest_table = "default.stg_posts"
stg_posts_df = incremntal_upsert(stg_posts_df, dest_table, "PostId","CreationDate")