# Window functions

## Task 1:

* split users to 3 groups according to number of distinct badges they have (less than 50, 50-150, more than 150)
* if a user has one badge more times consider only the first date
* for each user compute avg time between two badges
* compute avg for each 3 groups

Note
* In this task you will use
 * window functions
 * aggregations
 * when condition
 * filtering
 * time manipulation

In [None]:
import findspark
findspark.init()

In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, desc, count, explode, split, regexp_replace, collect_list, array_sort, reverse, unix_timestamp, row_number,
    when, lit, lead, avg
)

from pyspark.sql import Window

from pyspark.sql.types import StructType, StructField, StringType, LongType, TimestampType

In [None]:
spark = (
    SparkSession
    .builder
    .appName('WF II')
    .getOrCreate()
)

In [None]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-2]) 

data_input_path = os.path.join(project_path, 'data/badges')

In [None]:
badgesDF = (
    spark
    .read
    .parquet(data_input_path)
)

In [None]:
badgesDF.printSchema()

In [None]:
w1 = Window().partitionBy('user_id', 'name').orderBy('date')
w2 = Window().partitionBy('user_id')

badges_transformed = (
    badgesDF
    .withColumn('r', row_number().over(w1))
    .filter(col('r') == 1)
    .withColumn('badges', count('*').over(w2))
    .withColumn(
        'category', 
        when(col('badges') < 50, lit(3))
        .when(col('badges').between(50, 150), lit(2))
        .otherwise(lit(1))
    )

).cache()

In [None]:
badges_transformed.count()

In [None]:
badges_transformed.orderBy(('category')).show(truncate=False)

In [None]:
w3 = Window().partitionBy('user_id').orderBy('date')

(
    badges_transformed
    .withColumn('next_badge', lead('date').over(w3))
    .filter(col('next_badge').isNotNull())
    .withColumn('diff', unix_timestamp(col('next_badge')) - unix_timestamp(col('date')))
    .groupBy('category')
    .agg(
        (avg('diff') / 3600 / 24).alias('avg_diff')  # converted to days
    )
    .orderBy('category')
).show(truncate=False)

In [None]:
spark.stop()