# Chapter 7 - Aggregation

In [2]:
df = spark.read.format("csv")\
    .option("header", "true")\
    .option("inferSchema", "true")\
    .load("../pyspark-training/data/The-Definitive-Guide/retail-data/all/*.csv")\
    .coalesce(5)
df.cache()
df.createOrReplaceTempView("dfTable")

In [3]:
df.show(5)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
only showing top 5 rows



In [4]:
# count() is a method and an action here
df.count()

541909

## Aggregation Functions
### Note: These aggregation functions are used with `select()` and take a string as the argument rather than col()

In [9]:
from pyspark.sql.functions import count, countDistinct, approx_count_distinct, first, last

In [11]:
# count() is a function here
df.select(count("StockCode"),
          countDistinct("StockCode"),
          approx_count_distinct("StockCode", 0.1),
          first("StockCode"),
          last("StockCode"),
         ).show()

+----------------+-------------------------+--------------------------------+-----------------------+----------------------+
|count(StockCode)|count(DISTINCT StockCode)|approx_count_distinct(StockCode)|first(StockCode, false)|last(StockCode, false)|
+----------------+-------------------------+--------------------------------+-----------------------+----------------------+
|          541909|                     4070|                            3364|                  20868|                 90149|
+----------------+-------------------------+--------------------------------+-----------------------+----------------------+



In [14]:
from pyspark.sql.functions import min, max, sum, sumDistinct, avg, expr

In [16]:
df.select(min("Quantity"),
          max("Quantity"),
          sum("Quantity"),
          sumDistinct("Quantity"),
          avg("Quantity"),
          expr("mean(Quantity)")
         ).show()

+-------------+-------------+-------------+----------------------+----------------+----------------+
|min(Quantity)|max(Quantity)|sum(Quantity)|sum(DISTINCT Quantity)|   avg(Quantity)|   avg(Quantity)|
+-------------+-------------+-------------+----------------------+----------------+----------------+
|       -80995|        80995|      5176450|                 29310|9.55224954743324|9.55224954743324|
+-------------+-------------+-------------+----------------------+----------------+----------------+



In [17]:
from pyspark.sql.functions import variance, stddev, var_pop, stddev_pop, var_samp, stddev_samp, skewness, kurtosis

In [21]:
df.select(variance("Quantity").alias("var"),
          var_pop("Quantity").alias("var_pop"),
          var_samp("Quantity").alias("var_samp"),
          stddev("Quantity").alias("stddev"),
          stddev_pop("Quantity").alias("stddev_pop"),
          stddev_samp("Quantity").alias("stddev_samp")
         ).show()

+------------------+------------------+------------------+------------------+------------------+------------------+
|               var|           var_pop|          var_samp|            stddev|        stddev_pop|       stddev_samp|
+------------------+------------------+------------------+------------------+------------------+------------------+
|47559.391409298754|47559.303646609056|47559.391409298754|218.08115785023418|218.08095663447796|218.08115785023418|
+------------------+------------------+------------------+------------------+------------------+------------------+



In [22]:
df.select(skewness("Quantity").alias("skewness"),
          kurtosis("Quantity").alias("kurtosis")
         ).show()

+-------------------+------------------+
|           skewness|          kurtosis|
+-------------------+------------------+
|-0.2640755761052562|119768.05495536952|
+-------------------+------------------+



In [23]:
from pyspark.sql.functions import corr, covar_pop, covar_samp

In [24]:
df.select(corr("InvoiceNo", "Quantity"),
          covar_samp("InvoiceNo", "Quantity"),
          covar_pop("InvoiceNo", "Quantity")).show()

+-------------------------+-------------------------------+------------------------------+
|corr(InvoiceNo, Quantity)|covar_samp(InvoiceNo, Quantity)|covar_pop(InvoiceNo, Quantity)|
+-------------------------+-------------------------------+------------------------------+
|     4.912186085635685E-4|             1052.7280543902734|            1052.7260778741693|
+-------------------------+-------------------------------+------------------------------+



In [26]:
from pyspark.sql.functions import collect_set, collect_list

### NOTE the `agg()` method is used here when aggregating to complex types

In [27]:
df.agg(collect_set("Country"), collect_list("Country")).show()

+--------------------+---------------------+
|collect_set(Country)|collect_list(Country)|
+--------------------+---------------------+
|[Portugal, Italy,...| [United Kingdom, ...|
+--------------------+---------------------+



## Grouping

In [29]:
# Special case: use count() as a method'
df.groupBy("InvoiceNo", "CustomerId").count().show(2)

+---------+----------+-----+
|InvoiceNo|CustomerId|count|
+---------+----------+-----+
|   536846|     14573|   76|
|   537026|     12395|   12|
+---------+----------+-----+
only showing top 2 rows



Use `agg()` instead of `select()`

In [30]:
df.groupBy("InvoiceNo").agg(
    count("Quantity").alias("quan"),
    expr("count(Quantity)")
).show(2)

+---------+----+---------------+
|InvoiceNo|quan|count(Quantity)|
+---------+----+---------------+
|   536596|   6|              6|
|   536938|  14|             14|
+---------+----+---------------+
only showing top 2 rows



In [31]:
df.groupBy("InvoiceNo").agg(expr("avg(Quantity)"),
                            expr("stddev_pop(Quantity)")).show(2)

+---------+------------------+--------------------+
|InvoiceNo|     avg(Quantity)|stddev_pop(Quantity)|
+---------+------------------+--------------------+
|   536596|               1.5|  1.1180339887498947|
|   536938|33.142857142857146|  20.698023172885524|
+---------+------------------+--------------------+
only showing top 2 rows



## Window Functions

In [32]:
from pyspark.sql.functions import col, to_date, desc, dense_rank, rank
from pyspark.sql.window import Window

In [42]:
# Create date column
dfWithDate = df.withColumn("date", to_date(col("InvoiceDate"), 'MM/dd/yyyy HH:mm'))

# Create a window specification
    # Note that Window object is created here
windowSpec = Window\
    .partitionBy("CustomerId", "date")\
    .orderBy(desc("Quantity"))\
    .rowsBetween(Window.unboundedPreceding, Window.currentRow) # Note the attribute of Window object here

# Apply functions with the Window specification
# The output is a column
maxPurchaseQuantity = max(col("Quantity")).over(windowSpec)
    # dense_rank and rank are done within partitions
purchaseDenseRank = dense_rank().over(windowSpec)
purchaseRank = rank().over(windowSpec)

# Get the final output as DataFrame
dfWithDate.where("CustomerId IS NOT NULL").orderBy("CustomerId")\
    .select(col("CustomerId"),
            col("date"),
            col("Quantity"),
            # pass on the columns with windows grouping
            purchaseRank.alias("quantityRank"),
            purchaseDenseRank.alias("quantityDenseRank"),
            maxPurchaseQuantity.alias("maxPurchaseQuantity")
).show()

+----------+----------+--------+------------+-----------------+-------------------+
|CustomerId|      date|Quantity|quantityRank|quantityDenseRank|maxPurchaseQuantity|
+----------+----------+--------+------------+-----------------+-------------------+
|     12346|2011-01-18|   74215|           1|                1|              74215|
|     12346|2011-01-18|  -74215|           2|                2|              74215|
|     12347|2010-12-07|      36|           1|                1|                 36|
|     12347|2010-12-07|      30|           2|                2|                 36|
|     12347|2010-12-07|      24|           3|                3|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|             

## Rollups
Grouping with total

In [43]:
dfWithDate.rollup("date", "Country").agg(sum("Quantity"))\
    .selectExpr("date", "Country", "`sum(Quantity)` as total_quantity")\
    .orderBy("date").show()

+----------+--------------+--------------+
|      date|       Country|total_quantity|
+----------+--------------+--------------+
|      null|          null|       5176450|
|2010-12-01|        France|           449|
|2010-12-01|          EIRE|           243|
|2010-12-01|          null|         26814|
|2010-12-01|       Germany|           117|
|2010-12-01|United Kingdom|         23949|
|2010-12-01|   Netherlands|            97|
|2010-12-01|     Australia|           107|
|2010-12-01|        Norway|          1852|
|2010-12-02|       Germany|           146|
|2010-12-02|          EIRE|             4|
|2010-12-02|          null|         21023|
|2010-12-02|United Kingdom|         20873|
|2010-12-03|        Poland|           140|
|2010-12-03|   Switzerland|           110|
|2010-12-03|       Germany|           170|
|2010-12-03|      Portugal|            65|
|2010-12-03|         Spain|           400|
|2010-12-03|         Italy|           164|
|2010-12-03|       Belgium|           528|
+----------

## Cube
Grouping with subtotals

In [45]:
dfWithDate.cube("date", "Country").agg(sum(col("Quantity")))\
    .select("Date", "Country", "sum(Quantity)").orderBy("Date").show(50)

+----------+--------------------+-------------+
|      Date|             Country|sum(Quantity)|
+----------+--------------------+-------------+
|      null|        Saudi Arabia|           75|
|      null|              Sweden|        35637|
|      null|           Australia|        83653|
|      null|              Israel|         4353|
|      null|             Finland|        10666|
|      null|              Poland|         3653|
|      null|                EIRE|       142637|
|      null|              Canada|         2763|
|      null|      United Kingdom|      4263829|
|      null|              France|       110480|
|      null|             Belgium|        23152|
|      null|              Greece|         1556|
|      null|              Brazil|          356|
|      null|                 RSA|          352|
|      null|                 USA|         1034|
|      null|              Cyprus|         6317|
|      null|           Hong Kong|         4769|
|      null|               Italy|       

## Pivot

In [53]:
pivoted = dfWithDate.where("Country IN ('USA', 'Japan')")\
    .select("date", "Country", "Quantity", "UnitPrice")\
    .groupBy("date").pivot("Country").sum()
pivoted.show()

+----------+-----------------------------------+--------------------+---------------------------------+------------------+
|      date|Japan_sum(CAST(Quantity AS BIGINT))|Japan_sum(UnitPrice)|USA_sum(CAST(Quantity AS BIGINT))|USA_sum(UnitPrice)|
+----------+-----------------------------------+--------------------+---------------------------------+------------------+
|2011-03-07|                                  9|   95.22000000000001|                             null|              null|
|2011-04-19|                               null|                null|                              137|              74.2|
|2011-02-09|                               2962|              115.62|                             null|              null|
|2011-10-12|                               null|                null|                            -1228|            217.93|
|2011-11-29|                               2040|                1.79|                             null|              null|
|2010-12-09|    