- The pivot() function in PySpark is used to rotate (pivot) a DataFrame by turing unique values from one column into separate columns. It's commonly used for data summarization and reporting.

- when you pivot data, you group it by some columns and create new columns for each unique value in another column. Typically, an aggregate function (like sum, avg) is applied to fill in the values.

- syntax:
    - df.groupBy(<group_columns>).pivot(<pivot_column>).agg(<aggregate_functions>)

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("PivotFunctionExample").getOrCreate()


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/09/09 15:08:54 WARN Utils: Your hostname, KLZPC0015, resolves to a loopback address: 127.0.1.1; using 172.25.17.96 instead (on interface eth0)
25/09/09 15:08:54 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/09 15:09:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/09/09 15:09:04 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/09/09 15:09:04 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
25/09/09 15:09:04 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
25/09/09 15:09:04 WARN Utils: Serv

In [12]:
# create sample dataframes
data = [
    (1, "Manta", 2023, "Product A", 240),
    (2, "Dipankar", 2024, "Product A", 270),
    (3, "Souvik", 2022, "Product B", 270),
    (4, "Soukarjya", 2025, "Product A", None),
    (5, "Arvind", 2022, "Product C", 280),
    (6, "Prodipta", 2022, "Product B", 280),
    (7, "Padma", 2025, "Product A", 270),
    (8, "Panta", 2023, "Product C", 270),
    (9, "Sougato", 2022, "Product B", 290)
]

columns = ["id", "name", "year", "product", "revenue"]

df = spark.createDataFrame(data, schema=columns)
df.show()


[Stage 27:>                                                         (0 + 3) / 3]

+---+---------+----+---------+-------+
| id|     name|year|  product|revenue|
+---+---------+----+---------+-------+
|  1|    Manta|2023|Product A|    240|
|  2| Dipankar|2024|Product A|    270|
|  3|   Souvik|2022|Product B|    270|
|  4|Soukarjya|2025|Product A|   NULL|
|  5|   Arvind|2022|Product C|    280|
|  6| Prodipta|2022|Product B|    280|
|  7|    Padma|2025|Product A|    270|
|  8|    Panta|2023|Product C|    270|
|  9|  Sougato|2022|Product B|    290|
+---+---------+----+---------+-------+



                                                                                

In [None]:
# we pivot the "product" column to turnunique product names into columns and sum the "revenue" for each combination of name and year
from pyspark.sql.functions import sum

pivot_df = df.groupBy("name", "year") \
             .pivot("product") \
             .agg(sum("revenue"))

pivot_df.show()




+---------+----+---------+---------+---------+
|     name|year|Product A|Product B|Product C|
+---------+----+---------+---------+---------+
|   Arvind|2022|     NULL|     NULL|      280|
|    Manta|2023|      240|     NULL|     NULL|
|Soukarjya|2025|     NULL|     NULL|     NULL|
|   Souvik|2022|     NULL|      270|     NULL|
|  Sougato|2022|     NULL|      290|     NULL|
| Dipankar|2024|      270|     NULL|     NULL|
|    Padma|2025|      270|     NULL|     NULL|
|    Panta|2023|     NULL|     NULL|      270|
| Prodipta|2022|     NULL|      280|     NULL|
+---------+----+---------+---------+---------+



                                                                                

In [15]:
# we pivot the "year" column to turn years into columns and sum the "revenue" for each name across different years

pivot_df_year = df.groupBy("name") \
                  .pivot("year") \
                  .agg(sum("revenue"))

pivot_df_year.show()


[Stage 58:>                                                         (0 + 4) / 4]

+---------+----+----+----+----+
|     name|2022|2023|2024|2025|
+---------+----+----+----+----+
|    Padma|NULL|NULL|NULL| 270|
|    Manta|NULL| 240|NULL|NULL|
|    Panta|NULL| 270|NULL|NULL|
|  Sougato| 290|NULL|NULL|NULL|
|Soukarjya|NULL|NULL|NULL|NULL|
| Prodipta| 280|NULL|NULL|NULL|
|   Souvik| 270|NULL|NULL|NULL|
| Dipankar|NULL|NULL| 270|NULL|
|   Arvind| 280|NULL|NULL|NULL|
+---------+----+----+----+----+



                                                                                

- pivot() - reshapes DatafRame by turning unique values from a column into separate columns.
- can perform aggregations like sum(), avg(), min(), etc.
- Useful for reporting, summarizing data and preparing data for BI tools.