# PySpark Window Function
PySpark window functions are a set of SQL-like operations that allow you to perform calculations across a group of rows that are related to the current row, but without collapsing the rows into a single row. These functions are particularly useful for tasks such as ranking, aggregating over specific partitions, and calculating cumulative or rolling statistics.

**Key Features of PySpark Window Functions**

1. Operate on a "Window" of Rows: Define a subset of data (the "window") for each row based on certain criteria like partitioning and ordering.
1. Non-collapsing: Unlike groupBy, window functions keep the number of rows unchanged.
1. SQL and Functional API: Can be used with both SQL queries and PySpark's DataFrame API.


**Common Use Cases**<br>
- **Ranking rows**: Assign ranks to rows within a partition.
- **Cumulative calculations**: Compute running totals, averages, etc.
- **Lag/Lead**: Access previous or next row values.
- **Aggregations**: Perform operations like min, max, avg over a specified window.

**Syntax**<br>
Using a window function involves three steps:

1. Define the Window:
    - Partition: Specifies the grouping of rows.
    - Order: Specifies the sorting within each partition.
1. Apply the Function: Perform an operation (e.g., sum, rank).
1. Use the Result: Add the calculated column to the DataFrame.

**Examples**

In [1]:
from pyspark.sql import SparkSession

# Initialize the SparkSession with a specific application name
spark = (SparkSession.builder
         .appName('PySpark Window Function')
         .getOrCreate())

spark

In [30]:
from pyspark.sql.functions import *
from pyspark.sql.window import Window

**Ranking**: Rank products within each category based on sales

In [55]:
data = [
    ("James", "Sales", 3000),
    ("Michael", "Sales", 4600),
    ("Robert", "Sales", 4100),
    ("Maria", "Finance", 3000),
    ("Scott", "Finance", 3300),
    ("Jen", "Finance", 3900),
    ("Jeff", "Marketing", 3000),
    ("Kumar", "Marketing", 2000),
    ("Saif", "Sales", 4100)
]

columns= ["employee_name", "department", "salary"]
df = spark.createDataFrame(data = data, schema = columns)

df.printSchema()
df.show()

root
 |-- employee_name: string (nullable = true)
 |-- department: string (nullable = true)
 |-- salary: long (nullable = true)

+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
|         Saif|     Sales|  4100|
+-------------+----------+------+



### Ranking functions

PySpark’s Window Ranking functions, such as row_number(), rank(), and dense_rank(), are used to assign unique identifiers or ranks to rows within a specific partition of a dataset. These functions operate over a window, which is a subset of data defined by a partitioning and ordering logic. They are useful for tasks like ordering, ranking, and identifying specific rows based on the specified conditions.

**Key Concepts**
- **Partition**: Divides the dataset into groups based on one or more columns (e.g., department).
- **Ordering**: Determines the sequence of rows within each partition (e.g., by salary).
- **Sequential Assignment**: These functions assign numbers to rows in the order defined by the partition and sorting criteria.

**Key Benefits**
- **Enhanced Data Insights**: Easily analyze and compare rows within groups.
- **Versatility**: Useful in real-world scenarios like leaderboard rankings, pagination, and top-N analysis.
- **Control Over Ties**: Choose between rank and dense_rank depending on how you want to handle ties.

In [56]:
window_spec_ex  = Window.partitionBy("department").orderBy("salary")

df.withColumn("row_number",row_number().over(window_spec_ex)) \
    .withColumn("rank",rank().over(window_spec_ex)) \
    .withColumn("dense_rank",dense_rank().over(window_spec_ex)) \
    .show()

+-------------+----------+------+----------+----+----------+
|employee_name|department|salary|row_number|rank|dense_rank|
+-------------+----------+------+----------+----+----------+
|        Maria|   Finance|  3000|         1|   1|         1|
|        Scott|   Finance|  3300|         2|   2|         2|
|          Jen|   Finance|  3900|         3|   3|         3|
|        Kumar| Marketing|  2000|         1|   1|         1|
|         Jeff| Marketing|  3000|         2|   2|         2|
|        James|     Sales|  3000|         1|   1|         1|
|       Robert|     Sales|  4100|         2|   2|         2|
|         Saif|     Sales|  4100|         3|   2|         2|
|      Michael|     Sales|  4600|         4|   4|         3|
+-------------+----------+------+----------+----+----------+



### Top Selling Product

In [None]:
salesdf = spark.read.csv('dataset/sales_data_cleaned.csv', header=True, inferSchema=True).drop('Month')
print(salesdf.count())
salesdf.printSchema()
salesdf.show(5)

791
root
 |-- Category: string (nullable = true)
 |-- Product: string (nullable = true)
 |-- Region: string (nullable = true)
 |-- Date: date (nullable = true)
 |-- Sales: double (nullable = true)
 |-- Quantity: double (nullable = true)

+----------+--------+------+----------+------+--------+
|  Category| Product|Region|      Date| Sales|Quantity|
+----------+--------+------+----------+------+--------+
|Home Decor| Cushion|  East|2022-04-13| 388.1|     8.0|
|  Clothing|  Jacket|  East|2023-12-02|278.27|     3.0|
|  Clothing|   Jeans|  West|2023-01-08|280.61|     8.0|
|    Sports|Football| South|2022-12-10|407.84|     6.0|
|Home Decor| Cushion| South|2023-01-21|  50.0|     5.0|
+----------+--------+------+----------+------+--------+
only showing top 5 rows



In [40]:
tot_salesdf = (salesdf
               .filter("Category != 'Unknown'")
               .groupBy('Category', 'Product')
               .agg(
                   round(sum(col("Sales") * col("Quantity"))).alias("TotalSales")
               ))
tot_salesdf.show()

+-----------+---------------+----------+
|   Category|        Product|TotalSales|
+-----------+---------------+----------+
|Electronics|     Headphones|   28898.0|
|     Sports|       Yoga Mat|   47373.0|
| Home Decor|           Lamp|   36884.0|
|Electronics|Unknown Product|   45391.0|
| Home Decor|        Cushion|   33569.0|
|      Books|        Fiction|   46667.0|
|      Books|         Comics|   56466.0|
|     Sports|       Football|   52603.0|
|Electronics|         Laptop|   46642.0|
|      Books|    Non-Fiction|   53522.0|
|Electronics|         Mobile|   49148.0|
| Home Decor|        Curtain|   70236.0|
|     Sports|  Tennis Racket|   86067.0|
|   Clothing|         Jacket|   53495.0|
|   Clothing|          Shirt|   63192.0|
|   Clothing|          Jeans|   64855.0|
+-----------+---------------+----------+



In [61]:
window_spec = Window.partitionBy("Category").orderBy(col("TotalSales").desc())

# Apply the ranking functions
ranked_df = tot_salesdf.withColumn("row_number", row_number().over(window_spec)) \
    .withColumn("rank", rank().over(window_spec)) \
    .withColumn("dense_rank", dense_rank().over(window_spec))
    
(ranked_df
    .filter(col("row_number") == 1)
    .select('Category', 'Product', 'TotalSales')
    .sort(col('TotalSales').desc())
    .show())

+-----------+-------------+----------+
|   Category|      Product|TotalSales|
+-----------+-------------+----------+
|     Sports|Tennis Racket|   86067.0|
| Home Decor|      Curtain|   70236.0|
|   Clothing|        Jeans|   64855.0|
|      Books|       Comics|   56466.0|
|Electronics|       Mobile|   49148.0|
+-----------+-------------+----------+



### ntile window function

The ```ntile() ```window function in PySpark is used to distribute rows of data into a specified number of buckets or groups, based on the ordering of the rows within a partition. The rows are evenly divided into the given number of buckets, and each row is assigned a bucket number from 1 to n, where n is the number of buckets.

**How ntile() works:**
- The function takes a single argument, which is the number of buckets (n) to divide the data into.
- It returns the bucket number for each row in the ordered set.
- The rows are first ordered according to a specified column and then divided into n groups as evenly as possible.
    - If the rows cannot be evenly divided, some buckets may contain one more row than others.

**Use Cases for ntile():**
- **Percentile Calculations**: Dividing data into quantiles like quartiles, deciles, etc., for analysis such as statistical summaries.
- **Categorization**: Assigning categories to data points based on ranks, such as splitting data into high, medium, and low categories.
- **Segmenting Data**: Segmenting users, customers, or employees based on certain metrics (e.g., income, sales performance, etc.) into equal-sized buckets.

**Example**: Each employee is assigned a Salary_Quartile based on their salary relative to others.

In [59]:
# Define the Window specification (ordering by Salary in descending order)
window_spec_ex = Window.orderBy(col("Salary").desc())
df.withColumn("Salary_Quartile",ntile(4).over(window_spec_ex)).show()

+-------------+----------+------+---------------+
|employee_name|department|salary|Salary_Quartile|
+-------------+----------+------+---------------+
|      Michael|     Sales|  4600|              1|
|       Robert|     Sales|  4100|              1|
|         Saif|     Sales|  4100|              1|
|          Jen|   Finance|  3900|              2|
|        Scott|   Finance|  3300|              2|
|        James|     Sales|  3000|              3|
|        Maria|   Finance|  3000|              3|
|         Jeff| Marketing|  3000|              4|
|        Kumar| Marketing|  2000|              4|
+-------------+----------+------+---------------+



In [None]:
window_spec = Window.partitionBy("Category").orderBy(col("TotalSales").desc())
tot_salesdf.withColumn("ntile",ntile(2).over(window_spec)).show()

+-----------+---------------+----------+-----+
|   Category|        Product|TotalSales|ntile|
+-----------+---------------+----------+-----+
|     Sports|  Tennis Racket|   86067.0|    1|
| Home Decor|        Curtain|   70236.0|    1|
|   Clothing|          Jeans|   64855.0|    1|
|   Clothing|          Shirt|   63192.0|    1|
|      Books|         Comics|   56466.0|    1|
|      Books|    Non-Fiction|   53522.0|    1|
|   Clothing|         Jacket|   53495.0|    1|
|     Sports|       Football|   52603.0|    1|
|Electronics|         Mobile|   49148.0|    2|
|     Sports|       Yoga Mat|   47373.0|    2|
|      Books|        Fiction|   46667.0|    2|
|Electronics|         Laptop|   46642.0|    2|
|Electronics|Unknown Product|   45391.0|    2|
| Home Decor|           Lamp|   36884.0|    2|
| Home Decor|        Cushion|   33569.0|    2|
|Electronics|     Headphones|   28898.0|    2|
+-----------+---------------+----------+-----+



### Cumulative Distribution Window Function

The ```cume_dist()``` window function in PySpark calculates the cumulative distribution of a value within a partition. It provides a measure of how the current row compares to all the other rows in the partition based on a specific ordering. It calculates the relative rank of each row in terms of its value in the specified partition, normalized between 0 and 1.

i.e. cume_dist() gives the fraction of rows in the partition that have a value less than or equal to the current row's value.

In [62]:
# Apply cume_dist function to calculate the cumulative distribution of salaries
window_spec_ex  = Window.partitionBy("department").orderBy("salary")
df.withColumn("CumeDist", cume_dist().over(window_spec_ex)).show()

+-------------+----------+------+------------------+
|employee_name|department|salary|          CumeDist|
+-------------+----------+------+------------------+
|        Maria|   Finance|  3000|0.3333333333333333|
|        Scott|   Finance|  3300|0.6666666666666666|
|          Jen|   Finance|  3900|               1.0|
|        Kumar| Marketing|  2000|               0.5|
|         Jeff| Marketing|  3000|               1.0|
|        James|     Sales|  3000|              0.25|
|       Robert|     Sales|  4100|              0.75|
|         Saif|     Sales|  4100|              0.75|
|      Michael|     Sales|  4600|               1.0|
+-------------+----------+------+------------------+



In [64]:
window_spec = Window.partitionBy("Category").orderBy(col("TotalSales"))
tot_salesdf.withColumn("CumeDist",cume_dist().over(window_spec)).show()

+-----------+---------------+----------+------------------+
|   Category|        Product|TotalSales|          CumeDist|
+-----------+---------------+----------+------------------+
|      Books|        Fiction|   46667.0|0.3333333333333333|
|      Books|    Non-Fiction|   53522.0|0.6666666666666666|
|      Books|         Comics|   56466.0|               1.0|
|   Clothing|         Jacket|   53495.0|0.3333333333333333|
|   Clothing|          Shirt|   63192.0|0.6666666666666666|
|   Clothing|          Jeans|   64855.0|               1.0|
|Electronics|     Headphones|   28898.0|              0.25|
|Electronics|Unknown Product|   45391.0|               0.5|
|Electronics|         Laptop|   46642.0|              0.75|
|Electronics|         Mobile|   49148.0|               1.0|
| Home Decor|        Cushion|   33569.0|0.3333333333333333|
| Home Decor|           Lamp|   36884.0|0.6666666666666666|
| Home Decor|        Curtain|   70236.0|               1.0|
|     Sports|       Yoga Mat|   47373.0|

## Lag Window Function

The ```lag()``` window function in PySpark is used to access data from a previous row in the same result set, without needing to perform a self-join. This function provides a way to compare the current row's value to previous row values based on a specified order within a partition.

**Syntax:** ```df.withColumn("lag_column", F.lag(column_name, offset, default_value).over(window_spec))```

**Where:**
- **column_name**: The column from which you want to retrieve the lagged value.
- **offset**: (Optional) The number of rows before the current row to access. By default, it's 1.
- **default_value**: (Optional) The value to return if there is no previous row. By default, it's None.
- **window_spec**: Defines how to partition and order the data.

**How it Works:**
- The lag() function returns the value of the column from a previous row based on the specified offset (number of rows before the current row).
- It requires a window specification that defines how the data should be partitioned and ordered.
- If there is no previous row (e.g., for the first row), the lag() function returns the default value or null if no default value is provided.

In [72]:
window_spec_ex  = Window.partitionBy("department").orderBy("salary")
df.withColumn("lag",lag("salary").over(window_spec_ex)) \
    .withColumn("lead",lead("salary").over(window_spec_ex)).show()

+-------------+----------+------+----+----+
|employee_name|department|salary| lag|lead|
+-------------+----------+------+----+----+
|        Maria|   Finance|  3000|NULL|3300|
|        Scott|   Finance|  3300|3000|3900|
|          Jen|   Finance|  3900|3300|NULL|
|        Kumar| Marketing|  2000|NULL|3000|
|         Jeff| Marketing|  3000|2000|NULL|
|        James|     Sales|  3000|NULL|4100|
|       Robert|     Sales|  4100|3000|4100|
|         Saif|     Sales|  4100|4100|4600|
|      Michael|     Sales|  4600|4100|NULL|
+-------------+----------+------+----+----+



In [73]:
# Filter for the year 2023
df_2023 = salesdf.filter(year(col("Date")) == 2023)

# Group by Month and calculate total sales
df_monthly_sales = df_2023.groupBy(month(col("Date")).alias("Month")).agg(
    round(sum("Sales"),2).alias("TotalSales")
).sort('Month')

In [70]:
# Define the Window specification (order by Month)
window_spec = Window.orderBy("Month")

# Apply lag() and lead() to get the previous and next month's sales
df_with_lag_lead = df_monthly_sales.withColumn(
    "PreviousMonthSales", lag("TotalSales", 1).over(window_spec)
).withColumn(
    "NextMonthSales", lead("TotalSales", 1).over(window_spec)
)

df_with_lag_lead.show()

+-----+----------+------------------+--------------+
|Month|TotalSales|PreviousMonthSales|NextMonthSales|
+-----+----------+------------------+--------------+
|    1|   9980.59|              NULL|       8739.22|
|    2|   8739.22|           9980.59|       9191.11|
|    3|   9191.11|           8739.22|       7492.81|
|    4|   7492.81|           9191.11|       9045.63|
|    5|   9045.63|           7492.81|       7538.85|
|    6|   7538.85|           9045.63|       9150.08|
|    7|   9150.08|           7538.85|       7514.56|
|    8|   7514.56|           9150.08|       8443.14|
|    9|   8443.14|           7514.56|       6863.55|
|   10|   6863.55|           8443.14|      10019.73|
|   11|  10019.73|           6863.55|       6847.05|
|   12|   6847.05|          10019.73|          NULL|
+-----+----------+------------------+--------------+



In [74]:
spark.stop()