In [36]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('grouped_map').getOrCreate()

### Grouped Map

`applyInPandas()`  maps each group to each pandas.DataFrame. In the function, `subtract_mean()` is defined as a standard Python function with a type hint. This function is useful for applying a transformation to an entire DataFrame.

In [37]:
import pandas as pd

df = spark.createDataFrame([(1,1.0),(2,5.0),(1,2.0),(2,10.0),(2,3.0)], ("id", "v"))
df.show()


+---+----+
| id|   v|
+---+----+
|  1| 1.0|
|  2| 5.0|
|  1| 2.0|
|  2|10.0|
|  2| 3.0|
+---+----+



In [38]:
def substract_mean(pdf: pd.DataFrame) -> pd.DataFrame:
  v = pdf.v
  return pdf.assign(v=v-v.mean())
    
df.groupby("id").applyInPandas( substract_mean, schema = df.schema).show()

+---+----+
| id|   v|
+---+----+
|  1|-0.5|
|  1| 0.5|
|  2|-1.0|
|  2| 4.0|
|  2|-3.0|
+---+----+



### Map

`mapInPandas()` maps every batch in each partition and transforms each. The function takes an iterator of `pandas.DataFrame`. The output length does not need to match the input size. This can operate on individual columns and are useful if you do not need to transform every column in a DataFrame


In [39]:
from typing import Iterator
import pandas as pd

df = spark.createDataFrame([(2,31),(1,22), (2,30),(1,21)], ("id", "age"))
df.show()

+---+---+
| id|age|
+---+---+
|  2| 31|
|  1| 22|
|  2| 30|
|  1| 21|
+---+---+



In [40]:
# Iterator of Series to iterator of series
def pandas_filter(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
  for pdf in iterator:
    yield pdf[pdf.id ==1]
    
df.mapInPandas( pandas_filter, schema = df.schema).show()

+---+---+
| id|age|
+---+---+
|  1| 22|
|  1| 21|
+---+---+



### Co-Grouped Map
Similar to the grouped map, this function maps each group to each `pandas.DataFrame` in the function, but it groups with another DataFrame by common key(s) an then the function is applied to each co-group. As in `mapInPandas()`, there is no restriction on output length.

In [41]:
import pandas as pd

df1 = spark.createDataFrame([(1201, 1, 1.0),(1201, 2, 2.0),(1202, 1, 3.0),(1202, 2, 4.0), (1203, 3, 1.5)],("time","id","v1"))
df1.show()
df2 = spark.createDataFrame([(1201, 1, 'x'),(1201, 2, 'y')], ("time","id","v2"))
df2.show()

+----+---+---+
|time| id| v1|
+----+---+---+
|1201|  1|1.0|
|1201|  2|2.0|
|1202|  1|3.0|
|1202|  2|4.0|
|1203|  3|1.5|
+----+---+---+

+----+---+---+
|time| id| v2|
+----+---+---+
|1201|  1|  x|
|1201|  2|  y|
+----+---+---+



In [42]:
def asof_join(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
  return pd.merge_asof(left, right, on="time", by="id")


df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas( asof_join, "time int, id int, v1 double, v2 string").show()

+----+---+---+----+
|time| id| v1|  v2|
+----+---+---+----+
|1201|  1|1.0|   x|
|1202|  1|3.0|   x|
|1201|  2|2.0|   y|
|1202|  2|4.0|   y|
|1203|  3|1.5|NULL|
+----+---+---+----+

