In [1]:

from __future__ import print_function
from pyspark.ml.regression import LinearRegression
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession


This notebook is based on Spark 2.0;
and the main emphasis in Spark 2 is on dataframes and datasets rather than RDDs
Benefits of dataframes: structured data/schema, 'mini databse' queried by SQL, optimised data retrieval

In [2]:

spark = SparkSession.builder.appName('linear_regression').getOrCreate()

In [3]:

input_lines = spark.sparkContext.textFile('DataScience-Python3/regression.txt')
data = input_lines.map(lambda x: x.split(',')).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

In [4]:

# Convert RDD to a dataframe
col_names = ['label', 'features']
df = data.toDF(col_names)

In [5]:

train_test = df.randomSplit([0.5, 0.5])
train_df = train_test[0]
test_df = train_test[1]

In [6]:

# Create and test linear regression model
clf = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
model = clf.fit(train_df)

In [7]:

# Predict using test data
full_predictions = model.transform(test_df).cache()

In [8]:

# Extract the predictions against the ground truth labels
predictions = full_predictions.select('prediction').rdd.map(lambda x: x[0])
labels = full_predictions.select('label').rdd.map(lambda x: x[0])

In [9]:

# zip them together
prediction_vs_label = predictions.zip(labels).collect()

In [10]:

for prediction in prediction_vs_label:
    print (prediction)

(-1.8131204488136754, -2.58)
(-1.6862257623164936, -2.54)
(-1.8554186776460693, -2.36)
(-1.5452316662085135, -2.27)
(-1.3971878652951346, -1.96)
(-1.3971878652951346, -1.87)
(-1.2984919980195488, -1.8)
(-1.1715973115223668, -1.77)
(-1.1504481971061697, -1.65)
(-1.1504481971061697, -1.6)
(-1.143398492300771, -1.59)
(-1.023553510608988, -1.58)
(-1.1645476067169678, -1.58)
(-1.1011002634683769, -1.57)
(-1.1856967211331648, -1.53)
(-0.9742055769711948, -1.48)
(-0.9953546913873919, -1.46)
(-1.122249377884574, -1.42)
(-0.9319073481388009, -1.4)
(-0.9953546913873919, -1.36)
(-0.8050126616416189, -1.29)
(-1.0376529202197857, -1.29)
(-0.8261617760578159, -1.26)
(-0.8543605952794119, -1.26)
(-0.8473108904740129, -1.25)
(-0.769764137614624, -1.24)
(-0.8402611856686139, -1.23)
(-0.8473108904740129, -1.22)
(-0.8825594145010079, -1.17)
(-0.776813842420023, -1.12)
(-0.706316794366033, -1.11)
(-0.748615023198427, -1.09)
(-0.7909132520308211, -1.08)
(-0.8261617760578159, -1.04)
(-0.6499191559228411, -1