### Create test data

In [1]:
import numpy as np
import pandas as pd
import threading
import time

from pyspark.sql.functions import col, pandas_udf, spark_partition_id, udf
from pyspark.sql.types import *
from typing import Iterator

In [2]:
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 1024)
spark.sparkContext.addPyFile("cache_rdd.py")
spark.sparkContext.addPyFile("cache_pudf1.py")
spark.sparkContext.addPyFile("cache_pudf2.py")

In [3]:
data = np.random.rand(100000,2)

In [4]:
pdf = pd.DataFrame(data,columns=['A', 'B'])
pdf

Unnamed: 0,A,B
0,0.101431,0.545185
1,0.010034,0.316767
2,0.368177,0.229946
3,0.852369,0.203069
4,0.614969,0.984089
...,...,...
99995,0.173051,0.530109
99996,0.162625,0.937435
99997,0.656313,0.553777
99998,0.979720,0.842729


In [5]:
df = spark.createDataFrame(pdf).repartition(10)
df.show(5)

[Stage 0:>                                                          (0 + 2) / 2]

+-------------------+-------------------+
|                  A|                  B|
+-------------------+-------------------+
| 0.6725068566776664| 0.4048381689108268|
| 0.4646654335378224|0.16736603896679958|
| 0.9013817348875479|0.04339854354727446|
|0.00802792044114109| 0.7550677173094015|
| 0.8373676205836582|0.08083523416612326|
+-------------------+-------------------+
only showing top 5 rows



                                                                                

In [6]:
rdd = df.rdd

In [7]:
N = None

In [8]:
delay = 5

### pandas_udf

In [9]:
@pandas_udf("float")
def addN1(x: pd.Series) -> pd.Series:
    global N
    if N:
        # raise Exception("USING CACHE!!!")
        print("USING CACHE!!!")
        N = N + 1.0
    else:
        print("{} pandas_udf load".format(threading.get_ident()))
        N = 1.0
        time.sleep(delay)
        print("{} pandas_udf done".format(threading.get_ident()))
    
    print("type(x): {}".format(type(x)))
    return x + N

In [10]:
%%time
foo = df.select(addN1("A")).collect()

                                                                                

CPU times: user 272 ms, sys: 15 ms, total: 287 ms
Wall time: 27.1 s


In [11]:
foo[:5]

[Row(addN1(A)=1.1130174398422241),
 Row(addN1(A)=1.3608601093292236),
 Row(addN1(A)=1.422696828842163),
 Row(addN1(A)=1.2258188724517822),
 Row(addN1(A)=1.5190355777740479)]

### udf

In [12]:
@udf(returnType=FloatType())
def addN2(x):
    global N
    if N:
        # raise Exception("USING CACHE!!!")
        print("USING CACHE!!!")
        N = N + 1.0
    else:
        print("{} udf loading".format(threading.get_ident()))
        N = 1.0
        time.sleep(delay)
        print("{} udf done".format(threading.get_ident()))

    return x + N

In [13]:
%%time
foo = df.select(addN2("A")).collect()

                                                                                

CPU times: user 283 ms, sys: 14.9 ms, total: 297 ms
Wall time: 26.1 s


In [14]:
foo[:5]

[Row(addN2(A)=1.1130174398422241),
 Row(addN2(A)=2.3608601093292236),
 Row(addN2(A)=3.422696828842163),
 Row(addN2(A)=4.225819110870361),
 Row(addN2(A)=5.519035339355469)]

In [15]:
for i in range(5):
    print(foo[i*1000:i*1000+5])

[Row(addN2(A)=1.1130174398422241), Row(addN2(A)=2.3608601093292236), Row(addN2(A)=3.422696828842163), Row(addN2(A)=4.225819110870361), Row(addN2(A)=5.519035339355469)]
[Row(addN2(A)=1001.0861206054688), Row(addN2(A)=1002.9450073242188), Row(addN2(A)=1003.4807739257812), Row(addN2(A)=1004.1425170898438), Row(addN2(A)=1005.3653564453125)]
[Row(addN2(A)=2001.1534423828125), Row(addN2(A)=2002.98388671875), Row(addN2(A)=2003.001220703125), Row(addN2(A)=2004.948486328125), Row(addN2(A)=2005.0477294921875)]
[Row(addN2(A)=3001.1435546875), Row(addN2(A)=3002.0576171875), Row(addN2(A)=3003.155517578125), Row(addN2(A)=3004.908447265625), Row(addN2(A)=3005.816162109375)]
[Row(addN2(A)=4001.0302734375), Row(addN2(A)=4002.307861328125), Row(addN2(A)=4003.425048828125), Row(addN2(A)=4004.884521484375), Row(addN2(A)=4005.6884765625)]


### pandas_udf iterator
see: https://github.com/apache/spark/blob/v3.2.0/python/pyspark/sql/pandas/functions.py#L153-L177

In [16]:
@pandas_udf("float")
def addN(x: Iterator[pd.Series]) -> Iterator[pd.Series]:
    print("{} udf iterator load".format(threading.get_ident()))
    N_local = 1.0
    time.sleep(delay)
    print("{} udf iterator done".format(threading.get_ident()))
    for item in x:
        print("type(item): {}".format(type(item)))
        yield item + N_local

In [17]:
%%time
foo = df.select(addN("A")).collect()

                                                                                

CPU times: user 304 ms, sys: 10.7 ms, total: 315 ms
Wall time: 25.9 s


In [18]:
foo[:5]

[Row(addN(A)=1.1196389198303223),
 Row(addN(A)=1.805098533630371),
 Row(addN(A)=1.6327793598175049),
 Row(addN(A)=1.6746606826782227),
 Row(addN(A)=1.228914737701416)]

### mapPartitions

In [19]:
def addNR(it):
    result = []
    
    global N
    if N:
        raise Exception("USING CACHE!!!")
        N = N + 1.0
    else:
        print("{} rdd loading".format(threading.get_ident()))
        N = 1.0
        time.sleep(delay)
        print("{} rdd done".format(threading.get_ident()))
        
    for x in it:
        print("x: {}".format(x))
        result.append(x[1] + N)
    return result

In [20]:
rdd_out = rdd.mapPartitions(addNR)

In [21]:
%%time
foo = rdd_out.collect()



CPU times: user 48.6 ms, sys: 229 µs, total: 48.8 ms
Wall time: 26.3 s


                                                                                

In [22]:
foo[:5]

[1.9688157549552834,
 1.6876051989028156,
 1.1418699874472678,
 1.1659114159107062,
 1.8593736801570682]

### mapPartitions in separate file

In [23]:
from cache_rdd import addNR_dist

In [24]:
rdd_out = rdd.mapPartitions(addNR_dist)

In [25]:
%%time
foo = rdd_out.collect()



CPU times: user 19.1 ms, sys: 3.16 ms, total: 22.2 ms
Wall time: 5.7 s


                                                                                

In [26]:
foo[:5]

[1.9688157549552834,
 1.6876051989028156,
 1.1418699874472678,
 1.1659114159107062,
 1.8593736801570682]

### pandas_udf defined in separate file

In [27]:
from cache_pudf1 import addN1_dist

In [28]:
%%time
foo = df.select(addN1_dist("A")).collect()

                                                                                

CPU times: user 245 ms, sys: 6.57 ms, total: 251 ms
Wall time: 25.8 s


In [29]:
foo[:5]

[Row(addN1_dist(A)=1.1130174398422241),
 Row(addN1_dist(A)=1.3608601093292236),
 Row(addN1_dist(A)=1.422696828842163),
 Row(addN1_dist(A)=1.2258188724517822),
 Row(addN1_dist(A)=1.5190355777740479)]

### function in separate file, pandas_udf created in driver

In [30]:
from cache_pudf2 import addN1_dist2

In [31]:
pudf = pandas_udf(addN1_dist2, 'float')

In [32]:
%%time
foo = df.select(pudf("A")).collect()

                                                                                

CPU times: user 288 ms, sys: 6.37 ms, total: 294 ms
Wall time: 5.76 s


In [33]:
foo[:5]

[Row(addN1_dist2(A)=1.1196389198303223),
 Row(addN1_dist2(A)=1.805098533630371),
 Row(addN1_dist2(A)=1.6327793598175049),
 Row(addN1_dist2(A)=1.6746606826782227),
 Row(addN1_dist2(A)=1.228914737701416)]

### function in separate file (wrapped), pandas_udf in driver

In [35]:
from cache_pudf2 import get_fn

In [36]:
pudf2 = pandas_udf(get_fn(), 'float')

In [37]:
%%time
foo = df.select(pudf2("A")).collect()

                                                                                

CPU times: user 275 ms, sys: 30.2 ms, total: 305 ms
Wall time: 25.8 s
