In [1]:
from __future__ import print_function

from pyspark.ml.regression import LinearRegression

from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

In [3]:
if __name__ == "__main__":

    # Create a SparkSession (Note, the config section is only for Windows!)
    spark = SparkSession.builder.config("spark.sql.warehouse.dir", "../data/temp").appName("LinearRegression").getOrCreate()

    # Load up our data and convert it to the format MLLib expects.
    inputLines = spark.sparkContext.textFile("../data/regression.txt")
    data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

    # Convert this RDD to a DataFrame
    colNames = ["label", "features"]
    df = data.toDF(colNames)

    # Note, there are lots of cases where you can avoid going from an RDD to a DataFrame.
    # Perhaps you're importing data from a real database. Or you are using structured streaming
    # to get your data.

    # Let's split our data into training data and testing data
    trainTest = df.randomSplit([0.5, 0.5])
    trainingDF = trainTest[0]
    testDF = trainTest[1]

    # Now create our linear regression model
    lir = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

    # Train the model using our training data
    model = lir.fit(trainingDF)

    # Now see if we can predict values in our test data.
    # Generate predictions using our linear regression model for all features in our
    # test dataframe:
    fullPredictions = model.transform(testDF).cache()

    # Extract the predictions and the "known" correct labels.
    predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
    labels = fullPredictions.select("label").rdd.map(lambda x: x[0])

    # Zip them together
    predictionAndLabel = predictions.zip(labels).collect()

    # Print out the predicted and actual values for each point
    for prediction in predictionAndLabel:
      print(prediction)


    # Stop the session
    spark.stop()


(-1.8613671803366616, -2.58)
(-1.731340035841882, -2.54)
(-1.9047095618349217, -2.36)
(-1.5868654308476822, -2.27)
(-1.5868654308476822, -2.17)
(-1.4351670956037723, -1.96)
(-1.4062721746049323, -1.94)
(-1.4351670956037723, -1.94)
(-1.3484823326072526, -1.91)
(-1.4351670956037723, -1.87)
(-1.3340348721078326, -1.8)
(-1.2545738393610226, -1.79)
(-1.2040077276130525, -1.74)
(-1.1895602671136327, -1.66)
(-1.1823365368639227, -1.6)
(-1.175112806614213, -1.59)
(-1.052309392369143, -1.58)
(-1.1967839973633427, -1.58)
(-1.1317704251159528, -1.57)
(-1.2184551881124726, -1.53)
(-1.0017432806211728, -1.48)
(-0.958400899122913, -1.4)
(-0.9439534386234931, -1.34)
(-0.8572686756269732, -1.3)
(-1.052309392369143, -1.3)
(-0.8283737546281331, -1.29)
(-1.066756852868563, -1.29)
(-0.8500449453772632, -1.26)
(-0.8789398663761031, -1.26)
(-0.8717161361263931, -1.22)
(-0.8067025638790033, -1.2)
(-0.9006110571252331, -1.17)
(-0.907834787374943, -1.17)
(-0.8355974848778431, -1.14)
(-0.7994788336292933, -1.12