In [1]:
from __future__ import print_function

from pyspark.ml.regression import LinearRegression

from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors


# Create a SparkSession (Note, the config section is only for Windows!)
spark = SparkSession.builder.appName("LinearRegression").getOrCreate()

# Load up our data and convert it to the format MLLib expects.
inputLines = spark.sparkContext.textFile("/home/jovyan/work/data/regression.txt")
data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

# Convert this RDD to a DataFrame
colNames = ["label", "features"]
df = data.toDF(colNames)

# Note, there are lots of cases where you can avoid going from an RDD to a DataFrame.
# Perhaps you're importing data from a real database. Or you are using structured streaming
# to get your data.

# Let's split our data into training data and testing data
trainTest = df.randomSplit([0.5, 0.5])
trainingDF = trainTest[0]
testDF = trainTest[1]

# Now create our linear regression model
lir = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

# Train the model using our training data
model = lir.fit(trainingDF)

# Now see if we can predict values in our test data.
# Generate predictions using our linear regression model for all features in our
# test dataframe:
fullPredictions = model.transform(testDF).cache()

# Extract the predictions and the "known" correct labels.
predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
labels = fullPredictions.select("label").rdd.map(lambda x: x[0])

# Zip them together
predictionAndLabel = predictions.zip(labels).collect()

# Print out the predicted and actual values for each point
for prediction in predictionAndLabel:
  print(prediction)


# Stop the session
spark.stop()


(-2.641222276793917, -3.74)
(-1.8061644640946106, -2.58)
(-1.6787827638523438, -2.54)
(-1.8486250308420331, -2.36)
(-1.6504757193540622, -2.29)
(-1.5372475413609357, -2.27)
(-1.5797081081083582, -2.26)
(-1.3320214687483942, -2.12)
(-1.4310961244923799, -2.07)
(-1.3886355577449574, -1.96)
(-1.3603285132466758, -1.94)
(-1.3886355577449574, -1.94)
(-1.2824841408764016, -1.91)
(-1.3886355577449574, -1.87)
(-1.289560902000972, -1.8)
(-1.183409485132416, -1.75)
(-1.020643979267297, -1.67)
(-1.148025679509564, -1.66)
(-1.1409489183849937, -1.65)
(-1.289560902000972, -1.64)
(-1.2046397685061272, -1.61)
(-1.1409489183849937, -1.6)
(-1.0135672181427267, -1.58)
(-1.1763327240078456, -1.53)
(-1.020643979267297, -1.47)
(-0.985260173644445, -1.46)
(-1.112641873886712, -1.42)
(-0.9215693235233116, -1.39)
(-0.985260173644445, -1.36)
(-0.9074158012741708, -1.34)
(-0.7941876232810443, -1.3)
(-0.8224946677793259, -1.3)
(-1.0135672181427267, -1.3)
(-1.0277207403918673, -1.29)
(-0.8295714289038963, -1.27)
