In [63]:
import findspark
findspark.init()

In [79]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, desc, count, explode, split, regexp_replace, collect_list, array_sort, reverse, unix_timestamp, row_number,
    when, lit, lead, avg, pandas_udf, PandasUDFType
)

from pyspark.sql import Window

from pyspark.sql.types import StructType, StructField, StringType, LongType, TimestampType, DoubleType, IntegerType

import os
import pandas as pd

In [65]:
spark = (
    SparkSession
    .builder
    .appName('UDFs II')
    .getOrCreate()
)

# Task 1

* compute avg time between two consecutive answers for each user that answered at least 2 questions

In [66]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-2]) 

data_input_path = os.path.join(project_path, 'data/answers')

In [67]:
answersDF = (
    spark
    .read
    .option('path', data_input_path)
    .load()
).cache()

In [68]:
w = (
    Window()
    .partitionBy('user_id')
    .orderBy('creation_date')
    .rowsBetween(Window().unboundedPreceding, Window().unboundedFollowing)
)

data = (
    answersDF
    .withColumn('r', count('*').over(w))
    .filter(col('r') > 1)
)

In [69]:
data.count()

175943

In [70]:
schema = StructType(
    [
        StructField('answer_id', LongType()),
        StructField('creation_date', TimestampType()),
        StructField('body', StringType()),
        StructField('comments', LongType()),
        StructField('user_id', LongType()),
        StructField('score', LongType()),
        StructField('question_id', LongType()),
        StructField('result', DoubleType())
    ]
)

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def compute_avg_response(pdf):
    pdf['result'] = pdf.sort_values(by=['creation_date']).creation_date.diff().dt.total_seconds().mean()
    return pdf

In [None]:
resultDF = (
    data.drop('r')
    .groupBy('user_id')
    .apply(compute_avg_response)
    .select('user_id', 'result')
    .dropDuplicates()
)

In [None]:
resultDF.orderBy('result').show()

<b>Verify that the result makes sense:</b>

In [None]:
(
    answersDF
    .filter(col('user_id') == 142017)
    .withColumn('t', unix_timestamp('creation_date'))
    .select('creation_date', 't')
    .orderBy('creation_date')
).show(truncate=False)

In [62]:
spark.stop()