In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import LongType, StructField, StructType
import pyspark.pandas as ps
from pyspark.sql import functions as F
import pandas as pd

In [4]:
builder = SparkSession.builder.appName("pandas-on-spark")
builder = builder.config("spark.sql.execution.arrow.pyspark.enabled", "true")
# Pandas API on Spark automatically uses this Spark session with the configurations set.
spark = builder.getOrCreate()

In [5]:
LENGTH = 1000000

## Test pandas on pyspark schemas

In [6]:
psdf = ps.DataFrame({"id": range(LENGTH), "value": range(10, LENGTH+10), "a": range(LENGTH,LENGTH*2)})

In [7]:
print(psdf.spark.schema())

StructType([StructField('id', LongType(), False), StructField('value', LongType(), False), StructField('a', LongType(), False)])


In [8]:
# psdf.info()

In [9]:
sdf = psdf.to_spark()
sdf.printSchema()

root
 |-- id: long (nullable = false)
 |-- value: long (nullable = false)
 |-- a: long (nullable = false)





In [10]:
schema_new = StructType(
    [StructField("id", LongType(), True), StructField("value", LongType(), True), StructField('a', LongType(), False)]
)
sdf_new_schema = spark.createDataFrame(sdf.rdd, schema_new)

In [11]:
sdf_new_schema.printSchema()

root
 |-- id: long (nullable = true)
 |-- value: long (nullable = true)
 |-- a: long (nullable = false)



In [12]:
psdf_new_schema = sdf_new_schema.pandas_api()
psdf_new_schema.spark.schema()

StructType([StructField('id', LongType(), True), StructField('value', LongType(), True), StructField('a', LongType(), False)])

In [13]:
# psdf_new_schema.info()

## Select and filter columns to pandas on pyspark df

The goal is to compare the performance of three different approaches: native pyspark, pandas on pyspark and pandas with pyarrow. For this we do the same transformation in each approach and check the execution plan as well as the time taken to execute (on a larger dataset on databricks).

The transformations should be two filters which could be optimized and one complex filter.

In [17]:
def multi_when(x: int) -> int: 
    return 100 if x < LENGTH/20 else (50 if x < LENGTH/2 else 0)

In [18]:
def compute_pandas():
    pdf = pd.DataFrame({"id": range(LENGTH), "value": range(10, LENGTH+10), "a": range(LENGTH,LENGTH*2)}) 
    pdf = pdf.loc[pdf['id'] < LENGTH-LENGTH/10]
    pdf["new_value"] = pdf["id"].transform(multi_when)
    pdf = pdf.loc[pdf['id'] < LENGTH-LENGTH/5]
    pdf = pdf.loc[:, ['new_value', 'id']]
    return pdf

In [19]:
pdf = compute_pandas()
pdf

Unnamed: 0,new_value,id
0,100,0
1,100,1
2,100,2
3,100,3
4,100,4
...,...,...
799995,0,799995
799996,0,799996
799997,0,799997
799998,0,799998


In [20]:
psdf_select_filter = ps.DataFrame({"id": range(LENGTH), "value": range(10, LENGTH+10), "a": range(LENGTH,LENGTH*2)})

The [apply and transform](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/transform_apply.html) functions are working on pandas series and thus have the same effect as using pandas udfs.

In [21]:
psdf_select_filter = psdf_select_filter.loc[psdf_select_filter['id'] < (LENGTH - LENGTH/10)]
psdf_select_filter["new_value"] = psdf_select_filter["id"].transform(multi_when)
psdf_select_filter = psdf_select_filter.loc[psdf_select_filter['id'] < (LENGTH - LENGTH/5)]
psdf_select_filter = psdf_select_filter.loc[:, ['new_value', 'id']]

In [22]:
psdf_select_filter.spark.explain(mode='formatted')

== Physical Plan ==
* Project (4)
+- ArrowEvalPython (3)
   +- * Project (2)
      +- * LocalTableScan (1)


(1) LocalTableScan [codegen id : 1]
Output [5]: [__index_level_0__#50L, id#51L, value#52L, a#53L, __natural_order__#76L]
Arguments: [__index_level_0__#50L, id#51L, value#52L, a#53L, __natural_order__#76L]

(2) Project [codegen id : 1]
Output [2]: [__index_level_0__#50L, id#51L]
Input [5]: [__index_level_0__#50L, id#51L, value#52L, a#53L, __natural_order__#76L]

(3) ArrowEvalPython
Input [2]: [__index_level_0__#50L, id#51L]
Arguments: [pudf(__index_level_0__#50L, id#51L)#84L], [pythonUDF0#112L], 200

(4) Project [codegen id : 2]
Output [3]: [__index_level_0__#50L, pythonUDF0#112L AS new_value#87L, id#51L]
Input [3]: [__index_level_0__#50L, id#51L, pythonUDF0#112L]




In [23]:
psdf_select_filter.count()

24/08/01 12:40:41 WARN TaskSetManager: Stage 0 contains a task of very large size (3181 KiB). The maximum recommended task size is 1000 KiB.
24/08/01 12:40:50 WARN TaskMemoryManager: Failed to allocate a page (1048560 bytes), try again.
24/08/01 12:40:50 ERROR Utils: Uncaught exception in thread stdout writer for python3
java.lang.OutOfMemoryError: Java heap space
Exception in thread "stdout writer for python3" 24/08/01 12:40:50 ERROR Utils: Uncaught exception in thread stdout writer for python3
java.lang.OutOfMemoryError: Java heap space
java.lang.OutOfMemoryError: Java heap space
Exception in thread "stdout writer for python3" java.lang.OutOfMemoryError: Java heap space
24/08/01 12:40:50 ERROR Utils: Uncaught exception in thread stdout writer for python3
java.lang.OutOfMemoryError: Java heap space
Exception in thread "stdout writer for python3" 24/08/01 12:40:50 ERROR Utils: Uncaught exception in thread stdout writer for python3
java.lang.OutOfMemoryError: Java heap space
Exception i

ConnectionRefusedError: [Errno 61] Connection refused

ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "/Users/mzwiesl/Repos/Free/pyspark-examples/.venv/lib/python3.10/site-packages/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/mzwiesl/Repos/Free/pyspark-examples/.venv/lib/python3.10/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/Users/mzwiesl/Repos/Free/pyspark-examples/.venv/lib/python3.10/site-packages/py4j/clientserver.py", line 539, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving


In [127]:
%timeit psdf_select_filter

7.75 ns ± 0.0349 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


## Doing the same for native pyspark

In [24]:
sdf_select_filter = ps.DataFrame({"id": range(100), "value": range(10, 110), "a": range(100,200)}).to_spark()

ConnectionRefusedError: [Errno 61] Connection refused

In [180]:
sdf_select_filter.printSchema()

root
 |-- id: long (nullable = false)
 |-- value: long (nullable = false)
 |-- a: long (nullable = false)



In [184]:
sdf_select_filter = sdf_select_filter.filter(F.col("id") < LENGTH - LENGTH/10)
sdf_select_filter = sdf_select_filter.withColumn("new_value", F.when(F.col("id") < 40, 100).when(F.col("id") < 60, 50).otherwise(0))
sdf_select_filter = sdf_select_filter.filter(F.col("id") < LENGTH - LENGTH/5)
sdf_select_filter = sdf_select_filter.select("id", "new_value")

In [185]:
sdf_select_filter.explain(mode='formatted')

== Physical Plan ==
LocalTableScan (1)


(1) LocalTableScan
Output [2]: [id#7267L, new_value#7857]
Arguments: [id#7267L, new_value#7857]




In [186]:
%timeit sdf_select_filter.count()

22.9 ms ± 493 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Doing the same without arrow

In [188]:
sdf_select_filter2 = ps.DataFrame({"id": range(LENGTH), "value": range(10, LENGTH+10), "a": range(LENGTH,LENGTH*2)}).to_spark()



In [189]:
multi_when_udf = F.udf(multi_when, LongType())

In [190]:
sdf_select_filter2 = sdf_select_filter2.filter(F.col("id") < 90)
sdf_select_filter2 = sdf_select_filter2.withColumn("new_value", multi_when_udf(F.col("id")))
sdf_select_filter2 = sdf_select_filter2.filter(F.col("id") < 80)
sdf_select_filter2 = sdf_select_filter2.select("id", "new_value")

In [191]:
sdf_select_filter2.explain(mode='formatted')

Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.api.python.PythonSQLUtils.explainString.
: java.lang.OutOfMemoryError: Java heap space


In [130]:
%timeit sdf_select_filter2.show()

+---+---------+
| id|new_value|
+---+---------+
|  0|      100|
|  1|      100|
|  2|      100|
|  3|      100|
|  4|      100|
|  5|      100|
|  6|      100|
|  7|      100|
|  8|      100|
|  9|      100|
| 10|      100|
| 11|      100|
| 12|      100|
| 13|      100|
| 14|      100|
| 15|      100|
| 16|      100|
| 17|      100|
| 18|      100|
| 19|      100|
+---+---------+
only showing top 20 rows

+---+---------+
| id|new_value|
+---+---------+
|  0|      100|
|  1|      100|
|  2|      100|
|  3|      100|
|  4|      100|
|  5|      100|
|  6|      100|
|  7|      100|
|  8|      100|
|  9|      100|
| 10|      100|
| 11|      100|
| 12|      100|
| 13|      100|
| 14|      100|
| 15|      100|
| 16|      100|
| 17|      100|
| 18|      100|
| 19|      100|
+---+---------+
only showing top 20 rows

+---+---------+
| id|new_value|
+---+---------+
|  0|      100|
|  1|      100|
|  2|      100|
|  3|      100|
|  4|      100|
|  5|      100|
|  6|      100|
|  7|      100|
|  8