This is a tutorial of how to bring Pandas and Python code to Spark. We'll compare the traditional way with Fugue.

In [1]:
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression

X = pd.DataFrame({"x_1": [1, 1, 2, 2], "x_2":[1, 2, 2, 3]})
y = np.dot(X, np.array([1, 2])) + 3
reg = LinearRegression().fit(X, y)

Create the prediction function and test it

In [2]:
def predict(df: pd.DataFrame, model: LinearRegression) -> pd.DataFrame:
    return df.assign(predicted=model.predict(df))

input_df = pd.DataFrame({"x_1": [3, 4, 6, 6], "x_2":[3, 3, 6, 6]})
predict(input_df.copy(), reg)

Unnamed: 0,x_1,x_2,predicted
0,3,3,12.0
1,4,3,13.0
2,6,6,21.0
3,6,6,21.0


In [3]:
from fugue import transform
from fugue_spark import SparkExecutionEngine

result = transform(
    input_df,
    predict,
    schema="*,predicted:double",
    params={"model": reg},
    engine=SparkExecutionEngine()
)
result.show()

+---+---+---------+
|x_1|x_2|predicted|
+---+---+---------+
|  3|  3|     12.0|
|  4|  3|     13.0|
|  6|  6|     21.0|
|  6|  6|     21.0|
+---+---+---------+



In [4]:
from typing import Iterator, Any, Union
from pyspark.sql.types import StructType, StructField, DoubleType
from pyspark.sql import DataFrame, SparkSession

spark_session = SparkSession.builder.getOrCreate()

def predict_wrapper(dfs: Iterator[pd.DataFrame], model):
    for df in dfs:
        yield predict(df, model)

def run_predict(input_df: Union[DataFrame, pd.DataFrame], model):
    # conversion
    if isinstance(input_df, pd.DataFrame):
        sdf = spark_session.createDataFrame(input_df.copy())
    else:
        sdf = input_df.copy()

    schema = StructType(list(sdf.schema.fields))
    schema.add(StructField("predicted", DoubleType()))
    return sdf.mapInPandas(lambda dfs: predict_wrapper(dfs, model),
                           schema=schema)

result = run_predict(input_df.copy(), reg)
result.show()

+---+---+---------+
|x_1|x_2|predicted|
+---+---+---------+
|  3|  3|     12.0|
|  4|  3|     13.0|
|  6|  6|     21.0|
|  6|  6|     21.0|
+---+---+---------+

