In [25]:
import findspark
findspark.init()

In [26]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, desc, count, explode, split, regexp_replace, collect_list, array_sort, reverse, unix_timestamp, row_number,
    when, lit, lead, avg, pandas_udf, PandasUDFType
)

from pyspark.sql import Window

from pyspark.sql.types import StructType, StructField, StringType, LongType, TimestampType, DoubleType, IntegerType

import os

In [27]:
spark = (
    SparkSession
    .builder
    .appName('UDFs')
    .getOrCreate()
)

# Task 1

* compute avg time between two consecutive answers for each user that answered at least 2 questions

In [28]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-2]) 

data_input_path = os.path.join(project_path, 'data/answers')

In [29]:
answersDF = (
    spark
    .read
    .option('path', data_input_path)
    .load()
).cache()

In [30]:
w = (
    Window()
    .partitionBy('user_id')
    .orderBy('creation_date')
    .rowsBetween(Window().unboundedPreceding, Window().unboundedFollowing)
)

data = (
    answersDF
    .withColumn('r', count('*').over(w))
    .filter(col('r') > 1)
)

In [31]:
data.count()

175943

In [56]:
schema = StructType(
    [
        StructField('answer_id', LongType()),
        StructField('creation_date', TimestampType()),
        StructField('body', StringType()),
        StructField('comments', LongType()),
        StructField('user_id', LongType()),
        StructField('score', LongType()),
        StructField('question_id', LongType()),
        StructField('result', DoubleType())
    ]
)

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def compute_avg_response(pdf):
    pdf['result'] = pdf.sort_values(by=['creation_date']).creation_date.diff().dt.total_seconds().mean()
    return pdf

In [57]:
resultDF = (
    data.drop('r')
    .groupBy('user_id')
    .apply(compute_avg_response)
    .select('user_id', 'result')
    .dropDuplicates()
)

In [58]:
resultDF.orderBy('result').show()

+-------+------------------+
|user_id|            result|
+-------+------------------+
| 142017|            134.81|
|  76258|146.60000000000002|
| 127301|           150.733|
|   7012|151.23000000000002|
|  47471|             154.8|
|  81035|154.85600000000002|
| 155429|164.44400000000002|
|  43758|179.00666666666666|
|  71344|            179.78|
|   4907|          180.8985|
|  59681|           187.866|
|  28339|197.38000000000002|
| 202454|197.92700000000002|
| 105605|           206.097|
| 128124|           210.497|
|  12918|           218.453|
| 174677|            221.15|
| 204225|           223.953|
|  98477|           234.288|
| 207428|234.48000000000002|
+-------+------------------+
only showing top 20 rows



<b>Verify that the result makes sense:</b>

In [59]:
(
    answersDF
    .filter(col('user_id') == 142017)
    .withColumn('t', unix_timestamp('creation_date'))
    .select('creation_date', 't')
    .orderBy('creation_date')
).show(truncate=False)

+-----------------------+----------+
|creation_date          |t         |
+-----------------------+----------+
|2017-01-14 09:59:54.697|1484384394|
|2017-01-14 10:02:09.507|1484384529|
+-----------------------+----------+

