# Spark AQE Coalesce Explained

In [1]:
# Create Spark Session

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Spark AQE Explained") \
    .master("spark://spark-master:7077") \
    .getOrCreate()

spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/26 13:20:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
# Lets check the current spark conf for AQE and shuffle partitions
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.adaptive.coalescePartitions.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))
print(spark.conf.get("spark.sql.adaptive.advisoryPartitionSizeInBytes")) #approx 64MB Default

true
true
200
67108864b


In [3]:
# Disable AQE and change Shuffle partition
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.shuffle.partitions", 289)

In [5]:
# Read example data set
import pandas as pd

data_file_https_url = "https://media.githubusercontent.com/media/subhamkharwal/pyspark-zero-to-hero/refs/heads/master/datasets/sales.csv"
schema = "transacted_at string, trx_id long, retailer_id long, description string, amount float, city_id float"
df = spark.createDataFrame(data=pd.read_csv(data_file_https_url), schema=schema)
df.printSchema()
print("Initial Partition after read: " + str(df.rdd.getNumPartitions()))

root
 |-- transacted_at: string (nullable = true)
 |-- trx_id: long (nullable = true)
 |-- retailer_id: long (nullable = true)
 |-- description: string (nullable = true)
 |-- amount: float (nullable = true)
 |-- city_id: float (nullable = true)

Initial Partition after read: 8


In [6]:
# GroupBy opeartion to trigger Shuffle
from pyspark.sql.functions import sum
df_count = df.selectExpr("city_id","cast(amount as double) as amount_double").groupBy("city_id").agg(sum("amount_double"))
print("Output shuffle partitions: " + str(df_count.rdd.getNumPartitions()))

Output shuffle partitions: 289


In [7]:
# Enable AQE and change Shuffle partition
spark.conf.set("spark.sql.adaptive.enabled", True)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", True)
spark.conf.set("spark.sql.shuffle.partitions", 289)

In [8]:
# Read example data set
data_file_https_url = "https://media.githubusercontent.com/media/subhamkharwal/pyspark-zero-to-hero/refs/heads/master/datasets/sales.csv"
schema = "transacted_at string, trx_id long, retailer_id long, description string, amount float, city_id float"
df = spark.createDataFrame(data=pd.read_csv(data_file_https_url), schema=schema)
df.printSchema()
print("Initial Partition after read: " + str(df.rdd.getNumPartitions()))

root
 |-- transacted_at: string (nullable = true)
 |-- trx_id: long (nullable = true)
 |-- retailer_id: long (nullable = true)
 |-- description: string (nullable = true)
 |-- amount: float (nullable = true)
 |-- city_id: float (nullable = true)

Initial Partition after read: 8


In [9]:
# GroupBy opeartion to trigger Shuffle
# Since our output with city_id as group by is smaller than < 64MB thus the data is written in single partiton
from pyspark.sql.functions import sum
df_count = df.selectExpr("city_id","cast(amount as double) as amount_double").groupBy("city_id").agg(sum("amount_double"))
print("Output shuffle partitions: " + str(df_count.rdd.getNumPartitions()))

24/10/26 13:21:27 WARN TaskSetManager: Stage 0 contains a task of very large size (8828 KiB). The maximum recommended task size is 1000 KiB.
[Stage 0:>                                                          (0 + 8) / 8]

Output shuffle partitions: 1


In [10]:
# GroupBy opeartion to trigger Shuffle but this time with trx_id (which is more unique - thus more data)
# Since our output with trx_id as group by is > 64MB thus the data is written in multiple partitions
from pyspark.sql.functions import sum
df_count = df.selectExpr("trx_id","cast(amount as double) as amount_double").groupBy("trx_id").agg(sum("amount_double"))
print("Output shuffle partitions: " + str(df_count.rdd.getNumPartitions()))

24/10/26 13:21:41 WARN TaskSetManager: Stage 1 contains a task of very large size (8828 KiB). The maximum recommended task size is 1000 KiB.

Output shuffle partitions: 8


In [11]:
spark.stop()