In [None]:
# Spark Session
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("Understand Plans and DAG")
    .master("spark://spark-master:7077")
    .getOrCreate()
)

spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/07 16:46:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
# Disable AQE and Broadcast join

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [3]:
# Check default Parallism

spark.sparkContext.defaultParallelism

8

In [4]:
# Create dataframes

df_1 = spark.range(4, 200, 2)
df_2 = spark.range(2, 200, 4)

In [6]:
df_1.rdd.getNumPartitions()

8

In [7]:
# Re-partition data

df_3 = df_1.repartition(5)
df_4 = df_2.repartition(7)

In [9]:
df_4.rdd.getNumPartitions()

7

In [10]:
# Join the dataframes

df_joined = df_3.join(df_4, on="id")

In [11]:
# Get the sum of ids

df_sum = df_joined.selectExpr("sum(id) as total_sum")

In [12]:
# View data
df_sum.show()



+---------+
|total_sum|
+---------+
|     4998|
+---------+



                                                                                

In [11]:
# Explain plan

df_sum.explain()

== Physical Plan ==
*(6) HashAggregate(keys=[], functions=[sum(id#0L)])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=166]
   +- *(5) HashAggregate(keys=[], functions=[partial_sum(id#0L)])
      +- *(5) Project [id#0L]
         +- *(5) SortMergeJoin [id#0L], [id#2L], Inner
            :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(id#0L, 200), ENSURE_REQUIREMENTS, [plan_id=150]
            :     +- Exchange RoundRobinPartitioning(5), REPARTITION_BY_NUM, [plan_id=149]
            :        +- *(1) Range (4, 200, step=2, splits=2)
            +- *(4) Sort [id#2L ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(id#2L, 200), ENSURE_REQUIREMENTS, [plan_id=157]
                  +- Exchange RoundRobinPartitioning(7), REPARTITION_BY_NUM, [plan_id=156]
                     +- *(3) Range (2, 200, step=4, splits=2)




In [13]:
df_3 = df_1.coalesce(5)
df_4 = df_2.coalesce(7)
df_joined = df_3.join(df_4, on="id")
df_sum = df_joined.selectExpr("sum(id) as total_sum")
df_sum.show()



+---------+
|total_sum|
+---------+
|     4998|
+---------+



                                                                                

In [15]:
# Union the data again to see the skipped stages

df_union = df_sum.union(df_4)

In [16]:
df_union.show()

                                                                                

+---------+
|total_sum|
+---------+
|     4998|
|        2|
|        6|
|       10|
|       14|
|       18|
|       22|
|       26|
|       30|
|       34|
|       38|
|       42|
|       46|
|       50|
|       54|
|       58|
|       62|
|       66|
|       70|
|       74|
+---------+
only showing top 20 rows



In [14]:
# Explain plan

df_union.explain()

== Physical Plan ==
Union
:- *(6) HashAggregate(keys=[], functions=[sum(id#0L)])
:  +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=404]
:     +- *(5) HashAggregate(keys=[], functions=[partial_sum(id#0L)])
:        +- *(5) Project [id#0L]
:           +- *(5) SortMergeJoin [id#0L], [id#2L], Inner
:              :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
:              :  +- Exchange hashpartitioning(id#0L, 200), ENSURE_REQUIREMENTS, [plan_id=388]
:              :     +- Exchange RoundRobinPartitioning(5), REPARTITION_BY_NUM, [plan_id=387]
:              :        +- *(1) Range (4, 200, step=2, splits=2)
:              +- *(4) Sort [id#2L ASC NULLS FIRST], false, 0
:                 +- Exchange hashpartitioning(id#2L, 200), ENSURE_REQUIREMENTS, [plan_id=395]
:                    +- Exchange RoundRobinPartitioning(7), REPARTITION_BY_NUM, [plan_id=394]
:                       +- *(3) Range (2, 200, step=4, splits=2)
+- ReusedExchange [id#20L], Exchange RoundRobinPartitioning(

In [17]:
spark.stop()