# Pandas UDFs and UDAFs in pyspark
Links: 
* https://spark.apache.org/docs/2.4.0/sql-pyspark-pandas-with-arrow.html
* https://danvatterott.com/blog/2018/09/06/python-aggregate-udfs-in-pyspark/
* https://databricks.com/blog/2017/10/30/introducing-vectorized-udfs-for-pyspark.html
* Funcionamiento de Spark + UDAFs con Pandas https://florianwilhelm.info/2017/10/efficient_udfs_with_pyspark/
* https://stackoverflow.com/questions/55506698/writing-custom-udaf-in-pyspark
* https://stackoverflow.com/questions/40006395/applying-udfs-on-groupeddata-in-pyspark-with-functioning-python-example/47497815#47497815
* https://stackoverflow.com/questions/32100973/how-to-define-and-use-a-user-defined-aggregate-function-in-spark-sql

* versions: https://stackoverflow.com/questions/51713705/python-pandas-udf-spark-error

---
Obs: Is important to have the right combination of version with these packages:

* numpy: 1.14.5
* pyarrow: 0.10.0
* pandas: 0.24.2
---

## UDFs

In [3]:
import pandas as pd

from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType

# Declare the function and create the UDF
def multiply_func(a, b):
    return a * b

multiply = pandas_udf(multiply_func, returnType=LongType())

# The function for a pandas_udf should be able to execute with local Pandas data
x = pd.Series([1, 2, 3])
print(multiply_func(x, x))

0    1
1    4
2    9
dtype: int64


In [4]:
# Create a Spark DataFrame, 'spark' is an existing SparkSession
df = spark.createDataFrame(pd.DataFrame(x, columns=["x"]))

# Execute function as a Spark vectorized UDF
df.select(multiply(col("x"), col("x"))).show()
#df.show()

+-------------------+
|multiply_func(x, x)|
+-------------------+
|                  1|
|                  4|
|                  9|
+-------------------+



---
## UDF - Grouped Map

In [20]:
from pyspark.sql.functions import pandas_udf, PandasUDFType

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v"))

df.show()

+---+----+
| id|   v|
+---+----+
|  1| 1.0|
|  1| 2.0|
|  2| 3.0|
|  2| 5.0|
|  2|10.0|
+---+----+



To use groupBy().apply(), the user needs to define the following:

* A *Python function* that defines the computation for each group.
* A *StructType object* or a string that defines the **schema** of the output DataFrame.

*Important: Note that all data for a group will be loaded into memory before the function is applied*

In [21]:
%%time
@pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
def subtract_mean(pdf):
    # pdf is a pandas.DataFrame
    v = pdf.v
    return pdf.assign(v=v - v.mean())

df.groupby("id").apply(subtract_mean).show()

+---+----+
| id|   v|
+---+----+
|  1|-0.5|
|  1| 0.5|
|  2|-3.0|
|  2|-1.0|
|  2| 4.0|
+---+----+

CPU times: user 16.3 ms, sys: 4.13 ms, total: 20.4 ms
Wall time: 966 ms


---
## UDAF - Grouped Aggregate

*Important: Note that this type of UDF does not support partial aggregation and all data for a group or window will be loaded into memory.*

In [18]:
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql import Window

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v"))

df.show()

+---+----+
| id|   v|
+---+----+
|  1| 1.0|
|  1| 2.0|
|  2| 3.0|
|  2| 5.0|
|  2|10.0|
+---+----+



In [19]:
%%time
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def mean_udf(v):
    return v.mean()

df.groupby("id").agg(mean_udf(df['v'])).show()

+---+-----------+
| id|mean_udf(v)|
+---+-----------+
|  1|        1.5|
|  2|        6.0|
+---+-----------+

CPU times: user 7.82 ms, sys: 12.2 ms, total: 20 ms
Wall time: 939 ms


In [12]:
w = Window \
    .partitionBy('id') \
    .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()

+---+----+------+
| id|   v|mean_v|
+---+----+------+
|  1| 1.0|   1.5|
|  1| 2.0|   1.5|
|  2| 3.0|   6.0|
|  2| 5.0|   6.0|
|  2|10.0|   6.0|
+---+----+------+



# Setting Arrow Batch Size

Data partitions in Spark are converted into Arrow record batches, which can temporarily lead to high memory usage in the JVM. *To avoid possible out of memory exceptions*, the size of the Arrow record batches can be adjusted by setting the conf **“spark.sql.execution.arrow.maxRecordsPerBatch”** to an integer that will determine the maximum number of rows for each batch. 

The default value is 10,000 records per batch. If the number of columns is large, the value should be adjusted accordingly. Using this limit, each data partition will be made into 1 or more record batches for processing.