# Spark Learning Note - Data Aggregations

Jia Geng | gjia0214@gmail.com

<a id='directory'></a>

## Directory

- [Data Source](https://github.com/databricks/Spark-The-Definitive-Guide/tree/master/data/)
- [1. DataFrame Level Aggregation](#sec1)
- [2. GroupBy and Aggregate](#sec2)
- [3. Window Function](#sec3)
- [4. GroupSets, Rollup, and Cube](#sec4)
- [5. Pivot](#sec5)


In [22]:
# check java version 
# use sudo update-alternatives --config java to switch java version if needed.
!java -version

openjdk version "1.8.0_252"
OpenJDK Runtime Environment (build 1.8.0_252-8u252-b09-1~19.10-b09)
OpenJDK 64-Bit Server VM (build 25.252-b09, mixed mode)


In [23]:
from pyspark.sql.session import SparkSession

spark = SparkSession.builder.appName('Spark Learning').getOrCreate()
spark

In [24]:
data_example = '/home/jgeng/Documents/Git/SparkLearning/book_data/retail-data/all/online-retail-dataset.csv'

In [26]:
df = spark.read.format('csv').option('header', True).option('inferSchema', True).load(data_example)
df.printSchema()
df.show(3)
df.cache()  # cache is lazy operation, it does not cache data until use it
df.count()  # since count is an action on all data, call this will cache all data on memory!!!

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: integer (nullable = true)
 |-- Country: string (nullable = true)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|
+---------+---------+--------------------+--

541909

## Aggregation

Aggregation is to group the rows by a key and grouping function. In spark, the groupby operation will return a `RelationalGroupedDataset` object.

Grouping types in spark include:
- Dataframe level aggregation.
- **group by**: Aggregate using one or more keys and one or more grouping functions
- **window**: Aggregate using one or more keys and one or more grouping functions. Functions are related to the current row.
- **group set**: Aggregate at multiple different levels
    - **roll up**: one or more keys and one or more values, summarized hierarchically
    - **cube**: one or more keys and one or more values, summarized across all combinations of columns

### 1. DataFrame Level Aggregation <a id='sec1'></a>

Common aggregation functions on dataframes are under `pyspark.sql.functions`. Work on columns.
- `count()`: `df.count()` is action.
- `countDistinct()`: can be slow when data is large
- `approx_count_distinct(col_name, prec)`: faster option, take a precision param
- `first()`, `last()`: get first/last value of a column
- `min()`, `max()`, `sum()`, `sumDistinct()`, `avg()`: work as it means
- `var_pop()`, `var_sample()`, `stddev_pop()`, `stddev_sample()`: work as it means
- `skewness()`, `kurtosis()`
    - skewness: Skewness is a measure of symmetry, or more precisely, the lack of symmetry. A distribution, or data set, is symmetric if it looks the same to the left and right of the center point.
        - normal dist. skewness = 0 (symmetry, left/right tails are same)
        - positive skewness: right skew - right tail is longer
        - negative skewness: left skew - left tail is longer
    - kurtosis: Kurtosis is a measure of whether the data are heavy-tailed or light-tailed relative to a normal distribution. That is, data sets with high kurtosis tend to have heavy tails, or outliers. Data sets with low kurtosis tend to have light tails, or lack of outliers. A uniform distribution would be the extreme case.
        - normal dist. kurtosis = 0
        - positive kurtosis: heavy tailed
        - negative kurtosis: light tailed
    
- `corr()`, `covar_pop()`, `covar_sample()`

Spark also support aggregate column values into an array using `collect_set()` or `collect_list()` fucntion.

[back to top](#directory)

In [27]:
from pyspark.sql.functions import countDistinct, approx_count_distinct, col, struct, array
# count, distince count, count coloumn 
df.select(countDistinct(col('StockCode')).alias('DistinctCount')).show()

# work faster when data is very large
df.select(approx_count_distinct(col('StockCode'), 0.01).alias('DistinctCount')).show()

# can count distinct multiple columns
df.select(countDistinct(col('StockCode'), col('Quantity')).alias('DistinctCount')).show()  

# this would work on multiple columns but slower
df.select(approx_count_distinct(struct(col('StockCode'), col('Quantity')), 0.01).alias('Approx')).show()

# this would also work on multiple columns but slower
df.select(approx_count_distinct(array(col('StockCode'), col('Quantity')), 0.01).alias('Approx')).show()

# refresh the use of distinct() to show all distinct rows
df.select(col('StockCode'), col('Quantity')).distinct().show(3)
df.select(col('StockCode'), col('Quantity')).distinct().count()  # same results as countDistinct!

+-------------+
|DistinctCount|
+-------------+
|         4070|
+-------------+

+-------------+
|DistinctCount|
+-------------+
|         4079|
+-------------+

+-------------+
|DistinctCount|
+-------------+
|        45280|
+-------------+

+------+
|Approx|
+------+
| 45378|
+------+

+------+
|Approx|
+------+
| 45314|
+------+

+---------+--------+
|StockCode|Quantity|
+---------+--------+
|    21485|       6|
|    84347|       3|
|    22454|       2|
+---------+--------+
only showing top 3 rows



45280

In [30]:
from pyspark.sql.functions import min, max, first, last, sum, avg, var_pop, skewness, kurtosis

# some column based stats
min_quantity = min(df['Quantity'])  # this also work
max_quantity = max('Quantity')  # this also work
first_quantity = first(df.Quantity)
last_quantity = last(df.Quantity)
sum_quantity = sum(df.Quantity)
avg_quantity = avg(df.Quantity)
var_quantity = var_pop(df.Quantity)
skewness_quantity = skewness(df.Quantity)
kurtosis_quantity = kurtosis(df.Quantity)

df.select(min_quantity.alias('min'), max_quantity.alias('max'), 
          first_quantity.alias('first'), last_quantity.alias('last'),
          sum_quantity.alias('sum'), avg_quantity.alias('avg'),
          var_quantity.alias('var'), skewness_quantity.alias('skewness'),
          kurtosis_quantity.alias('kurtosis')).show()

+------+-----+-----+----+-------+----------------+------------------+-------------------+------------------+
|   min|  max|first|last|    sum|             avg|               var|           skewness|          kurtosis|
+------+-----+-----+----+-------+----------------+------------------+-------------------+------------------+
|-80995|80995|    6|   3|5176450|9.55224954743324|47559.303646609165|-0.2640755761052369|119768.05495536828|
+------+-----+-----+----+-------+----------------+------------------+-------------------+------------------+



In [31]:
from pyspark.sql.functions import corr, covar_pop

# correlation between two columns
cor_qp = corr(df.Quantity, df.UnitPrice)

# correlation is covariance normalized by variance (pop/sample)
covar_qp = covar_pop(df.Quantity, df.UnitPrice)

# print it out
df.select(cor_qp.alias('Correlation'), covar_qp.alias('Covariance')).show()

+--------------------+-------------------+
|         Correlation|         Covariance|
+--------------------+-------------------+
|-0.00123492454487...|-26.058713170967746|
+--------------------+-------------------+



In [32]:
from pyspark.sql.functions import collect_set, collect_list

agged_df = df.agg(collect_set(col('Quantity')), collect_list(col('Quantity')))
agged_df.show()

+---------------------+----------------------+
|collect_set(Quantity)|collect_list(Quantity)|
+---------------------+----------------------+
| [-42, 306, 256, 1...|  [6, 6, 8, 6, 6, 2...|
+---------------------+----------------------+



### 2. GroupBy and Aggregate <a id='sec2'></a>

More common task is to perform calculation based on the groups in the data. This is usually a two stage process:
- group by some keys: `.groupBy(col_names, ...)`, support multiple comlumns
- aggregate by some function `.agg(func(col), ...)`, this can take multiple functions!

`.groupBy()` return a `GroupedData` object which supports `.agg()` that can take aggregation functions on the data groups.

[back to top](#directory)

In [33]:
df.show(1)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
only showing top 1 row



In [34]:
from pyspark.sql.functions import count, expr, col

# group by can work with multiple columns
df.groupBy('StockCode', 'Country').agg(count('StockCode').alias('Count'), 
                                       avg(col('UnitPrice')).alias('Avg')).show(3)
# can use expr for full string implementation
df.groupBy('StockCode', 'Country').agg(expr('count(StockCode)').alias('Count'), 
                                       expr('avg(UnitPrice)').alias('Avg')).show(3)

+---------+--------------+-----+------------------+
|StockCode|       Country|Count|               Avg|
+---------+--------------+-----+------------------+
|    22154|United Kingdom|  170|0.5414117647058824|
|    22478|United Kingdom|  133|1.8110526315789475|
|    22844|United Kingdom|  402|10.921791044776118|
+---------+--------------+-----+------------------+
only showing top 3 rows

+---------+--------------+-----+------------------+
|StockCode|       Country|Count|               Avg|
+---------+--------------+-----+------------------+
|    22154|United Kingdom|  170|0.5414117647058824|
|    22478|United Kingdom|  133|1.8110526315789475|
|    22844|United Kingdom|  402|10.921791044776118|
+---------+--------------+-----+------------------+
only showing top 3 rows



### 3 Window Functions <a id='sec3'></a>

A window is a specification of which rows should be used for the computation (aggregations).
- for groupBy, each row can only go into one group
- **for window, a row can go into multiple groups. e.g. rolling average**

The pipeline for applying window function is:
- define a window, use `Window` object under `pyspark.sql.window`
    - use `.partitionBy(col_names, ...)` to define the partitions
    - use `.orderBy()` to sort values within each partition
    - use `.rowsBeteen()` to define the criteria to generate the window. E.g. `.rowsBetween(Window.unboundedPreceding, Window.currentRow)` means
        - a window consist of all previous row -> current row (**within the same partition**)
    - above returns a `windowSpec` object

Some aggregation functions can be apply on the windowSpec object, for example:
- `mean(col_name).over(windowSpect)`: simply use `.over()`
- the function should have a string column name input instead of df.colname

There are also window functions such as `rank`, `dense_rank`:
- `rank().over(windowSpec)`
- **since this is a window function, the rank is per partition not the global rank**

A common window function pipeline is:
- create the `windowSpec`: `.partition(col_names, ...) -> .orderBy() -> `.rowsBetween(start, end)``
- apply functions on window to get the column object
- select the dataframe using the column object (it is a good pratice to sort the dataframe using the partition criteria for display the data)

[back to top](#directory)

In [35]:
from pyspark.sql.functions import to_date
df.show(1)
df.printSchema()

# convert the InvoiceDate from datetime to date
# to_date(col, format), format must be specified or it will not be able to recognize
dfWithDate = df.withColumn('InvoiceDate', to_date(col('InvoiceDate'), 'MM/d/yyyy H:mm'))
dfWithDate.show(3)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
only showing top 1 row

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: integer (nullable = true)
 |-- Country: string (nullable = true)

+---------+---------+--------------------+--------+-----------+---------+----------+--------------+
|InvoiceNo|StockCode|  

In [36]:
from pyspark.sql.window import Window
from pyspark.sql.functions import max, dense_rank, rank

# Step 1 - create window spec
# pipeline to define a window spec
# define partition -> order -> define range criteria
windowSpecPartitioned = Window.partitionBy('CustomerId', 'InvoiceDate')
windowSpecOrdered = windowSpecPartitioned.orderBy('Quantity')
windowSpec = windowSpecOrdered.rowsBetween(Window.unboundedPreceding, Window.currentRow)  # start, end

# Step 2 - apply function over the windowSpec to get the columns
# These are all transformations and won't execute right away
# When we call actions and execute these queries, it will match column names with the dataframe
rollingMaxQuantity = max(col('Quantity')).over(windowSpec)
quanRank = rank().over(windowSpec)
quanDenseRank = dense_rank().over(windowSpec)

# Step 3 - select columns 
# make sure to remove the nulls
dfWithDate.orderBy('CustomerId').select('CustomerId', 'InvoiceDate', 'Quantity',
                                rollingMaxQuantity.alias('Rolling Max Quantity'),
                                quanRank.alias('Rank'), quanDenseRank.alias('Dense Rank')).show(3)

# Step 3 - select columns 
# make sure to remove the nulls
dfWithDate.where('CustomerId is not null').orderBy('CustomerId').select('CustomerId', 'InvoiceDate', 'Quantity',
                                                                rollingMaxQuantity.alias('Rolling Max Quantity'),
                                                                quanRank.alias('Rank'), quanDenseRank.alias('Dense Rank')).show()



+----------+-----------+--------+--------------------+----+----------+
|CustomerId|InvoiceDate|Quantity|Rolling Max Quantity|Rank|Dense Rank|
+----------+-----------+--------+--------------------+----+----------+
|      null| 2010-12-01|     -10|                 -10|   1|         1|
|      null| 2010-12-01|       1|                   1|   2|         2|
|      null| 2010-12-01|       1|                   1|   2|         2|
+----------+-----------+--------+--------------------+----+----------+
only showing top 3 rows

+----------+-----------+--------+--------------------+----+----------+
|CustomerId|InvoiceDate|Quantity|Rolling Max Quantity|Rank|Dense Rank|
+----------+-----------+--------+--------------------+----+----------+
|     12346| 2011-01-18|  -74215|              -74215|   1|         1|
|     12346| 2011-01-18|   74215|               74215|   2|         2|
|     12347| 2010-12-07|       3|                   3|   1|         1|
|     12347| 2010-12-07|       4|                   

### 4. Group Sets <a id='sec4'></a>

Three types of group sets aggregation given (col1, col2, col3):
- group by (col1, col2, col3)
- roll up - hierarchically groups: (all, all, all), (col1, all, all), (col1, col2, all), (col1, col2, col3)
- cube - all combination groups: (all, all, all), (col1, all, all), (col2, all, all), (col3, all, all), (col1, all, col3), (all, col2, col3), (col1, col2, all), (col1, col2, col3)

Aggregation on multiple groups can be easily achieved by `df.groupBy(col1, col2, col3)`

**Roll Up (col1, col2, col3)** -> rollup col1 on the rest of the columns, gives us 4 levels:
- grand total
- sub total of each (col1) group
- sub total of each (col1, col2) group
- subtotal of each (col1, col2, col3) group

**Cube (col1, col2, col3)** -> all combination aggregation, gives us 8 level (below is not the actual order of levels, check the code example):
- grand total
- sub total of each (col1) group
- sub total of each (col2) group
- sub total of each (col3) group
- sub total of each (col1, col2) group
- sub total of each (col1, col3) group
- sub total of each (col2, col3) group
- sub total of each (col1, col2, col3) group

After the group set operations. We need to query the aggregated information. When doing the `.agg()`, we can introduce a `grouping_id()` function to introduce a column that indicate the level of aggregation.

[back to top](#directory)

In [37]:
from pyspark.sql.functions import desc

# groupBy, each group is defined by the (customerID, stockCode)
dfWithDate.groupBy('CustomerID', 'StockCode').agg(sum('Quantity')).orderBy(desc('CustomerID'), desc('StockCode')).show(3)    

+----------+---------+-------------+
|CustomerID|StockCode|sum(Quantity)|
+----------+---------+-------------+
|     18287|    85173|           48|
|     18287|   85040A|           48|
|     18287|   85039B|          120|
+----------+---------+-------------+
only showing top 3 rows



In [38]:
dfWithDate.count()

541909

In [39]:
dfWithDateNoNull = dfWithDate.na.drop()
dfWithDateNoNull.count()

406829

In [42]:
# rollup 
# after roll up, we can also use agg when needed
rolledDF = dfWithDateNoNull.rollup('CustomerID', 'StockCode', 'InvoiceDate').count()

# (null, null) is the sum over all rows
# (12345, null) is the sum over customerID = 12345 and all stockcode 
rolledDF.orderBy('CustomerID', 'StockCode').show(3)

# rollup on the left!!!!
# rollup col1 on col2, col3
# rollup col2 on col3
col1_count = rolledDF.where('CustomerID is null').count()
col2_count = rolledDF.where('StockCode is null').count()
col3_count = rolledDF.where('InvoiceDate is null').count()

# number of null means number of totals
# for col1, only one grand total
# for col2, 4372 sub-totals of each (col1, col2) pairs + 1 grand total
# for col3, 271988 include all above + subtotals of each (col1, col2, col3) triplet
print(col1_count, col2_count, col3_count)  

+----------+---------+-----------+------+
|CustomerID|StockCode|InvoiceDate| count|
+----------+---------+-----------+------+
|      null|     null|       null|406829|
|     12346|     null|       null|     2|
|     12346|    23166| 2011-01-18|     2|
+----------+---------+-----------+------+
only showing top 3 rows

1 4373 271988


In [43]:
from pyspark.sql.functions import avg
# cube
dfWithDateNoNull.cube('CustomerID', 'StockCode', 'InvoiceDate').agg(avg('Quantity')).orderBy('CustomerID').show(5)

# more levels
dfWithDateNoNull.cube('CustomerID', 'StockCode', 'InvoiceDate').agg(avg('Quantity')).count()

+----------+---------+-----------+------------------+
|CustomerID|StockCode|InvoiceDate|     avg(Quantity)|
+----------+---------+-----------+------------------+
|      null|    22451| 2010-12-01|3.3333333333333335|
|      null|    22554| 2010-12-01|              13.0|
|      null|    20982|       null| 8.955555555555556|
|      null|    22147| 2010-12-01|18.666666666666668|
|      null|    22595| 2010-12-01|              58.0|
+----------+---------+-----------+------------------+
only showing top 5 rows



913656

In [44]:
from pyspark.sql.functions import grouping_id

# higher id means higher level of aggregation
# the highest id means the grand total
dfCubed = dfWithDateNoNull.cube('CustomerID', 'StockCode', 'InvoiceDate').agg(grouping_id().alias('level'), avg('Quantity'), sum('Quantity'))  
dfCubed.orderBy('CustomerID').show(4)

# to get the grand total, we just need to query on the record with the highest level
# cube will produce 8 levels, hence the highest level is 7
dfCubed.where('level == 7').show()  # bingo!

# lets check each level
for i in range(8):
    query = 'level == {}'.format(i)
    print('Level {}'.format(i))
    dfCubed.where(query).show(1)

+----------+---------+-----------+-----+------------------+-------------+
|CustomerID|StockCode|InvoiceDate|level|     avg(Quantity)|sum(Quantity)|
+----------+---------+-----------+-----+------------------+-------------+
|      null|    21527|       null|    5| 2.729559748427673|          434|
|      null|    22437|       null|    5|15.376404494382022|         2737|
|      null|    21390| 2010-12-02|    4|              24.0|           24|
|      null|    20992| 2010-12-01|    4|               9.0|            9|
+----------+---------+-----------+-----+------------------+-------------+
only showing top 4 rows

+----------+---------+-----------+-----+-----------------+-------------+
|CustomerID|StockCode|InvoiceDate|level|    avg(Quantity)|sum(Quantity)|
+----------+---------+-----------+-----+-----------------+-------------+
|      null|     null|       null|    7|12.06130339774205|      4906888|
+----------+---------+-----------+-----+-----------------+-------------+

Level 0
+--------

### 5. Pivot <a id='sec5'></a>

Pivot brings the aggregated values of some feature to columns and display them in a easy-to-equery way.

[back to top](#directory)

In [45]:
# For example, if we have a table
data = [('X', 'S', 1, 3), ('X', 'P', 1, 6), ('Y', 'P', 1, 8), ('Y', 'P', 1, 10), ('Y', 'S', 1, 0)]
col_names = ['A', 'B', 'C', 'D']
dfExample = spark.createDataFrame(data, col_names)
dfExample.show()

+---+---+---+---+
|  A|  B|  C|  D|
+---+---+---+---+
|  X|  S|  1|  3|
|  X|  P|  1|  6|
|  Y|  P|  1|  8|
|  Y|  P|  1| 10|
|  Y|  S|  1|  0|
+---+---+---+---+



In [46]:
# if our task demands frequent query things like sum(C) for all A == X , B == D.
# we could simple create a aggregated table for the queries
dfExampleAgg = dfExample.groupBy('A', 'B').agg(sum(col('C')))
dfExampleAgg.show()

# queries would be
dfExampleAgg.where((col('A') == 'X') & (col('B') == 'P')).show()

# anther way is to pivot - now the original B col data 'S', 'D' are columns
# since B col have two possible values S, D and we have two numeric column C, D
# we aggregate using sum
# the pivoting will create 2x2 = 4 pivoted columns
pivotExample = dfExample.groupBy('A').pivot('B').sum()
pivotExample.show()

# queries would be
pivotExample.where(col('A') == 'X').select('P_sum(C)').show()

+---+---+------+
|  A|  B|sum(C)|
+---+---+------+
|  X|  S|     1|
|  Y|  P|     2|
|  Y|  S|     1|
|  X|  P|     1|
+---+---+------+

+---+---+------+
|  A|  B|sum(C)|
+---+---+------+
|  X|  P|     1|
+---+---+------+

+---+--------+--------+--------+--------+
|  A|P_sum(C)|P_sum(D)|S_sum(C)|S_sum(D)|
+---+--------+--------+--------+--------+
|  Y|       2|      18|       1|       0|
|  X|       1|       6|       1|       3|
+---+--------+--------+--------+--------+

+--------+
|P_sum(C)|
+--------+
|       1|
+--------+



In [47]:
dfWithDate.show(3)
dfWithDate.printSchema()
dfWithDate = dfWithDate.withColumn('Quantity', col('Quantity').cast('long'))
dfWithDate = dfWithDate.withColumn('CustomerID', col('CustomerID').cast('long'))
dfWithDate.groupBy('InvoiceDate').pivot('Country').sum().printSchema()
dfWithDate.groupBy('InvoiceDate').pivot('Country').sum().show(3)

+---------+---------+--------------------+--------+-----------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+-----------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6| 2010-12-01|     2.55|     17850|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6| 2010-12-01|     3.39|     17850|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8| 2010-12-01|     2.75|     17850|United Kingdom|
+---------+---------+--------------------+--------+-----------+---------+----------+--------------+
only showing top 3 rows

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: date (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: integ