# Window functions

## Task 1:

* split users to 3 groups according to number of distinct badges they have (less than 50, 50-150, more than 150)
* if a user has one badge more times consider only the first date
* for each user compute avg time between two badges
 * compute avg for each 3 groups

Note
* In this task you will use:
 * window functions
 * aggregations
 * when condition
 * filtering
 * time manipulation

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, when, count, row_number, lit, unix_timestamp, lead, avg, ceil
)

from pyspark.sql import Window

import os

In [None]:
spark = (
    SparkSession
    .builder
    .appName('WF II')
    .getOrCreate()
)

In [None]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-3]) 

data_input_path = os.path.join(project_path, 'data/badges')

In [None]:
badgesDF = (
    spark
    .read
    .parquet(data_input_path)
)

In [None]:
badgesDF.printSchema()

<b>Prepare data</b>

Hint:
* select only first badge occurence (if there are multiple)
 * creat a window per user_id, badge and order by date
 * use row_number and filter only first row
* compute number of badges for each user
 * create another window per user_id
 * use count over this window
* add column 'category'
 * using when condition create 3 values depending on the badges count

In [None]:
w1 = Window().partitionBy('user_id', 'name').orderBy('date')
w2 = Window().partitionBy('user_id')

badges_transformed = (
    badgesDF
    .withColumn('r', row_number().over(w1))
    .filter(col('r') == 1)
    .withColumn('badges', count('*').over(w2))
    .withColumn(
        'category', 
        when(col('badges') < 50, lit(3))
        .when(col('badges').between(50, 150), lit(2))
        .otherwise(lit(1))
    )
).cache()

In [None]:
badges_transformed.count()

In [None]:
badges_transformed.orderBy(('category')).show(truncate=False)

<b>Compute the time between two badges</b>

Hint:
* create new window per user_id ordered by date
* use lead function to get the next date
* use unix_timestamp
* groupBy category
* compute the average
* convert seconds to days
* round (ceil) to whole days

In [None]:
w3 = Window().partitionBy('user_id').orderBy('date')

(
    badges_transformed
    .withColumn('next_badge', lead('date').over(w3))
    .filter(col('next_badge').isNotNull())
    .withColumn('diff', unix_timestamp(col('next_badge')) - unix_timestamp(col('date')))
    .groupBy('category')
    .agg(
        (ceil(avg('diff') / 3600 / 24)).alias('avg_diff')  # converted to days
    )
    .orderBy('category')
).show(truncate=False)

In [None]:
spark.stop()