#🟢  Create Optimized Spark Session



In [19]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
import time

# Spark session with Kryo serialization + shuffle partitions
spark = SparkSession.builder \
    .appName("Spark Optimization Project") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.adaptive.enabled", "true") \
    .getOrCreate()


#🟢 : Generate Synthetic Data
✅ 10 million transactions
✅ 5 million customers
✅ 10,000 products





Customer Dataset (5 million)


In [20]:
customer_count = 5_000_000
customers = spark.range(1, customer_count + 1) \
    .withColumnRenamed("id", "customer_id") \
    .withColumn("name", concat(lit("Customer_"), col("customer_id"))) \
    .withColumn("gender", expr("CASE WHEN customer_id % 2 = 0 THEN 'M' ELSE 'F' END")) \
    .withColumn("age", (rand() * 50 + 18).cast("integer")) \
    .withColumn("location", expr("CASE WHEN customer_id % 5 = 0 THEN 'CityA' ELSE 'CityB' END"))


Product Dataset (10,000)




In [21]:
product_count = 10_000
products = spark.range(1, product_count + 1) \
    .withColumnRenamed("id", "product_id") \
    .withColumn("product_name", concat(lit("Product_"), col("product_id"))) \
    .withColumn("category", expr("CASE WHEN product_id % 5 = 0 THEN 'CategoryA' ELSE 'CategoryB' END")) \
    .withColumn("price", (rand() * 100 + 5).cast("decimal(10,2)"))


Transaction Dataset (10 million)




In [26]:
from pyspark.sql.functions import unix_timestamp, from_unixtime

# get current timestamp in seconds
current_ts = unix_timestamp(current_timestamp())

# random seconds offset within the last year
random_offset = (rand() * 365 * 24 * 60 * 60).cast("integer")

transactions = spark.range(1, transaction_count + 1) \
    .withColumnRenamed("id", "transaction_id") \
    .withColumn("customer_id", (rand() * customer_count).cast("integer") + 1) \
    .withColumn("product_id", (rand() * product_count).cast("integer") + 1) \
    .withColumn("amount", (rand() * 200 + 1).cast("decimal(10,2)")) \
    .withColumn(
        "timestamp",
        from_unixtime(current_ts - random_offset)
    ) \
    .withColumn(
        "payment_type",
        expr("CASE WHEN transaction_id % 3 = 0 THEN 'CARD' ELSE 'CASH' END")
    )


#📁 3. Save to Parquet for Pushdown



In [29]:
customers.write.mode("overwrite").parquet("customers_parquet")
products.write.mode("overwrite").parquet("products_parquet")
transactions.write.mode("overwrite").parquet("transactions_parquet")


#📁 4. Baseline Performance



In [30]:
customers_df = spark.read.parquet("customers_parquet")
products_df = spark.read.parquet("products_parquet")
transactions_df = spark.read.parquet("transactions_parquet")

start = time.time()

joined = transactions_df.join(customers_df, "customer_id", "inner") \
    .join(products_df, "product_id", "inner")

baseline_result = joined.groupBy("category").agg(sum("amount").alias("total_revenue"))
baseline_result.show()

print(f"Baseline time: {time.time()-start:.2f} sec")


+---------+-------------+
| category|total_revenue|
+---------+-------------+
|CategoryB| 807905245.81|
|CategoryA| 202004432.78|
+---------+-------------+

Baseline time: 17.28 sec


📁 5. Optimizations Step by Step


5.1 Column Pruning


In [32]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")


In [33]:
transactions_pruned = spark.read.parquet("transactions_parquet").select("customer_id", "product_id", "amount")
products_pruned = spark.read.parquet("products_parquet").select("product_id", "category")
customers_pruned = spark.read.parquet("customers_parquet").select("customer_id")

start = time.time()
joined_pruned = transactions_pruned.join(customers_pruned, "customer_id", "inner") \
    .join(products_pruned, "product_id", "inner") \
    .groupBy("category").agg(sum("amount").alias("total_revenue"))
joined_pruned.show()
print(f"Column pruning time: {time.time()-start:.2f} sec")


+---------+-------------+
| category|total_revenue|
+---------+-------------+
|CategoryB| 807905245.81|
|CategoryA| 202004432.78|
+---------+-------------+

Column pruning time: 41.99 sec


5.2 Predicate Pushdown


In [34]:

# filter high-value transactions at read
transactions_filtered = spark.read.parquet("transactions_parquet").select("customer_id", "product_id", "amount") \
    .filter(col("amount") > 50)

start = time.time()
joined_filtered = transactions_filtered.join(customers_pruned, "customer_id", "inner") \
    .join(products_pruned, "product_id", "inner") \
    .groupBy("category").agg(sum("amount").alias("total_revenue"))
joined_filtered.show()
print(f"Predicate pushdown time: {time.time()-start:.2f} sec")


+---------+-------------+
| category|total_revenue|
+---------+-------------+
|CategoryB| 757925086.09|
|CategoryA| 189503719.52|
+---------+-------------+

Predicate pushdown time: 29.08 sec


5.3 Filter Pushdown
(same as predicate pushdown above, using filters at data scan stage — Spark handles them similarly)



5.4 Project Pushdown



In [35]:
# already done via .select() in column pruning, no extra step


5.5 Sorting for Joins



In [36]:
# try sorting product_id to improve shuffle
transactions_sorted = transactions_pruned.sort("product_id")
# note: if join keys are skewed, bucketing is better


5.6 Efficient File Formats
✅ Already using Parquet (efficient, supports predicate + column pushdown).



5.7 WholeStage Code Generation


In [37]:
spark.conf.set("spark.sql.codegen.wholeStage", "true")
# default ON, just to confirm


5.8 Adaptive Query Execution (AQE)



In [38]:
spark.conf.set("spark.sql.adaptive.enabled", "true")


5.9 Kryo Serialization


In [40]:
spark = SparkSession.builder \
    .appName("Spark Optimization Project") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.shuffle.partitions", "200") \
    .getOrCreate()


5.10 Avoiding Shuffles



In [42]:
# use broadcast join for small products
from pyspark.sql.functions import broadcast
joined_broadcast = transactions_pruned.join(broadcast(products_pruned), "product_id", "inner")


5.11 Avoiding UDFs


In [43]:
joined_broadcast.withColumn("discounted", col("amount") * 0.9)



DataFrame[product_id: int, customer_id: int, amount: decimal(10,2), category: string, discounted: double]

5.12 Minimizing Data Movement



In [44]:
# partition transactions on customer_id to colocate
transactions_partitioned = transactions_pruned.repartition("customer_id")


5.13 Tuning spark.sql.shuffle.partitions



In [45]:
spark.conf.set("spark.sql.shuffle.partitions", "400")


5.14 Avoid collect() on Large Data


In [48]:
# dangerous:
# data = transactions_pruned.collect()
transactions_pruned.show(5)




+-----------+----------+------+
|customer_id|product_id|amount|
+-----------+----------+------+
|    2315473|      5129|  3.65|
|    3126779|       114| 93.97|
|    3500607|      8312| 42.01|
|    2761934|      1992| 42.79|
|    2206611|      2092|197.16|
+-----------+----------+------+
only showing top 5 rows



In [47]:
transactions_pruned.take(10)



[Row(customer_id=2315473, product_id=5129, amount=Decimal('3.65')),
 Row(customer_id=3126779, product_id=114, amount=Decimal('93.97')),
 Row(customer_id=3500607, product_id=8312, amount=Decimal('42.01')),
 Row(customer_id=2761934, product_id=1992, amount=Decimal('42.79')),
 Row(customer_id=2206611, product_id=2092, amount=Decimal('197.16')),
 Row(customer_id=798749, product_id=2618, amount=Decimal('152.06')),
 Row(customer_id=4019106, product_id=9297, amount=Decimal('103.20')),
 Row(customer_id=441124, product_id=1984, amount=Decimal('92.07')),
 Row(customer_id=1452174, product_id=8459, amount=Decimal('51.02')),
 Row(customer_id=489062, product_id=2879, amount=Decimal('152.05'))]

5.15 Reuse DataFrames



In [49]:
transactions_cached = transactions_pruned.cache()
# reused in multiple joins or queries


5.16 Writing Data with Optimal Partition Size
(80 * 128MB ≈ 10GB total partition spread)






In [50]:
final_result = transactions_pruned.groupBy("product_id").agg(sum("amount").alias("total_sales"))
final_result.coalesce(80).write.mode("overwrite").parquet("final_sales_data")
