In [1]:
%%capture
!pip install pyspark

In [2]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql import Window
from pyspark.sql.types import LongType

In [3]:
spark = SparkSession.builder.master('local[*]').appName("SparkTest").getOrCreate()

#pandas_function_api

### Grouped map

In [4]:
df1 = spark.createDataFrame(
    [(1, 10), (1, 20), (2, 30), (2, 50), (2, 10)],
    ("id", "v"))

def subtract_mean(pdf):
    # pdf is a pandas.DataFrame
    v = pdf.v
    return pdf.assign(v=v - v.mean())

df1.groupby("id").applyInPandas(subtract_mean, schema="id integer, v integer").show()

+---+---+
| id|  v|
+---+---+
|  1| -5|
|  1|  5|
|  2|  0|
|  2| 20|
|  2|-20|
+---+---+



### Map

In [6]:
df2 = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))

def filter_func(iterator):
    for pdf in iterator:
        yield pdf[pdf.id == 1]

df2.mapInPandas(filter_func, schema=df2.schema).show()

+---+---+
| id|age|
+---+---+
|  1| 21|
+---+---+



### Cogrouped map

In [7]:
df3 = spark.createDataFrame(
    [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
    ("time", "id", "v1"))

df4 = spark.createDataFrame(
    [(20000101, 1, "x"), (20000101, 2, "y")],
    ("time", "id", "v2"))

def asof_join(l, r):
    return pd.merge_asof(l, r, on="time", by="id")

df3.groupby("id").cogroup(df4.groupby("id")).applyInPandas(
    asof_join, schema="time int, id int, v1 double, v2 string").show()

+--------+---+---+---+
|    time| id| v1| v2|
+--------+---+---+---+
|20000101|  1|1.0|  x|
|20000102|  1|3.0|  x|
|20000101|  2|2.0|  y|
|20000102|  2|4.0|  y|
+--------+---+---+---+



#udf_python_pandas

### Series to Series UDF

In [8]:
def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series:
    return a * b

In [9]:
multiply = pandas_udf(multiply_func, returnType=LongType())

In [10]:
x = pd.Series([1, 2, 3])
print(multiply_func(x, x))

0    1
1    4
2    9
dtype: int64


In [11]:
df5 = spark.createDataFrame(pd.DataFrame(x, columns=["x"]))

In [12]:
df5.select(multiply(col("x"), col("x"))).show()

+-------------------+
|multiply_func(x, x)|
+-------------------+
|                  1|
|                  4|
|                  9|
+-------------------+



### Series to scalar UDF

In [13]:
df6 = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v"))

In [14]:
@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
    return v.mean()

In [15]:
df6.select(mean_udf(df6['v'])).show()

+-----------+
|mean_udf(v)|
+-----------+
|        4.2|
+-----------+



In [16]:
df6.groupby("id").agg(mean_udf(df6['v'])).show()

+---+-----------+
| id|mean_udf(v)|
+---+-----------+
|  1|        1.5|
|  2|        6.0|
+---+-----------+



In [17]:
w = Window \
    .partitionBy('id') \
    .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df6.withColumn('mean_v', mean_udf(df6['v']).over(w)).show()

+---+----+------+
| id|   v|mean_v|
+---+----+------+
|  1| 1.0|   1.5|
|  1| 2.0|   1.5|
|  2| 3.0|   6.0|
|  2| 5.0|   6.0|
|  2|10.0|   6.0|
+---+----+------+

