In [1]:
from __future__ import print_function
from pyspark.ml.regression import LinearRegression
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

In [2]:
if __name__ == "__main__":

    # Create a SparkSession (Note, the config section is only for Windows!)
    spark = SparkSession.builder.appName("LinearRegression").getOrCreate()

In [3]:
    # Load up our data and convert it to the format MLLib expects.
    inputLines = spark.sparkContext.textFile("c:/SparkCourse/Machine Learning/Dataset/regression.txt")
    data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

In [4]:
    # Convert this RDD to a DataFrame
    colNames = ["label", "features"]
    df = data.toDF(colNames)

In [5]:
    # Note, there are lots of cases where you can avoid going from an RDD to a DataFrame.
    # Perhaps you're importing data from a real database. Or you are using structured streaming
    # to get your data.

    # Let's split our data into training data and testing data
    trainTest = df.randomSplit([0.5, 0.5])
    trainingDF = trainTest[0]
    testDF = trainTest[1]

In [6]:
    # Now create our linear regression model
    lir = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

In [7]:
    # Train the model using our training data
    model = lir.fit(trainingDF)

In [8]:
    # Now see if we can predict values in our test data.
    # Generate predictions using our linear regression model for all features in our
    # test dataframe:
    fullPredictions = model.transform(testDF).cache()

In [9]:
    # Extract the predictions and the "known" correct labels.
    predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
    labels = fullPredictions.select("label").rdd.map(lambda x: x[0])

In [12]:
    # Zip them together
    predictionAndLabel = predictions.zip(labels).collect()

In [14]:
    # Print out the predicted and actual values for each point
    for prediction in predictionAndLabel:
      print(prediction)

(-2.6750371879158097, -3.74)
(-1.8344337382277593, -2.58)
(-1.6777110611672754, -2.29)
(-1.5637309323960142, -2.27)
(-1.5637309323960142, -2.17)
(-1.4070082553355303, -2.09)
(-1.456874561672957, -2.07)
(-1.4426270455765495, -2.0)
(-1.414132013383734, -1.94)
(-1.3500181909498998, -1.88)
(-1.414132013383734, -1.87)
(-1.186171755841212, -1.77)
(-1.1790479977930082, -1.58)
(-1.114934175359174, -1.57)
(-1.2004192719376197, -1.53)
(-1.0436965948771355, -1.47)
(-0.9439639822022823, -1.4)
(-0.8940976758648554, -1.37)
(-1.0579441109735432, -1.33)
(-0.8157363373346134, -1.3)
(-0.8157363373346134, -1.29)
(-0.8371076114792249, -1.26)
(-0.8656026436720402, -1.26)
(-0.7801175470935945, -1.24)
(-0.7943650631900021, -1.2)
(-0.8940976758648554, -1.16)
(-0.8228600953828172, -1.14)
(-0.71600372465976, -1.11)
(-0.8798501597684478, -1.11)
(-0.6590136602741296, -1.1)
(-0.7729937890453906, -1.1)
(-0.758746272948983, -1.09)
(-0.801488821238206, -1.08)
(-0.7373749988043715, -1.07)
(-0.71600372465976, -1.05)
(-

In [16]:
    # Stop the session
    spark.stop()
