In [1]:
from pyspark.sql import SparkSession
import os

# The name of your Spark cluster hostname or ip address
spark_cluster = os.environ['SPARK_CLUSTER']

spark = SparkSession.builder \
    .master('spark://{cluster}:7077'.format(cluster=spark_cluster)) \
    .appName('Spark-UDF-Demo') \
    .getOrCreate()

In [2]:
from random import gauss
import numpy as np
from sklearn import linear_model
from random import gauss

trainX = np.array([[gauss(0,1),gauss(0,1),gauss(0,1)] for x in range(100)])
trainY = np.array([1.0 * x[0] + 2.0 * x[1] + 3.0 * x[2] + gauss(0,1) for x in trainX])

In [3]:
regr = linear_model.LinearRegression()
regr.fit(trainX, trainY)
regr.coef_

array([0.83289655, 2.17553368, 3.10220691])

In [4]:
from pyspark.sql.types import *

def predict(x):
    return float(regr.predict([x]))

spark.udf.register("predict", predict, DoubleType())

In [5]:
xdata = [[[gauss(0,1),gauss(0,1),gauss(0,1)]] for x in range(500000)]
dataDF1 = spark.sparkContext.parallelize(xdata, 1) \
    .toDF(StructType([StructField("x", ArrayType(DoubleType()))])) \
    .cache()
dataDF2 = spark.sparkContext.parallelize(xdata, 2) \
    .toDF(StructType([StructField("x", ArrayType(DoubleType()))])) \
    .cache()
dataDF1.registerTempTable("xdata1")
dataDF2.registerTempTable("xdata2")
dataDF1.show(5)
dataDF2.show(5)

+--------------------+
|                   x|
+--------------------+
|[-0.7195409868932...|
|[0.41342795524975...|
|[0.09828362248218...|
|[-0.9950719961013...|
|[0.73438926972835...|
+--------------------+
only showing top 5 rows

+--------------------+
|                   x|
+--------------------+
|[-0.7195409868932...|
|[0.41342795524975...|
|[0.09828362248218...|
|[-0.9950719961013...|
|[0.73438926972835...|
+--------------------+
only showing top 5 rows



In [6]:
%%time
predictDF1 = spark.sql("select x, predict(x) as y from xdata1")
predictDF1.show(5)

+--------------------+--------------------+
|                   x|                   y|
+--------------------+--------------------+
|[-0.7195409868932...|   2.189687135199884|
|[0.41342795524975...|   7.898904574235032|
|[0.09828362248218...|-0.23878946851847022|
|[-0.9950719961013...|  -4.804233525756058|
|[0.73438926972835...|  0.8826929320066959|
+--------------------+--------------------+
only showing top 5 rows

CPU times: user 3.08 ms, sys: 843 µs, total: 3.92 ms
Wall time: 5.5 s


In [7]:
%%time
predictDF2 = spark.sql("select x, predict(x) as y from xdata2")
predictDF2.show(5)

+--------------------+--------------------+
|                   x|                   y|
+--------------------+--------------------+
|[-0.7195409868932...|   2.189687135199884|
|[0.41342795524975...|   7.898904574235032|
|[0.09828362248218...|-0.23878946851847022|
|[-0.9950719961013...|  -4.804233525756058|
|[0.73438926972835...|  0.8826929320066959|
+--------------------+--------------------+
only showing top 5 rows

CPU times: user 2.71 ms, sys: 1.09 ms, total: 3.8 ms
Wall time: 3.98 s
