# New Pandas UDFs and Python Type Hints in the Upcoming Release of Apache Spark 3.0

In [1]:
import pyarrow as pa
print(f"pyarrow version is {pa.__version__}")

from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .appName("Python Pandas UDF in Spark 3.0")
    .config('spark.sql.shuffle.partitions', 5)
    .config("spark.sql.execution.arrow.enabled", "true")
    .getOrCreate()
)

spark

pyarrow version is 0.17.1


# New Pandas UDFs

## Series to Series

In [2]:
import pandas as pd
from pyspark.sql.functions import pandas_udf       

@pandas_udf('long')
def pandas_plus_one(s: pd.Series) -> pd.Series:
    return s + 1

spark.range(10).select(pandas_plus_one("id")).show()

+-------------------+
|pandas_plus_one(id)|
+-------------------+
|                  1|
|                  2|
|                  3|
|                  4|
|                  5|
|                  6|
|                  7|
|                  8|
|                  9|
|                 10|
+-------------------+



## Iterator of Series to Iterator of Series

In [3]:
from typing import Iterator
import pandas as pd
from pyspark.sql.functions import pandas_udf       

@pandas_udf('long')
def pandas_plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    return map(lambda s: s + 1, iterator)

spark.range(10).select(pandas_plus_one("id")).show()

+-------------------+
|pandas_plus_one(id)|
+-------------------+
|                  1|
|                  2|
|                  3|
|                  4|
|                  5|
|                  6|
|                  7|
|                  8|
|                  9|
|                 10|
+-------------------+



In [4]:
# @pandas_udf("long")
# def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
#     # Do some expensive initialization with a state
#     state = very_expensive_initialization()
#     for x in iterator:
#         # Use that state for the whole iterator.
#         yield calculate_with_state(x, state)

# df.select(calculate("value")).show()

### Iterator of Multiple Series to Iterator of Series

In [5]:
from typing import Iterator, Tuple
import pandas as pd
from pyspark.sql.functions import pandas_udf       

@pandas_udf("long")
def multiply_two(
        iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
    return (a * b for a, b in iterator)

spark.range(10).select(multiply_two("id", "id")).show()

+--------------------+
|multiply_two(id, id)|
+--------------------+
|                   0|
|                   1|
|                   4|
|                   9|
|                  16|
|                  25|
|                  36|
|                  49|
|                  64|
|                  81|
+--------------------+



## Series to Scalar

In [6]:
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql import Window

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))

@pandas_udf("double")
def pandas_mean(v: pd.Series) -> float:
    return v.sum()

df.select(pandas_mean(df['v'])).show()
df.groupby("id").agg(pandas_mean(df['v'])).show()
df.select(pandas_mean(df['v']).over(Window.partitionBy('id'))).show()

+--------------+
|pandas_mean(v)|
+--------------+
|          21.0|
+--------------+

+---+--------------+
| id|pandas_mean(v)|
+---+--------------+
|  2|          18.0|
|  1|           3.0|
+---+--------------+

+---------------------------------------------------------+
|pandas_mean(v) OVER (PARTITION BY id unspecifiedframe$())|
+---------------------------------------------------------+
|                                                     18.0|
|                                                     18.0|
|                                                     18.0|
|                                                      3.0|
|                                                      3.0|
+---------------------------------------------------------+



## New Pandas Function APIs

### Grouped Map

In [7]:
import pandas as pd

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))

def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame:
    v = pdf.v
    return pdf.assign(v=v - v.mean())

df.groupby("id").applyInPandas(subtract_mean, schema=df.schema).show()

+---+----+
| id|   v|
+---+----+
|  2|-3.0|
|  2|-1.0|
|  2| 4.0|
|  1|-0.5|
|  1| 0.5|
+---+----+



### Map

In [8]:
from typing import Iterator
import pandas as pd

df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))

def pandas_filter(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for pdf in iterator:
        yield pdf[pdf.id == 1]

df.mapInPandas(pandas_filter, schema=df.schema).show()

+---+---+
| id|age|
+---+---+
|  1| 21|
+---+---+



### Co-grouped Map

In [9]:
import pandas as pd

df1 = spark.createDataFrame(
    [(1201, 1, 1.0), (1201, 2, 2.0), (1202, 1, 3.0), (1202, 2, 4.0)],
    ("time", "id", "v1"))
df2 = spark.createDataFrame(
    [(1201, 1, "x"), (1201, 2, "y")], ("time", "id", "v2"))

def asof_join(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
    return pd.merge_asof(left, right, on="time", by="id")

df1.groupby("id").cogroup(
    df2.groupby("id")
).applyInPandas(asof_join, "time int, id int, v1 double, v2 string").show()

+----+---+---+---+
|time| id| v1| v2|
+----+---+---+---+
|1201|  2|2.0|  y|
|1202|  2|4.0|  y|
|1201|  1|1.0|  x|
|1202|  1|3.0|  x|
+----+---+---+---+

