In [1]:
# Create Spark Session

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Pivot & Un-Pivot") \
    .master("spark://spark-master:7077") \
    .getOrCreate()

spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/04 16:37:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
# Example Data Set

_data = [
	["Ramesh", "PHY", 90],
	["Ramesh", "MATH", 95],
	["Ramesh", "CHEM", 100],
	["Sangeeta", "PHY", 90],
	["Sangeeta", "MATH", 100],
	["Sangeeta", "CHEM", 83],
	["Mohan", "BIO", 90],
	["Mohan", "MATH", 70],
	["Mohan", "CHEM", 76],
	["Imran", "PHY", 96],
	["Imran", "MATH", 87],
	["Imran", "CHEM", 79],
	["Imran", "BIO", 82]
]

_cols = ["NAME", "SUBJECT", "MARKS"]

# Generate Data Frame
df = spark.createDataFrame(data=_data, schema = _cols)
df.show(truncate = False)

                                                                                

+--------+-------+-----+
|NAME    |SUBJECT|MARKS|
+--------+-------+-----+
|Ramesh  |PHY    |90   |
|Ramesh  |MATH   |95   |
|Ramesh  |CHEM   |100  |
|Sangeeta|PHY    |90   |
|Sangeeta|MATH   |100  |
|Sangeeta|CHEM   |83   |
|Mohan   |BIO    |90   |
|Mohan   |MATH   |70   |
|Mohan   |CHEM   |76   |
|Imran   |PHY    |96   |
|Imran   |MATH   |87   |
|Imran   |CHEM   |79   |
|Imran   |BIO    |82   |
+--------+-------+-----+



### Method 1 - Without specifying column names

In [3]:
# Pivot data without specifying the column names(values) and checking the execution time
from pyspark.sql import functions as F

pivot_df_1 = df.groupBy("NAME").pivot("SUBJECT").agg(F.sum("MARKS"))
pivot_df_1.printSchema()
pivot_df_1.show(truncate = False)

                                                                                

root
 |-- NAME: string (nullable = true)
 |-- BIO: long (nullable = true)
 |-- CHEM: long (nullable = true)
 |-- MATH: long (nullable = true)
 |-- PHY: long (nullable = true)

+--------+----+----+----+----+
|NAME    |BIO |CHEM|MATH|PHY |
+--------+----+----+----+----+
|Mohan   |90  |76  |70  |null|
|Ramesh  |null|100 |95  |90  |
|Imran   |82  |79  |87  |96  |
|Sangeeta|null|83  |100 |90  |
+--------+----+----+----+----+



In [4]:
pivot_df_1.write.format("noop").mode("overwrite").save()

### Method 2 - Specifying column names

In [6]:
# Get the distinct list of Subjects
_subjects = df.select("SUBJECT").distinct().rdd.map(lambda x: x[0]).collect()
_subjects

['PHY', 'BIO', 'MATH', 'CHEM']

In [7]:
# Lets check the data and schema
pivot_df_2 = df.groupBy("NAME").pivot("SUBJECT", _subjects).agg(F.sum("MARKS"))
pivot_df_2.printSchema()

root
 |-- NAME: string (nullable = true)
 |-- PHY: long (nullable = true)
 |-- BIO: long (nullable = true)
 |-- MATH: long (nullable = true)
 |-- CHEM: long (nullable = true)



In [8]:
pivot_df_2.write.format("noop").mode("overwrite").save()

In [9]:
spark.stop()