# Pandas User Defined Functions

In the previous notebook we have seen how Pandas UDF can be used for statistical modelling, we defined UDF that took Pandas Series on the input and returned back also Panda Series.

Here we will define a Pandas UDF that will be used as an aggregation function after calling groupBy. The difference is that this UDF will also take Pandas Series on the input but it will return single number.

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, pandas_udf, avg, desc

import os

import pandas as pd

In [None]:
spark = (
    SparkSession
    .builder
    .appName('UDFs III')
    .getOrCreate()
)

In [None]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-3]) 

answers_input_path = os.path.join(project_path, 'data/answers')

### Task I

1. For each question compute the average <bold>score</bold> of its answers.
2. For each question compute the <bold>median</bold> score of its answers.

In [None]:
# we will need answers dataset:

answersDF = (
    spark
    .read
    .option('path', answers_input_path)
    .load()
)

* first compute the average - this is simple we can use `avg` aggregation function

In [None]:
(
    answersDF
    .groupBy('question_id')
    .agg(
        avg('score').alias('avg_score')
    )
    .orderBy(desc('avg_score'))
).show()

* now compute the median - this is more tricky since Spark doesn't provide native function for it
* we can use Pandas UDF for it
 * using the type hint annotation we can define and aggregation Pandas UDF
 * the input argument is pd.Series
 * the output argument will be int

In [None]:
@pandas_udf('int')
def median_udf(pd_s: pd.Series) -> int:
    return pd_s.median()

In [None]:
(
    answersDF
    .groupBy('question_id')
    .agg(
        median_udf('score').alias('median_score')
    )
    .orderBy(desc('median_score'))
).show()

For more information and details about various Pandas UDFs see the [docs](https://spark.apache.org/docs/3.1.1/api/python/user_guide/arrow_pandas.html).

In [None]:
spark.stop()