#🚀 Step 1: Set Up Spark Session



In [7]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("PySparkOptimizationProject").getOrCreate()


#🚀 Step 2: Simulate Large Transaction Data


In [22]:
# small transaction data
transactions = [
    (1, "user_1", "product_1", 100),
    (2, "user_2", "product_2", 150),
    (3, "user_3", "product_3", 120),
]
columns = ["transaction_id", "user_id", "product_id", "amount"]

small_tx_df = spark.createDataFrame(transactions, columns)

# expand 15 times (2^15 ≈ 32k rows)
large_tx_df = small_tx_df
for _ in range(15):
    large_tx_df = large_tx_df.union(small_tx_df)

print(f"Total transaction rows: {large_tx_df.count()}")


Total transaction rows: 48


In [9]:
large_tx_df.show()

+--------------+-------+----------+------+
|transaction_id|user_id|product_id|amount|
+--------------+-------+----------+------+
|             1| user_1| product_1|   100|
|             2| user_2| product_2|   150|
|             3| user_3| product_3|   120|
|             1| user_1| product_1|   100|
|             2| user_2| product_2|   150|
|             3| user_3| product_3|   120|
|             1| user_1| product_1|   100|
|             2| user_2| product_2|   150|
|             3| user_3| product_3|   120|
|             1| user_1| product_1|   100|
|             2| user_2| product_2|   150|
|             3| user_3| product_3|   120|
|             1| user_1| product_1|   100|
|             2| user_2| product_2|   150|
|             3| user_3| product_3|   120|
|             1| user_1| product_1|   100|
|             2| user_2| product_2|   150|
|             3| user_3| product_3|   120|
|             1| user_1| product_1|   100|
|             2| user_2| product_2|   150|
+----------

#🚀 Step 3: Create Product Catalog (Lookup Table)



In [10]:
product_catalog = [
    ("product_1", "Electronics", 250),
    ("product_2", "Clothing", 80),
    ("product_3", "Groceries", 30),
]
product_columns = ["product_id", "category", "unit_price"]

product_df = spark.createDataFrame(product_catalog, product_columns)
product_df.show()


+----------+-----------+----------+
|product_id|   category|unit_price|
+----------+-----------+----------+
| product_1|Electronics|       250|
| product_2|   Clothing|        80|
| product_3|  Groceries|        30|
+----------+-----------+----------+



#🚀 Step 4: Join Transaction with Product Data


In [11]:
joined_df_shuffle = large_tx_df.join(product_df, on="product_id", how="left")
joined_df_shuffle.show(5)


+----------+--------------+-------+------+-----------+----------+
|product_id|transaction_id|user_id|amount|   category|unit_price|
+----------+--------------+-------+------+-----------+----------+
| product_1|             1| user_1|   100|Electronics|       250|
| product_3|             3| user_3|   120|  Groceries|        30|
| product_2|             2| user_2|   150|   Clothing|        80|
| product_1|             1| user_1|   100|Electronics|       250|
| product_3|             3| user_3|   120|  Groceries|        30|
+----------+--------------+-------+------+-----------+----------+
only showing top 5 rows



#🚀 Step 5: Broadcast Join


In [12]:
from pyspark.sql.functions import broadcast

joined_df_broadcast = large_tx_df.join(broadcast(product_df), on="product_id", how="left")

joined_df_broadcast.show(5)


+----------+--------------+-------+------+-----------+----------+
|product_id|transaction_id|user_id|amount|   category|unit_price|
+----------+--------------+-------+------+-----------+----------+
| product_1|             1| user_1|   100|Electronics|       250|
| product_2|             2| user_2|   150|   Clothing|        80|
| product_3|             3| user_3|   120|  Groceries|        30|
| product_1|             1| user_1|   100|Electronics|       250|
| product_2|             2| user_2|   150|   Clothing|        80|
+----------+--------------+-------+------+-----------+----------+
only showing top 5 rows



#⏱ Profile performance




In [13]:
import time

start = time.time()
joined_df_broadcast.count()
print("Broadcast join time: ", time.time() - start)


Broadcast join time:  6.003990650177002


#🚀 Step 6: Repartition


In [14]:
print("Partitions before:", large_tx_df.rdd.getNumPartitions())

repartitioned_df = joined_df_broadcast.repartition(16)
print("Partitions after repartition:", repartitioned_df.rdd.getNumPartitions())

# simulate write
repartitioned_df.write.mode("overwrite").parquet("/tmp/tx_repartitioned")


Partitions before: 32
Partitions after repartition: 16


#🚀 Step 7: Coalesce


In [15]:
coalesced_df = repartitioned_df.coalesce(4)
print("Partitions after coalesce:", coalesced_df.rdd.getNumPartitions())

coalesced_df.write.mode("overwrite").parquet("/tmp/tx_coalesced")


Partitions after coalesce: 4


#🚀 Step 8: Caching and Persist


In [20]:
# using cache
cached_df = joined_df_broadcast.cache()
cached_df.count()  # triggers caching


48

In [21]:


# using persist with MEMORY_AND_DISK
from pyspark import StorageLevel
persisted_df = joined_df_broadcast.persist(StorageLevel.MEMORY_AND_DISK)
persisted_df.count()  # triggers persist


48

#⏱ Profile performance:




In [17]:
start = time.time()
cached_df.count()
print("Count after cache: ", time.time() - start)

start = time.time()
persisted_df.count()
print("Count after persist: ", time.time() - start)


Count after cache:  1.0293893814086914
Count after persist:  1.1298937797546387


#🚀 Step 9: Catalyst Optimizer Demonstration


In [18]:
# hint for broadcast
hinted_df = large_tx_df.join(
    product_df.hint("broadcast"), on="product_id", how="left"
)
hinted_df.explain()  # shows Catalyst logical & physical plan


== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- InMemoryTableScan [product_id#235, transaction_id#233L, user_id#234, amount#236L, category#448, unit_price#449L]
      +- InMemoryRelation [product_id#235, transaction_id#233L, user_id#234, amount#236L, category#448, unit_price#449L], StorageLevel(disk, memory, deserialized, 1 replicas)
            +- AdaptiveSparkPlan isFinalPlan=false
               +- Project [product_id#235, transaction_id#233L, user_id#234, amount#236L, category#448, unit_price#449L]
                  +- BroadcastHashJoin [product_id#235], [product_id#447], LeftOuter, BuildRight, false
                     :- Union
                     :  :- Scan ExistingRDD[transaction_id#233L,user_id#234,product_id#235,amount#236L]
                     :  :- Scan ExistingRDD[transaction_id#241L,user_id#242,product_id#243,amount#244L]
                     :  :- Scan ExistingRDD[transaction_id#249L,user_id#250,product_id#251,amount#252L]
                     :  :- Scan Exi

#🚀 Step 10: Final Profiling


In [19]:
# shuffle join baseline
start = time.time()
joined_df_shuffle.count()
print("Shuffle join count time:", time.time() - start)

# broadcast join with caching
start = time.time()
cached_df.count()
print("Broadcast + cache count time:", time.time() - start)


Shuffle join count time: 0.8022630214691162
Broadcast + cache count time: 0.6562621593475342
