In [1]:
import findspark
findspark.init()
from pyspark.ml.regression import LinearRegression

from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

if __name__ == "__main__":

    # Create a SparkSession (Note, the config section is only for Windows!)
    spark = SparkSession.builder.config("spark.sql.warehouse.dir", "file:///C:/temp").appName("LinearRegression").getOrCreate()

    # Load up our data and convert it to the format MLLib expects.
    inputLines = spark.sparkContext.textFile("regression.txt")
    data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

    # Convert this RDD to a DataFrame
    colNames = ["label", "features"]
    df = data.toDF(colNames)

    # Note, there are lots of cases where you can avoid going from an RDD to a DataFrame.
    # Perhaps you're importing data from a real database. Or you are using structured streaming
    # to get your data.

    # Let's split our data into training data and testing data
    trainTest = df.randomSplit([0.5, 0.5])
    trainingDF = trainTest[0]
    testDF = trainTest[1]

    # Now create our linear regression model
    lir = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

    # Train the model using our training data
    model = lir.fit(trainingDF)

    # Now see if we can predict values in our test data.
    # Generate predictions using our linear regression model for all features in our
    # test dataframe:
    fullPredictions = model.transform(testDF).cache()

    # Extract the predictions and the "known" correct labels.
    predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
    labels = fullPredictions.select("label").rdd.map(lambda x: x[0])

    # Zip them together
    predictionAndLabel = predictions.zip(labels).collect()

    # Print out the predicted and actual values for each point
    for prediction in predictionAndLabel:
      print(prediction)


    # Stop the session
    spark.stop()


(-1.8080807849189933, -2.58)
(-1.6804222792333225, -2.54)
(-1.8506336201475502, -2.36)
(-1.6520537224142846, -2.29)
(-1.5811323303666895, -2.26)
(-1.5385794951381326, -2.17)
(-1.3329074582001073, -2.12)
(-1.3825524326334238, -2.09)
(-1.283262483766791, -1.91)
(-1.3045389013810695, -1.91)
(-1.183972534900158, -1.75)
(-1.1485118388763607, -1.66)
(-1.2903546229715506, -1.64)
(-1.2052489525144365, -1.61)
(-1.1768803956953986, -1.53)
(-1.1130511428525633, -1.42)
(-0.9215633843240569, -1.4)
(-0.8719184098907404, -1.37)
(-0.793904878638386, -1.3)
(-1.0137611939859303, -1.3)
(-0.793904878638386, -1.29)
(-0.8293655746621835, -1.27)
(-0.8151812962526646, -1.26)
(-0.8435498530717025, -1.26)
(-0.836457713866943, -1.25)
(-0.9215633843240569, -1.25)
(-0.7584441826145887, -1.24)
(-0.8293655746621835, -1.23)
(-0.7726284610241076, -1.2)
(-0.836457713866943, -1.2)
(-0.864826270685981, -1.17)
(-0.8719184098907404, -1.17)
(-0.8719184098907404, -1.16)
(-0.8009970178431455, -1.14)
(-0.8577341314812215, -1.1