In [1]:
import findspark
findspark.init()

import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

In [2]:
df = spark.sql('''select 'spark' as hello''')
df.show()

+-----+
|hello|
+-----+
|spark|
+-----+



In [3]:
import pandas as pd

In [4]:
pdf = pd.DataFrame(
    data={
            'integers':[1,2,3],
            'floats': [-1.0, 0.5, 2.7],
            'integer_arrays': [[1,2], [3,4,5], [6,7,8,9]]
    }
)

In [5]:
pdf

Unnamed: 0,integers,floats,integer_arrays
0,1,-1.0,"[1, 2]"
1,2,0.5,"[3, 4, 5]"
2,3,2.7,"[6, 7, 8, 9]"


In [6]:
df = spark.createDataFrame(pdf)
df.printSchema()

root
 |-- integers: long (nullable = true)
 |-- floats: double (nullable = true)
 |-- integer_arrays: array (nullable = true)
 |    |-- element: long (containsNull = true)



In [7]:
df.show()

+--------+------+--------------+
|integers|floats|integer_arrays|
+--------+------+--------------+
|       1|  -1.0|        [1, 2]|
|       2|   0.5|     [3, 4, 5]|
|       3|   2.7|  [6, 7, 8, 9]|
+--------+------+--------------+



In [8]:
def square(x):
    return x**2

In [9]:
#User-defined function written to produce Integer type output
#results in NULLs when it produces floats instead.
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
square_udf_int = udf(lambda z: square(z), IntegerType())

In [10]:
(
    df.select(
        'integers',
        'floats',
        square_udf_int('integers').alias('int_squared'),
        square_udf_int('floats').alias('float_squared')
        )
    .show()
)

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|          1|         null|
|       2|   0.5|          4|         null|
|       3|   2.7|          9|         null|
+--------+------+-----------+-------------+



In [11]:
from pyspark.sql.types import FloatType
square_udf_float = udf(lambda z: square(z), FloatType())

In [12]:
(
    df.select(
        'integers',
        'floats',
        square_udf_float('integers').alias('int_squared'),
        square_udf_float('floats').alias('float_squared')
        )
    .show()
)

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|       null|          1.0|
|       2|   0.5|       null|         0.25|
|       3|   2.7|       null|         7.29|
+--------+------+-----------+-------------+



In [13]:
#Specify the type in the python function to be safe
def square_float(x):
    return float(x**2)
square_udf_float2 = udf(lambda z: square_float(z), FloatType())

In [14]:
(
    df.select(
        'integers',
        'floats',
        square_udf_float2('integers').alias('int_squared_float'),
        square_udf_float2('floats').alias('float_squared')
        )
    .show()
)

+--------+------+-----------------+-------------+
|integers|floats|int_squared_float|float_squared|
+--------+------+-----------------+-------------+
|       1|  -1.0|              1.0|          1.0|
|       2|   0.5|              4.0|         0.25|
|       3|   2.7|              9.0|         7.29|
+--------+------+-----------------+-------------+



In [15]:
#For lists, all the types in the list have to be the same,and specified within ArrayType()
from pyspark.sql.types import ArrayType

def square_list_float(x):
    return [float(val**2) for val in x]
square_list_float_udf = udf(lambda y: square_list_float(y), ArrayType(FloatType()))

In [16]:
(
    df.select(
        'integer_arrays',
        square_list_float_udf('integer_arrays').alias('list_squared_float')
    )
    .show()
)

+--------------+--------------------+
|integer_arrays|  list_squared_float|
+--------------+--------------------+
|        [1, 2]|          [1.0, 4.0]|
|     [3, 4, 5]|   [9.0, 16.0, 25.0]|
|  [6, 7, 8, 9]|[36.0, 49.0, 64.0...|
+--------------+--------------------+



In [17]:
#For functions that return a tuple of mixed types, we have to make a corresponding StructType()
#E.g. a function that returns the position and the letter from ascii_letters
import string

def convert_ascii(number):
    return [number, string.ascii_letters[number]]

In [18]:
print(string.ascii_letters)
convert_ascii(2)

abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ


[2, 'c']

In [19]:
from pyspark.sql.types import StructType, StructField, StringType

array_schema = StructType([
    StructField('number', IntegerType(), nullable=False),
    StructField('letters', StringType(), nullable=False)
])

convert_ascii_udf = udf(lambda z: convert_ascii(z), array_schema)

In [20]:
df_ascii = df.select(
        'integers',
        convert_ascii_udf('integers').alias('ascii_map')
    )
df_ascii.show()
df_ascii.printSchema()

+--------+---------+
|integers|ascii_map|
+--------+---------+
|       1|   [1, b]|
|       2|   [2, c]|
|       3|   [3, d]|
+--------+---------+

root
 |-- integers: long (nullable = true)
 |-- ascii_map: struct (nullable = true)
 |    |-- number: integer (nullable = false)
 |    |-- letters: string (nullable = false)



In [21]:
#using non-native types like nparrays will produce Py4JJavaError
import numpy as np

def square_array_wrong(x):
    return np.square(x)

square_array_wrong_udf = udf(square_array_wrong, ArrayType(FloatType()))

In [22]:
(
    df.withColumn(
        'squared',
        square_array_wrong_udf('integer_arrays')
    )
    .show()
)

Py4JJavaError: An error occurred while calling o141.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 20.0 failed 1 times, most recent failure: Lost task 1.0 in stage 20.0 (TID 51, localhost, executor driver): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:86)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:85)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	at java.lang.Thread.run(Unknown Source)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1602)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1590)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1589)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1589)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1823)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1772)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1761)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:642)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2034)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2055)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2074)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:363)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3273)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2484)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2484)
	at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3254)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3253)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2484)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2698)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:254)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
	at java.lang.reflect.Method.invoke(Unknown Source)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Unknown Source)
Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:86)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:85)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	... 1 more


In [23]:
#convert your output back to python primatives and you can use any non-native types you like in-between
def square_array_right(x):
    return np.square(x).tolist()

square_array_right_udf = udf(square_array_right, ArrayType(IntegerType()))

In [24]:
%%time
(
    df.withColumn(
        'squared',
        square_array_right_udf('integer_arrays')
    )
    .show()
)

+--------+------+--------------+----------------+
|integers|floats|integer_arrays|         squared|
+--------+------+--------------+----------------+
|       1|  -1.0|        [1, 2]|          [1, 4]|
|       2|   0.5|     [3, 4, 5]|     [9, 16, 25]|
|       3|   2.7|  [6, 7, 8, 9]|[36, 49, 64, 81]|
+--------+------+--------------+----------------+

Wall time: 7.04 s


In [25]:
#When your dataframe is too small, Spark sends the entire dataframe to only one executor
#The workaround is to repartition the dataframe before calling the UDF.

df_repartitioned = df.repartition(3)

In [26]:
%%time
(
    df_repartitioned.withColumn(
        'squared',
        square_array_right_udf('integer_arrays')
    )
    .show()
)

+--------+------+--------------+----------------+
|integers|floats|integer_arrays|         squared|
+--------+------+--------------+----------------+
|       2|   0.5|     [3, 4, 5]|     [9, 16, 25]|
|       1|  -1.0|        [1, 2]|          [1, 4]|
|       3|   2.7|  [6, 7, 8, 9]|[36, 49, 64, 81]|
+--------+------+--------------+----------------+

Wall time: 4.97 s


If you are deploying the code to a cluster, you need to submit the packages/scripts (say my_package.py) together with your PySpark script (say myPySparkScript.py) in order to import those packages on a cluster. It goes like `spark-submit [some config stuff] --py-files my_package.py myPySparkScript.py`