In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as spark_sum, spark_partition_id, asc

In [2]:
spark = SparkSession.builder \
    .appName("DataFrameShuffleExample") \
    .master("spark://spark-master:7077") \
    .config("spark.sql.shuffle.partitions", "20") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/30 08:01:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# DataFrame 생성
data = [(i, i % 5) for i in range(20)]
df = spark.createDataFrame(data, ["value", "key"])

In [4]:
# narrow transformation - stage 1
df2 = df.filter(col("value") > 5).select("value", "key")

In [5]:
# 각 Row가 속한 Partition 확인
df2_with_pid = df2.withColumn("partition_id", spark_partition_id())
print("Stage 1 Partition Info:")
df2_with_pid.sort("partition_id", ascending=True).show()

Stage 1 Partition Info:


                                                                                

+-----+---+------------+
|value|key|partition_id|
+-----+---+------------+
|    6|  1|           0|
|    7|  2|           0|
|    8|  3|           0|
|    9|  4|           0|
|   10|  0|           1|
|   11|  1|           1|
|   12|  2|           1|
|   13|  3|           1|
|   14|  4|           1|
|   15|  0|           1|
|   16|  1|           1|
|   17|  2|           1|
|   18|  3|           1|
|   19|  4|           1|
+-----+---+------------+



In [6]:
# wide transformation - shuffle이 발생 - stage 2
df3 = df2.groupBy("key").agg(spark_sum("value").alias("sum_value"))

In [7]:
result = df3.collect()

                                                                                

In [8]:
print("Result:")
for row in result:
    print(row)

Result:
Row(key=2, sum_value=36)
Row(key=4, sum_value=42)
Row(key=0, sum_value=25)
Row(key=1, sum_value=33)
Row(key=3, sum_value=39)


In [9]:
# Stage 2 Partition 확인
df3_with_pid = df3.withColumn("partition_id", spark_partition_id())
print("Stage 2 Partition Info (after shuffle):")
df3_with_pid.show()

Stage 2 Partition Info (after shuffle):
+---+---------+------------+
|key|sum_value|partition_id|
+---+---------+------------+
|  2|       36|           0|
|  4|       42|           0|
|  1|       33|           0|
|  3|       39|           0|
|  0|       25|           0|
+---+---------+------------+



In [10]:
df3.explain(extended=True)

== Parsed Logical Plan ==
'Aggregate ['key], ['key, sum('value) AS sum_value#26]
+- Project [value#0L, key#1L]
   +- Filter (value#0L > cast(5 as bigint))
      +- LogicalRDD [value#0L, key#1L], false

== Analyzed Logical Plan ==
key: bigint, sum_value: bigint
Aggregate [key#1L], [key#1L, sum(value#0L) AS sum_value#26L]
+- Project [value#0L, key#1L]
   +- Filter (value#0L > cast(5 as bigint))
      +- LogicalRDD [value#0L, key#1L], false

== Optimized Logical Plan ==
Aggregate [key#1L], [key#1L, sum(value#0L) AS sum_value#26L]
+- Filter (isnotnull(value#0L) AND (value#0L > 5))
   +- LogicalRDD [value#0L, key#1L], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=true
+- == Final Plan ==
   *(2) HashAggregate(keys=[key#1L], functions=[sum(value#0L)], output=[key#1L, sum_value#26L])
   +- AQEShuffleRead coalesced
      +- ShuffleQueryStage 0
         +- Exchange hashpartitioning(key#1L, 4), ENSURE_REQUIREMENTS, [plan_id=41]
            +- *(1) HashAggregate(keys=[key#1L], function