<a href="https://colab.research.google.com/github/brendanlooker/colab-examples/blob/main/spark/Building_and_Optimizing_Spark_Pipelines.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building and Optimizing Spark Pipelines
This hands-on lab explores key optimization and maintenance strategies for using Apache Spark and Apache Iceberg. You will learn how to:

- Tune Spark to leverage the performance benefits of Broadcast Hash Join over Shuffle Sort-Merge Join when dealing with large fact tables and small dimension tables.

- Perform Iceberg file compaction to merge small files created by continuous writes.

- Build a Structured Streaming pipeline that uses Iceberg's transactional MERGE INTO capability to perform real-time updates and inserts into a target table

# Setup

1. Create a Spark Serverless Runtime

In [None]:
from google.cloud.dataproc_spark_connect import DataprocSparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

# --- Configuration Constants (Using user's provided values) ---
CATALOG_NAME = "my_catalog"
NAMESPACE = "spark_lab"
PROJECT_ID = "oh-lab-477016"
REGION = "us-central1"
WAREHOUSE_DIR = "gs://oh-lab-477016-warehouse/warehouse"
CHECKPOINT_DIR = "gs://oh-lab-477016-checkpoints/checkpoint"

# --- Spark Session Initialization (Based on user's code) ---
spark = DataprocSparkSession.builder \
    .appName("spark-lab") \
    .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") \
    .config(f"spark.sql.catalog.{CATALOG_NAME}", "org.apache.iceberg.spark.SparkCatalog") \
    .config(f"spark.sql.catalog.{CATALOG_NAME}.catalog-impl", "org.apache.iceberg.gcp.bigquery.BigQueryMetastoreCatalog") \
    .config(f"spark.sql.catalog.{CATALOG_NAME}.gcp_project", f"{PROJECT_ID}") \
    .config(f"spark.sql.catalog.{CATALOG_NAME}.gcp_location", f"{REGION}") \
    .config(f"spark.sql.catalog.{CATALOG_NAME}.warehouse", f"{WAREHOUSE_DIR}") \
    .config("spark.executor.instances", "8") \
    .config("spark.sql.autoBroadcastJoinThreshold", "100") \
    .config("spark.sql.adaptive.enabled", "false") \
    .config("spark.sql.adaptive.skewJoin.enabled", "false") \
    .getOrCreate()

ModuleNotFoundError: No module named 'google.cloud.dataproc_spark_connect'

2. Setup data for the lab. We will be creating 1 Big dataset and another small dataset

We create a large Fact table Employees with 5 million rows and a small Dimension table Regions with 500 rows. This size difference is key for testing join strategies.

In [None]:
TOTAL_EMPLOYEES = 5000000
TOTAL_REGIONS = 500

# Create Region (Dimension) Table
df_region = spark.range(1, TOTAL_REGIONS + 1).withColumnRenamed("id", "RegionID") \
    .withColumn("RegionName", F.concat(F.lit("Rgn_"), F.col("RegionID")))

# Create Employee (Fact) Table
df_emp = spark.range(TOTAL_EMPLOYEES).withColumnRenamed("id", "EmployeeID") \
    .withColumn("Salary", F.round(F.rand() * 120000)) \
    .withColumn("TransactionDate", F.date_sub(F.current_date(), (F.rand() * 1000).cast("int"))) \
    .withColumn("RegionID", (F.rand() * TOTAL_REGIONS + 1).cast("int"))
df_emp = df_emp.withColumn("Notes", F.lit("Long string data field that increases I/O size"))

# Create temporary views
df_region.createOrReplaceTempView("region_raw")
df_emp.repartition(100,F.col("RegionID")).createOrReplaceTempView("employee_raw_for_write")

print(f"Employee Fact Table rows simulated: {TOTAL_EMPLOYEES}")

3. Create NameSpace

In [None]:
spark.sql(f"USE {CATALOG_NAME};")
spark.sql(f"CREATE NAMESPACE IF NOT EXISTS {NAMESPACE};")
spark.sql(f"USE {NAMESPACE};")

4. Create Source Tables in Iceberg

In [None]:
CATALOG_PATH = f"{CATALOG_NAME}.{NAMESPACE}"
EMPLOYEE_TABLE = f"{CATALOG_PATH}.employees"
REGION_TABLE = f"{CATALOG_PATH}.regions"

# 1. Write the large Fact table, PARTITIONED by RegionID.
spark.sql(f"""
    CREATE OR REPLACE TABLE {EMPLOYEE_TABLE}
    USING iceberg
    PARTITIONED BY (RegionID)
    AS SELECT * FROM employee_raw_for_write
""")
print(f"Fact Table 1: Partitioned Iceberg created: {EMPLOYEE_TABLE}")

# 2. Write the small Dimension table
spark.sql(f"""
    CREATE OR REPLACE TABLE {REGION_TABLE}
    USING iceberg
    AS SELECT * FROM region_raw
""")
print(f"Region Dimension Table created: {REGION_TABLE}")

# Module 1: Join Strategy Tuning

This module compares the performance of the default Shuffle Sort-Merge Join against the optimized Broadcast Hash Join for a Fact-Dimension join.

## 1. Shuffle Sort-Merge Join

SMJ is Spark's default join for large tables where the smaller table cannot be broadcast. It requires Spark to:

- Shuffle: Move data from both tables across the network to group rows by the join key.

- Sort: Sort the shuffled data within each partition.

- Merge: Iterate through both sorted datasets in parallel to find matching keys. This process is highly I/O and CPU intensive, making it the slower baseline.

After running this query, check the Spark UI. The DAG Visualization should show a SortMergeJoin operator.



In [None]:
print("Starting SMJ baseline...")
smj_join_baseline = spark.sql(f"""
    SELECT
        e.EmployeeID, r.RegionName
    FROM
        {EMPLOYEE_TABLE} e
    INNER JOIN
        {REGION_TABLE} r ON e.RegionID = r.RegionID
""").count()

print(f"SMJ Join Count: {smj_join_baseline}")

## 2. Broadcast Hash Join

BHJ is an optimization that is applicable when one table is significantly smaller (e.g., < 10MB by default, or fits in executor memory). Spark takes the smaller table, collects it onto the driver, and then broadcasts it to all executors. The join is then performed in memory using a hash map on each executor, entirely eliminating the expensive Shuffle and Sort phases.

The code below uses the /*+ BROADCAST(r) */ hint to force Spark to broadcast the regions table

The DAG Visualization should now show a BroadcastHashJoin operator, and you should observe a significantly faster execution time.

In [None]:
print("Starting BHJ optimized join...")

bhj_join_optimized = spark.sql(f"""
    SELECT
        /*+ BROADCAST(r) */ e.EmployeeID, r.RegionName
    FROM
        {EMPLOYEE_TABLE} e
    INNER JOIN
        {REGION_TABLE} r ON e.RegionID = r.RegionID
""").count()

print(f"BHJ Optimized Count: {bhj_join_optimized}")

To see runtime difference try increasing employee dataset to 500M rows and run the queries again

# Module 2: Iceberg File Compaction

This module demonstrates file compaction in Iceberg. Continuous writes (especially small, frequent inserts) create many small data files, which can lead to slow query performance because the query engine has to read much more metadata and initiate more I/O requests. Compaction merges these small files into larger, optimal-sized files.

## 1. Create Table

In [None]:
# --- Create NameSpace---
spark.sql(f"USE {CATALOG_NAME};")
spark.sql(f"CREATE NAMESPACE IF NOT EXISTS {NAMESPACE};")
spark.sql(f"USE {NAMESPACE};")

In [None]:
spark.sql(f"""
    CREATE OR REPLACE TABLE Users (
        user_id INT,
        status STRING,
        joined_date DATE
    )
    USING iceberg
    PARTITIONED BY (joined_date)
""")


## 2. Insert Data to Iceberg Table

We perform 5 separate INSERT INTO statements. Since each statement is a separate transaction, it creates 5 distinct small data files.

In [None]:
spark.sql(f"""
    Insert Into Users (
        user_id,
        status,
        joined_date
    )
    Values
    (
      1, 'ACTIVE',DATE '2025-11-03'
    )

""")

In [None]:
spark.sql(f"""
    Insert Into Users (
        user_id,
        status,
        joined_date
    )
    Values
    (
      2, 'ACTIVE',DATE '2025-11-03'
    )

""")

In [None]:
spark.sql(f"""
    Insert Into Users (
        user_id,
        status,
        joined_date
    )
    Values
    (
      3, 'ACTIVE',DATE '2025-11-03'
    )

""")
spark.sql(f"""
    Insert Into Users (
        user_id,
        status,
        joined_date
    )
    Values
    (
      4, 'ACTIVE',DATE '2025-11-03'
    )

""")
spark.sql(f"""
    Insert Into Users (
        user_id,
        status,
        joined_date
    )
    Values
    (
      5, 'ACTIVE',DATE '2025-11-03'
    )

""")

## 3. Rewrite Data Files

The rewrite_data_files stored procedure is an Iceberg maintenance operation. It reads the small data files that meet the where condition and merges their contents into a set of new, larger, optimized files.

In below code we will compact files from a specific parition

In [None]:
spark.sql(f"""CALL my_catalog.system.rewrite_data_files(
    table => 'Users',
    options => map('rewrite-all', 'true'),
    where => "joined_date >= '2025-11-03 00:00:00' AND joined_date < '2025-11-04 00:00:00'"
)
""").show()

# Module 3: Streaming Pipeline

This module demonstrates building a real-time streaming pipeline using Spark Structured Streaming and Iceberg's transactional MERGE INTO operation to perform UPSERTs.

## 1. Define UPSERT Logic and Target Table

Spark Structured Streaming processes data in small micro-batches. The foreachBatch function allows you to apply operations like complex SQL DML statements on the data within each micro-batch.

In [None]:
from pyspark.sql import Window
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, TimestampType
from pyspark.sql import functions as F

# Define the target table name for streaming
STREAM_TARGET_TABLE = "realtime_users1"
CHECKPOINT_LOCATION = f"{CHECKPOINT_DIR}/user_activity_stream"

def upsert_to_iceberg(updates_df, batch_id):
    """
    Performs a transactional MERGE INTO operation for each micro-batch,
    including a deduplication step using a window function.
    """
    if updates_df.count() == 0:
        print(f"Batch {batch_id}: No new records, skipping merge.")
        return

    executor_spark = updates_df.sparkSession

    # L300 Check: Deduplicate incoming micro-batch on user_id, prioritizing latest timestamp
    window_spec = Window.partitionBy("user_id").orderBy(F.col("timestamp").desc())

    deduplicated_updates_df = updates_df \
        .withColumn("rank", F.row_number().over(window_spec)) \
        .filter(F.col("rank") == 1) \
        .drop("rank")

    # Safety Check
    if deduplicated_updates_df.isEmpty():
        print(f"Batch {batch_id}: Records dropped during deduplication, skipping merge.")
        return

    deduplicated_updates_df.createOrReplaceTempView("stream_updates")

    print(f"Batch {batch_id}: Merging {deduplicated_updates_df.count()} unique records into {STREAM_TARGET_TABLE}...")

    try:
        # Execute the transactional UPSERT (MERGE INTO)
        executor_spark.sql(f"""
          MERGE INTO {STREAM_TARGET_TABLE} T
          USING stream_updates S
          ON T.user_id = S.user_id
          WHEN MATCHED AND T.timestamp < S.timestamp THEN
            UPDATE SET T.status = S.status, T.timestamp = S.timestamp
          WHEN NOT MATCHED THEN
            INSERT (user_id, status, timestamp)
            VALUES (S.user_id, S.status, S.timestamp)
        """)
        print(f"Batch {batch_id}: Merge completed successfully.")

    except Exception as e:
        print(f"Batch {batch_id}: Error during merge: {e}")


## 2. Create the initial, empty target Iceberg table

We partition the target table by days(timestamp) to ensure efficient querying over time and faster MERGE INTO operations

In [None]:
# --- Create NameSpace---
spark.sql(f"USE {CATALOG_NAME};")
spark.sql(f"CREATE NAMESPACE IF NOT EXISTS {NAMESPACE};")
spark.sql(f"USE {NAMESPACE};")

In [None]:
spark.sql(f"""
    CREATE OR REPLACE TABLE {STREAM_TARGET_TABLE} (
        user_id INT,
        status STRING,
        timestamp TIMESTAMP
    )
    USING iceberg
    PARTITIONED BY (days(timestamp))
""")
print(f"Real-time Iceberg target table '{STREAM_TARGET_TABLE}' created.")


## 3. Start and Monitor the Stream

The stream uses the rate source to simulate incoming data at 5 rows per second. The processingTime='10 seconds' trigger ensures the foreachBatch (and thus the MERGE INTO) runs every 10 seconds.


In [None]:
# Define the stream input schema
input_schema = StructType([
    StructField("timestamp", TimestampType(), True),
    StructField("value", StringType(), True)
])

# Read stream using a rate source for testing
stream_df = spark.readStream \
    .format("rate") \
    .option("rowsPerSecond", 5) \
    .option("checkpointLocation", CHECKPOINT_LOCATION) \
    .load()

# Parse the incoming simulated data to match the target schema
parsed_stream = stream_df \
  .withColumn("user_id", (F.col("value") % 5) + 100) \
  .withColumn("status", F.when(F.col("value") % 3 == 0, "LOGGED_IN").otherwise("ACTIVE")) \
  .withColumn("timestamp", F.current_timestamp()) \
  .select("user_id", "status", "timestamp")


# Start the streaming query
streaming_query = parsed_stream.writeStream \
    .outputMode("update") \
    .foreachBatch(upsert_to_iceberg) \
    .trigger(processingTime='10 seconds') \
    .option("checkpointLocation", CHECKPOINT_LOCATION) \
    .queryName("Iceberg_UPSERT_Stream") \
    .start()

print("Streaming query started. Data is now being processed every 10 seconds.")

In [None]:
# --- Run the cell below to check status and stop the stream when done ---

# Check the status of the stream (Run this cell repeatedly)
streaming_query.status

Run this cell to query the final Iceberg table. The timestamp should increase over time

In [None]:
spark.sql(f"SELECT * FROM {STREAM_TARGET_TABLE}").show(truncate=False)

In [None]:
streaming_query.stop()
print("Streaming query stopped and session resources released.")