In [1]:
from pyspark.ml.regression import LinearRegression

from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

In [2]:
# Create a SparkSession (Note, the config section is only for Windows!)
#spark = SparkSession.builder.config("spark.sql.warehouse.dir", "file:///C:/temp").appName("LinearRegression").getOrCreate()
spark = SparkSession.builder.appName("LinearRegression").getOrCreate()



In [3]:

# Load up our data and convert it to the format MLLib expects.
inputLines = spark.sparkContext.textFile("file:///home/hashimyousaf/spark-2.4.0-bin-hadoop2.7/bin/jupyter-scripts/dataset/regression.txt")
data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))


In [4]:
# Convert this RDD to a DataFrame
colNames = ["label", "features"]
df = data.toDF(colNames)



In [5]:
# Note, there are lots of cases where you can avoid going from an RDD to a DataFrame.
# Perhaps you're importing data from a real database. Or you are using structured streaming
# to get your data.

# Let's split our data into training data and testing data
trainTest = df.randomSplit([0.5, 0.5])
trainingDF = trainTest[0]
testDF = trainTest[1]


In [6]:

# Now create our linear regression model
lir = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

# Train the model using our training data
model = lir.fit(trainingDF)



In [7]:

# Now see if we can predict values in our test data.
# Generate predictions using our linear regression model for all features in our
# test dataframe:
fullPredictions = model.transform(testDF).cache()



In [8]:
# Extract the predictions and the "known" correct labels.
predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
labels = fullPredictions.select("label").rdd.map(lambda x: x[0])


In [9]:
# Zip them together
predictionAndLabel = predictions.zip(labels).collect()

# Print out the predicted and actual values for each point
for prediction in predictionAndLabel:
  print(prediction)


# Stop the session
spark.stop()

(-1.7400004370474282, -2.54)
(-1.913711216300565, -2.36)
(-1.5952414543364808, -2.27)
(-1.638669149149765, -2.26)
(-1.385340929405607, -2.12)
(-1.4721963190321756, -2.0)
(-1.443244522489986, -1.96)
(-1.4142927259477964, -1.94)
(-1.3781029802700597, -1.88)
(-1.341913234592323, -1.8)
(-1.2622957941013018, -1.79)
(-1.2116301501524702, -1.77)
(-1.1971542518813754, -1.66)
(-1.341913234592323, -1.64)
(-1.2550578449657543, -1.61)
(-1.1899163027458282, -1.6)
(-1.0596332183059756, -1.58)
(-1.2043922010169228, -1.58)
(-1.1392506587969966, -1.57)
(-1.008967574357144, -1.48)
(-1.1609645062036387, -1.42)
(-1.030681421763786, -1.36)
(-1.0813470657126176, -1.33)
(-0.8352567951040071, -1.3)
(-0.8642085916461966, -1.3)
(-1.0596332183059756, -1.3)
(-1.0741091165770702, -1.29)
(-0.8569706425106491, -1.26)
(-0.8786844899172913, -1.25)
(-0.9655398795438598, -1.25)
(-0.9076362864594808, -1.17)
(-0.9148742355950281, -1.17)
(-0.733925507206344, -1.11)
(-0.9003983373239334, -1.11)
(-0.791829100290723, -1.1)
(-