### Create test data

In [1]:
import numpy as np
import pandas as pd
import threading
import time

from pyspark.sql.functions import col, pandas_udf, spark_partition_id, udf
from pyspark.sql.types import *
from typing import Iterator

In [2]:
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 1024)
spark.sparkContext.addPyFile("cache_rdd.py")
spark.sparkContext.addPyFile("cache_pudf1.py")
spark.sparkContext.addPyFile("cache_pudf2.py")

In [3]:
data = np.random.rand(100000,2)

In [4]:
pdf = pd.DataFrame(data,columns=['A', 'B'])
pdf

Unnamed: 0,A,B
0,0.501336,0.306995
1,0.540219,0.790344
2,0.693692,0.103611
3,0.366457,0.755743
4,0.672671,0.380752
...,...,...
99995,0.098867,0.541502
99996,0.705354,0.836383
99997,0.787510,0.526729
99998,0.338041,0.767096


In [5]:
df = spark.createDataFrame(pdf).repartition(100)
df.show(5)

[Stage 0:>                                                          (0 + 2) / 2]

+-------------------+-------------------+
|                  A|                  B|
+-------------------+-------------------+
| 0.1663382356059685|0.08492084724124538|
| 0.3210506256599649|0.16656356211939327|
| 0.7252191545392298|0.35212011267085386|
|0.04806785080230802| 0.8527132166182346|
| 0.6379519872241818| 0.9399026552058006|
+-------------------+-------------------+
only showing top 5 rows



                                                                                

In [6]:
rdd = df.rdd

In [7]:
N = None

In [8]:
delay = 5

### pandas_udf

In [9]:
@pandas_udf("float")
def addN1(x: pd.Series) -> pd.Series:
    global N
    if N:
        raise Exception("USING CACHE!!!")
        N = N + 1.0
    else:
        print("{} pandas_udf load".format(threading.get_ident()))
        N = 1.0
        time.sleep(delay)
        print("{} pandas_udf done".format(threading.get_ident()))
    
    print("type(x): {}".format(type(x)))
    return x + N

In [10]:
%%time
foo = df.select(addN1("A")).collect()

                                                                                

CPU times: user 317 ms, sys: 19.5 ms, total: 336 ms
Wall time: 37.3 s


In [11]:
foo[:5]

[Row(addN1(A)=1.216964602470398),
 Row(addN1(A)=1.4214831590652466),
 Row(addN1(A)=1.051087498664856),
 Row(addN1(A)=1.6599854230880737),
 Row(addN1(A)=1.542315125465393)]

### udf

In [12]:
@udf(returnType=FloatType())
def addN2(x):
    global N
    if N:
        # raise Exception("USING CACHE!!!")
        N = N + 1.0
    else:
        print("{} udf loading".format(threading.get_ident()))
        N = 1.0
        time.sleep(delay)
        print("{} udf done".format(threading.get_ident()))

    return x + N

In [13]:
%%time
foo = df.select(addN2("A")).collect()

                                                                                

CPU times: user 297 ms, sys: 40.9 ms, total: 338 ms
Wall time: 36.4 s


In [14]:
foo[:5]

[Row(addN2(A)=1.3568682670593262),
 Row(addN2(A)=2.5947766304016113),
 Row(addN2(A)=3.0666472911834717),
 Row(addN2(A)=4.1687726974487305),
 Row(addN2(A)=5.717076778411865)]

In [15]:
for i in range(5):
    print(foo[i*1000:i*1000+5])

[Row(addN2(A)=1.3568682670593262), Row(addN2(A)=2.5947766304016113), Row(addN2(A)=3.0666472911834717), Row(addN2(A)=4.1687726974487305), Row(addN2(A)=5.717076778411865)]
[Row(addN2(A)=1.7651697397232056), Row(addN2(A)=2.438091278076172), Row(addN2(A)=3.3596861362457275), Row(addN2(A)=4.361378192901611), Row(addN2(A)=5.69951057434082)]
[Row(addN2(A)=1.0654268264770508), Row(addN2(A)=2.522294759750366), Row(addN2(A)=3.9742650985717773), Row(addN2(A)=4.609269142150879), Row(addN2(A)=5.241343021392822)]
[Row(addN2(A)=1.4478042125701904), Row(addN2(A)=2.6139817237854004), Row(addN2(A)=3.4142234325408936), Row(addN2(A)=4.3832502365112305), Row(addN2(A)=5.457306385040283)]
[Row(addN2(A)=1.5085660219192505), Row(addN2(A)=2.6122524738311768), Row(addN2(A)=3.6369450092315674), Row(addN2(A)=4.831283092498779), Row(addN2(A)=5.2967729568481445)]


### pandas_udf iterator
see: https://github.com/apache/spark/blob/v3.2.0/python/pyspark/sql/pandas/functions.py#L153-L177

In [16]:
@pandas_udf("float")
def addN(x: Iterator[pd.Series]) -> Iterator[pd.Series]:
    print("{} udf iterator load".format(threading.get_ident()))
    N_local = 1.0
    time.sleep(delay)
    print("{} udf iterator done".format(threading.get_ident()))
    for item in x:
        print("type(item): {}".format(type(item)))
        yield item + N_local

In [17]:
%%time
foo = df.select(addN("A")).collect()

                                                                                

CPU times: user 311 ms, sys: 13.3 ms, total: 324 ms
Wall time: 36.3 s


In [18]:
foo[:5]

[Row(addN(A)=1.216964602470398),
 Row(addN(A)=1.4214831590652466),
 Row(addN(A)=1.051087498664856),
 Row(addN(A)=1.6599854230880737),
 Row(addN(A)=1.542315125465393)]

### mapPartitions

In [19]:
def addNR(it):
    result = []
    
    global N
    if N:
        raise Exception("USING CACHE!!!")
        N = N + 1.0
    else:
        print("{} rdd loading".format(threading.get_ident()))
        N = 1.0
        time.sleep(delay)
        print("{} rdd done".format(threading.get_ident()))
        
    for x in it:
        print("x: {}".format(x))
        result.append(x[1] + N)
    return result

In [20]:
rdd_out = rdd.mapPartitions(addNR)

In [21]:
%%time
foo = rdd_out.collect()



CPU times: user 55.9 ms, sys: 8.53 ms, total: 64.4 ms
Wall time: 36.3 s


                                                                                

In [22]:
foo[:5]

[1.0849208472412455,
 1.1665635621193933,
 1.3521201126708537,
 1.8527132166182345,
 1.9399026552058007]

### mapPartitions in separate file

In [23]:
from cache_rdd import addNR_dist

In [24]:
rdd_out = rdd.mapPartitions(addNR_dist)

In [25]:
%%time
foo = rdd_out.collect()

[Stage 16:>                                                      (0 + 16) / 100]

CPU times: user 28.4 ms, sys: 4.01 ms, total: 32.4 ms
Wall time: 5.41 s


                                                                                

In [26]:
foo[:5]

[1.5097103726287262,
 1.3307591283002687,
 1.3852438976526569,
 1.1161459274469387,
 1.5915728662068664]

### pandas_udf defined in separate file

In [27]:
from cache_pudf1 import addN1_dist

In [28]:
%%time
foo = df.select(addN1_dist("A")).collect()

                                                                                

CPU times: user 257 ms, sys: 18.7 ms, total: 275 ms
Wall time: 36.2 s


In [29]:
foo[:5]

[Row(addN1_dist(A)=1.3568682670593262),
 Row(addN1_dist(A)=1.5947765111923218),
 Row(addN1_dist(A)=1.0666471719741821),
 Row(addN1_dist(A)=1.16877281665802),
 Row(addN1_dist(A)=1.7170766592025757)]

### function in separate file, pandas_udf created in driver

In [30]:
from cache_pudf2 import addN1_dist2

In [31]:
pudf = pandas_udf(addN1_dist2, 'float')

In [32]:
%%time
foo = df.select(pudf("A")).collect()

                                                                                

CPU times: user 273 ms, sys: 7.65 ms, total: 281 ms
Wall time: 5.97 s


In [33]:
foo[:5]

[Row(addN1_dist2(A)=1.216964602470398),
 Row(addN1_dist2(A)=1.4214831590652466),
 Row(addN1_dist2(A)=1.051087498664856),
 Row(addN1_dist2(A)=1.6599854230880737),
 Row(addN1_dist2(A)=1.542315125465393)]