In [1]:
#we use the findspark library to locate spark on our local machine
import findspark
findspark.init(r'C:\spark\spark-3.5.0-bin-hadoop3')
import pyspark # only run this after findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import expr
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("Banana",1000,"USA"), ("Carrots",1500,"USA"), ("Beans",1600,"USA"), \
      ("Orange",2000,"USA"),("Orange",2000,"USA"),("Banana",400,"China"), \
      ("Carrots",1200,"China"),("Beans",1500,"China"),("Orange",4000,"China"), \
      ("Banana",2000,"Canada"),("Carrots",2000,"Canada"),("Beans",2000,"Mexico")]

columns= ["Product","Amount","Country"]
df = spark.createDataFrame(data = data, schema = columns)
df.printSchema()
df.show(truncate=False)

root
 |-- Product: string (nullable = true)
 |-- Amount: long (nullable = true)
 |-- Country: string (nullable = true)

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
|Banana |1000  |USA    |
|Carrots|1500  |USA    |
|Beans  |1600  |USA    |
|Orange |2000  |USA    |
|Orange |2000  |USA    |
|Banana |400   |China  |
|Carrots|1200  |China  |
|Beans  |1500  |China  |
|Orange |4000  |China  |
|Banana |2000  |Canada |
|Carrots|2000  |Canada |
|Beans  |2000  |Mexico |
+-------+------+-------+



In [2]:
# The pivot("Country") operation is used to pivot the data based on the values in the "Country" column. This means that for each unique value in the "Country" column, a new column will be created in the resulting DataFrame.

pivotDF = df.groupBy("Product").pivot("Country").sum("Amount")
pivotDF.printSchema()
pivotDF.show(truncate=False)

root
 |-- Product: string (nullable = true)
 |-- Canada: long (nullable = true)
 |-- China: long (nullable = true)
 |-- Mexico: long (nullable = true)
 |-- USA: long (nullable = true)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |NULL  |4000 |NULL  |4000|
|Beans  |NULL  |1500 |2000  |1600|
|Banana |2000  |400  |NULL  |1000|
|Carrots|2000  |1200 |NULL  |1500|
+-------+------+-----+------+----+



In [3]:
pivotDF = df.groupBy("Product","Country") \
      .sum("Amount") \
      .groupBy("Product") \
      .pivot("Country") \
      .sum("sum(Amount)")
pivotDF.printSchema()
pivotDF.show(truncate=False)

root
 |-- Product: string (nullable = true)
 |-- Canada: long (nullable = true)
 |-- China: long (nullable = true)
 |-- Mexico: long (nullable = true)
 |-- USA: long (nullable = true)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |NULL  |4000 |NULL  |4000|
|Beans  |NULL  |1500 |2000  |1600|
|Banana |2000  |400  |NULL  |1000|
|Carrots|2000  |1200 |NULL  |1500|
+-------+------+-----+------+----+



unpivotExpr is a string expression that specifies how to unpivot the data. It uses the stack function, which takes a set of alternating values where the first value in each pair represents the new column name ("Country" in this case), and the second value is the column from pivotDF that contains the corresponding data ("Canada," "China," "Mexico," etc.). The as clause renames the resulting columns to "Country" and "Total."

unPivotDF is created by selecting the "Product" column from pivotDF and applying the expr(unpivotExpr) expression to generate the "Country" and "Total" columns based on the unpivot expression.

The .where("Total is not null") part filters out rows where the "Total" column is null, as you might not want to include rows where there is no data in the pivoted columns.

Finally, unPivotDF.show(truncate=False) is used to display the resulting DataFrame. The truncate=False parameter ensures that column values won't be truncated in the output, allowing you to see the complete data.

In summary, this code takes a pivoted DataFrame, transforms it back into a more tabular format (unpivots it), and filters out rows where the "Total" column is null, giving you a more structured representation of your data.

In [4]:
""" unpivot """
""" unpivot """
unpivotExpr = "stack(3, 'Canada', Canada, 'China', China, 'Mexico', Mexico) as (Country,Total)"
unPivotDF = pivotDF.select("Product", expr(unpivotExpr)) \
    .where("Total is not null")
unPivotDF.show(truncate=False)

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
|Orange |China  |4000 |
|Beans  |China  |1500 |
|Beans  |Mexico |2000 |
|Banana |Canada |2000 |
|Banana |China  |400  |
|Carrots|Canada |2000 |
|Carrots|China  |1200 |
+-------+-------+-----+

