In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.window import WindowSpec, Window
from pyspark.sql.functions import *

In [2]:
spark = SparkSession.builder \
.master("local[*]") \
.appName("Window Functions") \
.getOrCreate()

spark

In [3]:
order_df = spark.read.csv("./dataset/windowdatamodified.csv", header=True, inferSchema=True)
order_df.show(5)

+---------+-------+-----------+-------------+------------+
|  country|weeknum|numinvoices|totalquantity|invoicevalue|
+---------+-------+-----------+-------------+------------+
|    Spain|     49|          1|           67|      174.72|
|  Germany|     48|         11|         1795|      1600.0|
|Lithuania|     48|          3|          622|     1598.06|
|  Germany|     49|         12|         1852|      1800.0|
|  Bahrain|     51|          1|           54|      205.74|
+---------+-------+-----------+-------------+------------+
only showing top 5 rows



In [4]:
order_df = order_df.sort(col("country").asc(), col("weeknum").desc())
order_df.show()

+---------------+-------+-----------+-------------+------------+
|        country|weeknum|numinvoices|totalquantity|invoicevalue|
+---------------+-------+-----------+-------------+------------+
|      Australia|     50|          2|          133|      387.95|
|      Australia|     49|          1|          214|       258.9|
|      Australia|     48|          1|          107|      358.25|
|        Austria|     50|          2|            3|      257.04|
|        Bahrain|     51|          1|           54|      205.74|
|        Belgium|     51|          2|          942|       800.0|
|        Belgium|     50|          2|          285|      625.16|
|        Belgium|     48|          1|          528|       800.0|
|Channel Islands|     49|          1|           80|      363.53|
|         Cyprus|     50|          1|          917|     1590.82|
|        Denmark|     49|          1|          454|      1281.5|
|        Finland|     50|          1|         1254|       892.8|
|         France|     51|

### Agg Func + Window

In [5]:
window1 = Window.partitionBy("country")

In [6]:
order_df.withColumn("max_value", max(col("invoicevalue")).over(window1)) \
    .withColumn("min_value", min(col("invoicevalue")).over(window1)) \
    .withColumn("sum_value", sum(col("invoicevalue")).over(window1)) \
    .withColumn("avg_value", avg(col("invoicevalue")).over(window1)) \
    .show()

+---------------+-------+-----------+-------------+------------+---------+---------+------------------+------------------+
|        country|weeknum|numinvoices|totalquantity|invoicevalue|max_value|min_value|         sum_value|         avg_value|
+---------------+-------+-----------+-------------+------------+---------+---------+------------------+------------------+
|      Australia|     50|          2|          133|      387.95|   387.95|    258.9|1005.0999999999999| 335.0333333333333|
|      Australia|     49|          1|          214|       258.9|   387.95|    258.9|1005.0999999999999| 335.0333333333333|
|      Australia|     48|          1|          107|      358.25|   387.95|    258.9|1005.0999999999999| 335.0333333333333|
|        Austria|     50|          2|            3|      257.04|   257.04|   257.04|            257.04|            257.04|
|        Bahrain|     51|          1|           54|      205.74|   205.74|   205.74|            205.74|            205.74|
|        Belgium

### Rank, Dense_Rank, Row_Number

In [7]:
window2 = Window.partitionBy("country").orderBy(col("invoicevalue").desc())

In [8]:
ranked_df = order_df.withColumn("rank", rank().over(window2)) \
    .withColumn("dense_rank", dense_rank().over(window2)) \
    .withColumn("row_number", row_number().over(window2))

ranked_df.show()

+---------------+-------+-----------+-------------+------------+----+----------+----------+
|        country|weeknum|numinvoices|totalquantity|invoicevalue|rank|dense_rank|row_number|
+---------------+-------+-----------+-------------+------------+----+----------+----------+
|      Australia|     50|          2|          133|      387.95|   1|         1|         1|
|      Australia|     48|          1|          107|      358.25|   2|         2|         2|
|      Australia|     49|          1|          214|       258.9|   3|         3|         3|
|        Austria|     50|          2|            3|      257.04|   1|         1|         1|
|        Bahrain|     51|          1|           54|      205.74|   1|         1|         1|
|        Belgium|     51|          2|          942|       800.0|   1|         1|         1|
|        Belgium|     48|          1|          528|       800.0|   1|         1|         2|
|        Belgium|     50|          2|          285|      625.16|   3|         2|

In [9]:
ranked_df.where("row_number = 1").show()

+---------------+-------+-----------+-------------+------------+----+----------+----------+
|        country|weeknum|numinvoices|totalquantity|invoicevalue|rank|dense_rank|row_number|
+---------------+-------+-----------+-------------+------------+----+----------+----------+
|      Australia|     50|          2|          133|      387.95|   1|         1|         1|
|        Austria|     50|          2|            3|      257.04|   1|         1|         1|
|        Bahrain|     51|          1|           54|      205.74|   1|         1|         1|
|        Belgium|     51|          2|          942|       800.0|   1|         1|         1|
|Channel Islands|     49|          1|           80|      363.53|   1|         1|         1|
|         Cyprus|     50|          1|          917|     1590.82|   1|         1|         1|
|        Denmark|     49|          1|          454|      1281.5|   1|         1|         1|
|        Finland|     50|          1|         1254|       892.8|   1|         1|

### Lead & Lag

In [10]:
window3 = Window.partitionBy("country").orderBy(col("weeknum"))

In [11]:
lag_df = order_df \
    .withColumn("previous_week", lag("invoicevalue").over(window3)) \
    .withColumn("previous_week_2", lag("invoicevalue", 2).over(window3)) \
    .withColumn("diff", col("invoicevalue") - col("previous_week"))

lag_df.show()

+---------------+-------+-----------+-------------+------------+-------------+---------------+-------------------+
|        country|weeknum|numinvoices|totalquantity|invoicevalue|previous_week|previous_week_2|               diff|
+---------------+-------+-----------+-------------+------------+-------------+---------------+-------------------+
|      Australia|     48|          1|          107|      358.25|         NULL|           NULL|               NULL|
|      Australia|     49|          1|          214|       258.9|       358.25|           NULL| -99.35000000000002|
|      Australia|     50|          2|          133|      387.95|        258.9|         358.25|             129.05|
|        Austria|     50|          2|            3|      257.04|         NULL|           NULL|               NULL|
|        Bahrain|     51|          1|           54|      205.74|         NULL|           NULL|               NULL|
|        Belgium|     48|          1|          528|       800.0|         NULL|  

In [12]:
lead_df = order_df \
    .withColumn("next_week", lead("invoicevalue").over(window3)) \
    .withColumn("next_week_2", lead("invoicevalue", 2).over(window3)) \
    .withColumn("diff", col("invoicevalue") - col("next_week"))

lead_df.show()

+---------------+-------+-----------+-------------+------------+---------+-----------+-------------------+
|        country|weeknum|numinvoices|totalquantity|invoicevalue|next_week|next_week_2|               diff|
+---------------+-------+-----------+-------------+------------+---------+-----------+-------------------+
|      Australia|     48|          1|          107|      358.25|    258.9|     387.95|  99.35000000000002|
|      Australia|     49|          1|          214|       258.9|   387.95|       NULL|            -129.05|
|      Australia|     50|          2|          133|      387.95|     NULL|       NULL|               NULL|
|        Austria|     50|          2|            3|      257.04|     NULL|       NULL|               NULL|
|        Bahrain|     51|          1|           54|      205.74|     NULL|       NULL|               NULL|
|        Belgium|     48|          1|          528|       800.0|   625.16|      800.0| 174.84000000000003|
|        Belgium|     50|          2|