In [None]:
PySpark pivot() function is used to rotate/transpose the data from one column into multiple Dataframe columns and back using
unpivot(). Pivot() It is an aggregation where one of the grouping columns values is transposed into individual columns with
distinct data.
Syntax: pivot_df = original_df.groupBy("grouping_column").pivot("pivot_column").agg({"agg_column": "agg_function"})

In [1]:
# Imports
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import expr
#Create spark session
spark = SparkSession.builder.master("local[1]") \
                    .appName('SparkByExamples.com') \
                    .getOrCreate()

data = [("Banana",1000,"USA"), ("Carrots",1500,"USA"), ("Beans",1600,"USA"), \
      ("Orange",2000,"USA"),("Orange",2000,"USA"),("Banana",400,"China"), \
      ("Carrots",1200,"China"),("Beans",1500,"China"),("Orange",4000,"China"), \
      ("Banana",2000,"Canada"),("Carrots",2000,"Canada"),("Beans",2000,"Mexico")]

columns= ["Product","Amount","Country"]
df = spark.createDataFrame(data = data, schema = columns)
df.printSchema()
df.show(truncate=False)

root
 |-- Product: string (nullable = true)
 |-- Amount: long (nullable = true)
 |-- Country: string (nullable = true)

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
|Banana |1000  |USA    |
|Carrots|1500  |USA    |
|Beans  |1600  |USA    |
|Orange |2000  |USA    |
|Orange |2000  |USA    |
|Banana |400   |China  |
|Carrots|1200  |China  |
|Beans  |1500  |China  |
|Orange |4000  |China  |
|Banana |2000  |Canada |
|Carrots|2000  |Canada |
|Beans  |2000  |Mexico |
+-------+------+-------+



In [2]:
# Applying pivot()
pivotDF = df.groupBy("Product").pivot("Country").sum("Amount")
pivotDF.printSchema()
pivotDF.show(truncate=False)

root
 |-- Product: string (nullable = true)
 |-- Canada: long (nullable = true)
 |-- China: long (nullable = true)
 |-- Mexico: long (nullable = true)
 |-- USA: long (nullable = true)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



In [4]:
countries = ["USA","China","Canada","Mexico"]
pivotDF = df.groupBy("Product").pivot("Country", countries).sum("Amount")
pivotDF.show(truncate=False)

pivotDF = df.groupBy("Product","Country").sum("Amount").groupBy("Product").pivot("Country").sum("sum(Amount)")
pivotDF.show(truncate=False)

+-------+----+-----+------+------+
|Product|USA |China|Canada|Mexico|
+-------+----+-----+------+------+
|Orange |4000|4000 |null  |null  |
|Beans  |1600|1500 |null  |2000  |
|Banana |1000|400  |2000  |null  |
|Carrots|1500|1200 |2000  |null  |
+-------+----+-----+------+------+

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



In [None]:
Unpivot is a reverse operation, we can achieve by rotating column values into rows values. PySpark SQL doesn’t have unpivot 
function hence will use the stack() function. Below code converts column countries to row.

In [5]:
# Applying unpivot()
from pyspark.sql.functions import expr
unpivotExpr = "stack(3, 'Canada', Canada, 'China', China, 'Mexico', Mexico) as (Country,Total)"
unPivotDF = pivotDF.select("Product", expr(unpivotExpr)) \
    .where("Total is not null")
unPivotDF.show(truncate=False)
unPivotDF.show()

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
|Orange |China  |4000 |
|Beans  |China  |1500 |
|Beans  |Mexico |2000 |
|Banana |Canada |2000 |
|Banana |China  |400  |
|Carrots|Canada |2000 |
|Carrots|China  |1200 |
+-------+-------+-----+

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
| Orange|  China| 4000|
|  Beans|  China| 1500|
|  Beans| Mexico| 2000|
| Banana| Canada| 2000|
| Banana|  China|  400|
|Carrots| Canada| 2000|
|Carrots|  China| 1200|
+-------+-------+-----+



In [None]:
The pivot operation requires unique combinations of grouping and pivot columns. If there are duplicate entries, you may need to perform an aggregation (e.g., using agg parameter) to resolve the conflict.
The pivot operation fills missing values with null. If you need to handle missing values differently, you can use methods like fillna or na.drop on the pivoted DataFrame.

In [None]:
data = [("ABC","Q1",2000),
        ("XYZ","Q1",4000),
        ("ABC","Q2",3000),
        ("XYZ","Q2",1000),
        ("ABC","Q3",2000),
        ("XYZ","Q3",6000),
        ("ABC","Q4",3000),
        ("XYZ","Q4",5000)]
schema = ["Company","Quarter","Revenue"]

df = spark.createDataFrame(data,schema)
display(df)

Company,Quarter,Revenue
ABC,Q1,2000
XYZ,Q1,4000
ABC,Q2,3000
XYZ,Q2,1000
ABC,Q3,2000
XYZ,Q3,6000
ABC,Q4,3000
XYZ,Q4,5000


In [None]:
pivot_df = df.groupBy("Company").pivot("Quarter").sum("Revenue")
display(pivot_df)

Company,Q1,Q2,Q3,Q4
XYZ,4000,1000,6000,5000
ABC,2000,3000,2000,3000


In [None]:
df2 = pivot_df.selectExpr("Company","stack (4,'Q1',Q1,'Q2',Q2,'Q3',Q3,'Q4',Q4) as (Quarter,Revenue)")
display(df2)

Company,Quarter,Revenue
XYZ,Q1,4000
XYZ,Q2,1000
XYZ,Q3,6000
XYZ,Q4,5000
ABC,Q1,2000
ABC,Q2,3000
ABC,Q3,2000
ABC,Q4,3000
