In [1]:
from pyspark.sql import SparkSession
import os

# The name of your Spark cluster hostname or ip address
spark_cluster = os.environ['SPARK_CLUSTER']

spark = SparkSession.builder \
    .master('spark://{cluster}:7077'.format(cluster=spark_cluster)) \
    .appName('Spark-UDF-Demo') \
    .getOrCreate()

In [2]:
from random import gauss
import numpy as np
from sklearn import linear_model
from random import gauss

trainX = np.array([[gauss(0,1),gauss(0,1),gauss(0,1)] for x in range(100)])
trainY = np.array([1.0 * x[0] + 2.0 * x[1] + 3.0 * x[2] + gauss(0,1) for x in trainX])

In [3]:
regr = linear_model.LinearRegression()
regr.fit(trainX, trainY)
regr.coef_

array([1.03441194, 1.94478251, 2.98967767])

In [4]:
from pyspark.sql.types import *

def predict(x):
    return float(regr.predict([x]))

spark.udf.register("predict", predict, DoubleType())

In [5]:
xdata = [[[gauss(0,1),gauss(0,1),gauss(0,1)]] for x in range(500000)]
dataDF1 = spark.sparkContext.parallelize(xdata, 1) \
    .toDF(StructType([StructField("x", ArrayType(DoubleType()))])) \
    .cache()
dataDF2 = spark.sparkContext.parallelize(xdata, 2) \
    .toDF(StructType([StructField("x", ArrayType(DoubleType()))])) \
    .cache()
dataDF1.registerTempTable("xdata1")
dataDF2.registerTempTable("xdata2")
dataDF1.show(5)
dataDF2.show(5)

+--------------------+
|                   x|
+--------------------+
|[0.28882317135799...|
|[-1.1035224904509...|
|[0.91096151183319...|
|[-0.1946786346608...|
|[-0.5178082287386...|
+--------------------+
only showing top 5 rows

+--------------------+
|                   x|
+--------------------+
|[0.28882317135799...|
|[-1.1035224904509...|
|[0.91096151183319...|
|[-0.1946786346608...|
|[-0.5178082287386...|
+--------------------+
only showing top 5 rows



In [6]:
%%time
predictDF1 = spark.sql("select x, predict(x) as y from xdata1")
predictDF1.show(5)

+--------------------+-------------------+
|                   x|                  y|
+--------------------+-------------------+
|[0.28882317135799...| -5.201442738873413|
|[-1.1035224904509...| -7.095088720788476|
|[0.91096151183319...|  3.157461208548457|
|[-0.1946786346608...|-1.5128694613688918|
|[-0.5178082287386...|-3.2399367860365227|
+--------------------+-------------------+
only showing top 5 rows

CPU times: user 3.29 ms, sys: 1.41 ms, total: 4.7 ms
Wall time: 6.48 s


In [7]:
%%time
predictDF2 = spark.sql("select x, predict(x) as y from xdata2")
predictDF2.show(5)

+--------------------+-------------------+
|                   x|                  y|
+--------------------+-------------------+
|[0.28882317135799...| -5.201442738873413|
|[-1.1035224904509...| -7.095088720788476|
|[0.91096151183319...|  3.157461208548457|
|[-0.1946786346608...|-1.5128694613688918|
|[-0.5178082287386...|-3.2399367860365227|
+--------------------+-------------------+
only showing top 5 rows

CPU times: user 2.24 ms, sys: 2.01 ms, total: 4.25 ms
Wall time: 3.89 s
