In [0]:
# Pivot and Un-Pivot Data Frame

In [0]:
# Create Spark Session

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Pivot & Un-Pivot") \
    .master("local[*]") \
    .getOrCreate()

spark

In [0]:
# Example Data Set

_data = [
	["Ramesh", "PHY", 90],
	["Ramesh", "MATH", 95],
	["Ramesh", "CHEM", 100],
	["Sangeeta", "PHY", 90],
	["Sangeeta", "MATH", 100],
	["Sangeeta", "CHEM", 83],
	["Mohan", "BIO", 90],
	["Mohan", "MATH", 70],
	["Mohan", "CHEM", 76],
	["Imran", "PHY", 96],
	["Imran", "MATH", 87],
	["Imran", "CHEM", 79],
	["Imran", "BIO", 82]
]

_cols = ["NAME", "SUBJECT", "MARKS"]

# Generate Data Frame
df = spark.createDataFrame(data=_data, schema = _cols)
df.show(truncate = False)

In [0]:
# Lets create a simple Python decorator - {get_time} to get the execution timings
# If you dont know about Python decorators - check out : https://www.geeksforgeeks.org/decorators-in-python/
import time

def get_time(func):
    def inner_get_time() -> str:
        start_time = time.time()
        func()
        end_time = time.time()
        return (f"Execution time: {(end_time - start_time)*1000} ms")
    print(inner_get_time())

### Method 1 - Without specifying column names

In [0]:
# Pivot data without specifying the column names(values) and checking the execution time
from pyspark.sql.functions import sum

@get_time
def x(): df.groupBy("NAME").pivot("SUBJECT").agg(sum("MARKS"))

In [0]:
# Lets check the data and schema
pivot_df_1 = df.groupBy("NAME").pivot("SUBJECT").agg(sum("MARKS"))
pivot_df_1.printSchema()
pivot_df_1.show(truncate = False)

### Method 2 - Specifying column names

In [0]:
# Get the time for extracting the distinct list
@get_time
def x(): df.select("SUBJECT").distinct().rdd.map(lambda x: x[0]).collect()

In [0]:
# Get the distinct list of Subjects
_subjects = df.select("SUBJECT").distinct().rdd.map(lambda x: x[0]).collect()
_subjects

In [0]:
# Pivot data specifying the column names(values) and checking the execution time
from pyspark.sql.functions import sum

@get_time
def x(): df.groupBy("NAME").pivot("SUBJECT", _subjects).agg(sum("MARKS"))

In [0]:
# Lets check the data and schema
pivot_df_2 = df.groupBy("NAME").pivot("SUBJECT", _subjects).agg(sum("MARKS"))
pivot_df_2.printSchema()
pivot_df_2.show(truncate = False)