In [1]:
from pyspark.sql import SparkSession 
spark = SparkSession.builder.appName("pyspark-by-examples").getOrCreate()
sc = spark.sparkContext
spark
sc

In [13]:
import time

In [2]:
from pyspark.sql.functions import expr
#Create spark session
data = [("Banana",1000,"USA"), ("Carrots",1500,"USA"), ("Beans",1600,"USA"), \
      ("Orange",2000,"USA"),("Orange",2000,"USA"),("Banana",400,"China"), \
      ("Carrots",1200,"China"),("Beans",1500,"China"),("Orange",4000,"China"), \
      ("Banana",2000,"Canada"),("Carrots",2000,"Canada"),("Beans",2000,"Mexico")]

columns= ["Product","Amount","Country"]
df = spark.createDataFrame(data = data, schema = columns)
df.printSchema()
df.show(truncate=False)


root
 |-- Product: string (nullable = true)
 |-- Amount: long (nullable = true)
 |-- Country: string (nullable = true)

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
|Banana |1000  |USA    |
|Carrots|1500  |USA    |
|Beans  |1600  |USA    |
|Orange |2000  |USA    |
|Orange |2000  |USA    |
|Banana |400   |China  |
|Carrots|1200  |China  |
|Beans  |1500  |China  |
|Orange |4000  |China  |
|Banana |2000  |Canada |
|Carrots|2000  |Canada |
|Beans  |2000  |Mexico |
+-------+------+-------+



    Pivot PySpark DataFrame
    PySpark SQL provides pivot() function to rotate the data from one column into multiple columns. It is an aggregation where one of the grouping columns values is transposed into individual columns with distinct data. To get the total amount exported to each country of each product, will do group by Product, pivot by Country, and the sum of Amount.

In [3]:

pivotDF = df.groupBy("Product").pivot("Country").sum("Amount")
pivotDF.printSchema()
pivotDF.show(truncate=False)


root
 |-- Product: string (nullable = true)
 |-- Canada: long (nullable = true)
 |-- China: long (nullable = true)
 |-- Mexico: long (nullable = true)
 |-- USA: long (nullable = true)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



In [21]:
startTimeQuery = time.perf_counter()
pivotDF1 = df.groupBy("Country").pivot("Product").sum("Amount")
endTimeQuery = time.perf_counter()
runTimeQuery = endTimeQuery - startTimeQuery
print(runTimeQuery)
pivotDF1.printSchema()
pivotDF1.show(truncate=False)
print("This will transpose the countries from DataFrame rows into columns and produces the below output. where ever data is not present, it represents as null by default.")

7.02711330000011
root
 |-- Country: string (nullable = true)
 |-- Banana: long (nullable = true)
 |-- Beans: long (nullable = true)
 |-- Carrots: long (nullable = true)
 |-- Orange: long (nullable = true)

+-------+------+-----+-------+------+
|Country|Banana|Beans|Carrots|Orange|
+-------+------+-----+-------+------+
|China  |400   |1500 |1200   |4000  |
|USA    |1000  |1600 |1500   |4000  |
|Mexico |null  |2000 |null   |null  |
|Canada |2000  |null |2000   |null  |
+-------+------+-----+-------+------+

This will transpose the countries from DataFrame rows into columns and produces the below output. where ever data is not present, it represents as null by default.


    Pivot Performance improvement in PySpark 2.0 version 2.0 on-wards performance has been improved on Pivot, however, if you are using the lower version; note that pivot is a very expensive operation hence, it is recommended to provide column data (if known) as an argument to function as shown below.

In [22]:
startTimeQuery = time.perf_counter()
countries = ["USA","China","Canada","Mexico"]
pivotDF = df.groupBy("Product").pivot("Country", countries).sum("Amount")
pivotDF.show(truncate=False)

endTimeQuery = time.perf_counter()
runTimeQuery = endTimeQuery - startTimeQuery
print(runTimeQuery)

+-------+----+-----+------+------+
|Product|USA |China|Canada|Mexico|
+-------+----+-----+------+------+
|Orange |4000|4000 |null  |null  |
|Beans  |1600|1500 |null  |2000  |
|Banana |1000|400  |2000  |null  |
|Carrots|1500|1200 |2000  |null  |
+-------+----+-----+------+------+

7.3973318999999265


Another approach is to do two-phase aggregation. PySpark 2.0 uses this implementation in order to improve the performance Spark-13749 <https://issues.apache.org/jira/browse/SPARK-13749>

In [23]:
startTimeQuery = time.perf_counter()
pivotDF = df.groupBy("Product","Country") \
      .sum("Amount") \
      .groupBy("Product") \
      .pivot("Country") \
      .sum("sum(Amount)") 
endTimeQuery = time.perf_counter()
runTimeQuery = endTimeQuery - startTimeQuery
print(runTimeQuery)
pivotDF.show(truncate=False)


6.919432800000095
+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



    Unpivot PySpark DataFrame
Unpivot is a reverse operation, we can achieve by rotating column values into rows values. PySpark SQL doesn’t have unpivot function hence will use the stack() function. Below code converts column countries to row.

In [42]:

from pyspark.sql.functions import expr
unpivotExpr = "stack(3, 'Canada', Canada, 'China', China, 'Mexico', Mexico) as (Country,Total)"
unPivotDF = pivotDF.select("Product", expr(unpivotExpr)) \
    .where("Total is not null")
unPivotDF.show(truncate=False)
unPivotDF.show()

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
|Orange |China  |4000 |
|Beans  |China  |1500 |
|Beans  |Mexico |2000 |
|Banana |Canada |2000 |
|Banana |China  |400  |
|Carrots|Canada |2000 |
|Carrots|China  |1200 |
+-------+-------+-----+

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
| Orange|  China| 4000|
|  Beans|  China| 1500|
|  Beans| Mexico| 2000|
| Banana| Canada| 2000|
| Banana|  China|  400|
|Carrots| Canada| 2000|
|Carrots|  China| 1200|
+-------+-------+-----+

