## Examples for PySpark UDFs
Ref : https://changhsinlee.com/pyspark-udf/

In [1]:
# Connect to Spark by creating a Spark session
from pyspark.sql import SparkSession

spark = SparkSession\
    .builder\
    .appName("PySparkUDF")\
    .getOrCreate()

In [2]:
import pandas as pd
from pyspark.sql.functions import udf

In [3]:
# example data
df_pd = pd.DataFrame(
    data={'integers': [1, 2, 3], 
     'floats': [-1.0, 0.5, 2.7],
     'integer_arrays': [[1, 2], [3, 4, 5], [6, 7, 8, 9]]}
)

In [4]:
df_pd

Unnamed: 0,floats,integer_arrays,integers
0,-1.0,"[1, 2]",1
1,0.5,"[3, 4, 5]",2
2,2.7,"[6, 7, 8, 9]",3


In [5]:
df = spark.createDataFrame(df_pd)

In [6]:
df.printSchema()

root
 |-- floats: double (nullable = true)
 |-- integer_arrays: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- integers: long (nullable = true)



In [7]:
df.show()

+------+--------------+--------+
|floats|integer_arrays|integers|
+------+--------------+--------+
|  -1.0|        [1, 2]|       1|
|   0.5|     [3, 4, 5]|       2|
|   2.7|  [6, 7, 8, 9]|       3|
+------+--------------+--------+



## Turn python function into a Spark function

### For numeric types

In [8]:
def square(x):
    return  x**2

In Python, this function can be used for both integer and float

In [9]:
square(5)

25

In [10]:
square(3.0)

9.0

In [11]:
# Integer type output
from pyspark.sql.types import IntegerType
square_udf_int = udf(lambda z: square(z), IntegerType())

Spark UDF will return a column of NULLs if the input data type doesn’t match the output data type

In [12]:
df.select('integers', 
              'floats', 
              square_udf_int('integers').alias('int_squared'), 
              square_udf_int('floats').alias('float_squared')).show()

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|          1|         null|
|       2|   0.5|          4|         null|
|       3|   2.7|          9|         null|
+--------+------+-----------+-------------+



In [13]:
# float type output
from pyspark.sql.types import FloatType
square_udf_float = udf(lambda z: square(z), FloatType())

In [14]:
df.select('integers', 
              'floats', 
              square_udf_float('integers').alias('int_squared'), 
              square_udf_float('floats').alias('float_squared')) .show()

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|       null|          1.0|
|       2|   0.5|       null|         0.25|
|       3|   2.7|       null|         7.29|
+--------+------+-----------+-------------+



In [15]:
# Force the output to be float
def square_float(x):
    return float(x**2)
square_udf_float2 = udf(lambda z: square_float(z), FloatType())

In [16]:
df.select('integers', 
              'floats', 
              square_udf_float2('integers').alias('int_squared'), 
              square_udf_float2('floats').alias('float_squared')) .show()

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|        1.0|          1.0|
|       2|   0.5|        4.0|         0.25|
|       3|   2.7|        9.0|         7.29|
+--------+------+-----------+-------------+



### Array types, or lists
When the input of the Python function is a list, then the values in the list have to be of the same type.

In [17]:
def square_list(x):
    return [float(val)**2 for val in x]

In [18]:
from pyspark.sql.types import ArrayType
square_list_udf = udf(lambda y: square_list(y), ArrayType(FloatType()))

In [19]:
df.select('integer_arrays', square_list_udf('integer_arrays')).show()

+--------------+------------------------+
|integer_arrays|<lambda>(integer_arrays)|
+--------------+------------------------+
|        [1, 2]|              [1.0, 4.0]|
|     [3, 4, 5]|       [9.0, 16.0, 25.0]|
|  [6, 7, 8, 9]|    [36.0, 49.0, 64.0...|
+--------------+------------------------+



## Struct types, or tuples

If the Python function returns a tuple, then we can use the Struct data type for PySpark.

In [20]:
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType
from pyspark.sql.types import StringType

A StructType contains StructField, each has a specific data type.

In [21]:
array_schema = StructType([
    StructField('number', IntegerType(), nullable=False),
    StructField('letters', StringType(), nullable=False)
])

Python Func

In [22]:
import string

# takes a number and returns a letter from ascii_letters
def convert_ascii(number):
    return [number, string.ascii_letters[number]]

convert_ascii(1)

[1, 'b']

In [23]:
spark_convert_ascii = udf(lambda z: convert_ascii(z), array_schema)

In [24]:
df_ascii = df.select('integers', spark_convert_ascii('integers').alias('ascii_map'))

In [25]:
df_ascii.printSchema()

root
 |-- integers: long (nullable = true)
 |-- ascii_map: struct (nullable = true)
 |    |-- number: integer (nullable = false)
 |    |-- letters: string (nullable = false)



In [26]:
df_ascii.show()

+--------+---------+
|integers|ascii_map|
+--------+---------+
|       1|   [1, b]|
|       2|   [2, c]|
|       3|   [3, d]|
+--------+---------+



## Wrong type example
Spark doesn't know what to do when the Python function returns numpy objects

In [27]:
import numpy as np

In [28]:
d_np = pd.DataFrame({'int_arrays': [[1,2,3], [4,5]]})
df_np = spark.createDataFrame(d_np)
df_np.show()

+----------+
|int_arrays|
+----------+
| [1, 2, 3]|
|    [4, 5]|
+----------+



In [29]:
# squares with a numpy function, which returns a np.ndarray
def square_array_wrong(x):
    return np.square(x)

In [30]:
square_array_wrong([1,2,3])

array([1, 4, 9])

In [31]:
spark_square_array_wrong = udf(square_array_wrong, ArrayType(FloatType()))

In [32]:
df_np.withColumn('doubled', spark_square_array_wrong('int_arrays')).show()

Py4JJavaError: An error occurred while calling o175.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 15.0 failed 1 times, most recent failure: Lost task 0.0 in stage 15.0 (TID 29, localhost, executor driver): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:86)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:85)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1602)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1590)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1589)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1589)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1823)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1772)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1761)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:642)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2034)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2055)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2074)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:363)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3273)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2484)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2484)
	at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3254)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3253)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2484)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2698)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:254)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:86)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:85)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	... 1 more


We have to make sure the numbers are returned as a list, with values being native Python types:

In [33]:
def square_array_right(x):
    return np.square(x).tolist()

In [34]:
spark_square_array_right = udf(square_array_right, ArrayType(IntegerType()))

Now it returns desired values.

In [35]:
zz = df_np.withColumn('squared', spark_square_array_right('int_arrays'))
zz.show()

+----------+---------+
|int_arrays|  squared|
+----------+---------+
| [1, 2, 3]|[1, 4, 9]|
|    [4, 5]| [16, 25]|
+----------+---------+

