In [1]:
!pip install --ignore-install -q pyspark
!pip install --ignore-install -q findspark

import findspark
findspark.init()

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('lr_example').getOrCreate()
from pyspark.ml.regression import LinearRegression

In [3]:
# Load data
all_data = spark.read.format("libsvm").option("numFeatures","10").load('sample_linear_regression_data.txt')

# Split into training data and test data
train_data, test_data = all_data.randomSplit([0.7,0.3])
train_data.show()
test_data.show()

unlabeled_data = test_data.select("features")
unlabeled_data.show()

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
|-28.571478869743427|(10,[0,1,2,3,4,5,...|
|-28.046018037776633|(10,[0,1,2,3,4,5,...|
|-26.805483428483072|(10,[0,1,2,3,4,5,...|
|-26.736207182601724|(10,[0,1,2,3,4,5,...|
| -23.51088409032297|(10,[0,1,2,3,4,5,...|
|-22.949825936196074|(10,[0,1,2,3,4,5,...|
|-22.837460416919342|(10,[0,1,2,3,4,5,...|
|-21.432387764165806|(10,[0,1,2,3,4,5,...|
|-20.212077258958672|(10,[0,1,2,3,4,5,...|
|-20.057482615789212|(10,[0,1,2,3,4,5,...|
|-19.884560774273424|(10,[0,1,2,3,4,5,...|
|-19.872991038068406|(10,[0,1,2,3,4,5,...|
|-19.782762789614537|(10,[0,1,2,3,4,5,...|
|-19.402336030214553|(10,[0,1,2,3,4,5,...|
| -19.16829262296376|(10,[0,1,2,3,4,5,...|
|-18.845922472898582|(10,[0,1,2,3,4,5,...|
| -18.27521356600463|(10,[0,1,2,3,4,5,...|
|-17.494200356883344|(10,[0,1,2,3,4,5,...|
|-17.428674570939506|(10,[0,1,2,3,4,5,...|
| -17.32672073267595|(10,[0,1,2,3,4,5,...|
+----------

In [4]:
lr = LinearRegression(featuresCol='features',labelCol='label',predictionCol='prediction')
# Fit the model
lr_model = lr.fit(train_data)

In [5]:
# Print the coefficients and intercept training data
print("Coefficients: {}".format(str(lr_model.coefficients)))
print("Intercept: {}".format(str(lr_model.intercept)))

Coefficients: [-0.7342399437195134,0.9858148614728693,-0.33524894479823786,2.4903824460742605,0.027506681467314838,0.8796618872599157,0.5815050301028878,-0.06374800915801716,-0.886096327899207,0.45196963168517956]
Intercept: 0.10277636014353966


In [6]:
# Testing result
test_result = lr_model.evaluate(test_data)
test_result.residuals.show() # this is the difference
print("RMSE: {}".format(test_result.rootMeanSquaredError)) # between prediction and the actual - lower means better

+-------------------+
|          residuals|
+-------------------+
|-24.658777747359522|
|-19.899127783333245|
|-17.145449589796733|
|-15.506871293546713|
|-16.082357874643147|
|-15.900387335818884|
|-15.123350377899033|
| -19.04342873217747|
|-13.118879209010803|
|-11.898596740186512|
|-15.018404516411946|
|-12.656466306090525|
|-10.940337187644898|
|-15.313051780640542|
| -9.554371629940118|
| -13.17633815687582|
|-15.428490609635555|
|-11.967895428058299|
|-13.336256224572962|
|-13.140423569891066|
+-------------------+
only showing top 20 rows

RMSE: 9.716757526064667


In [8]:
# Prediction
predictions = lr_model.transform(unlabeled_data)
predictions.show()
spark.stop()

+--------------------+--------------------+
|            features|          prediction|
+--------------------+--------------------+
|(10,[0,1,2,3,4,5,...|  1.1713376264230104|
|(10,[0,1,2,3,4,5,...| 0.23180916796152456|
|(10,[0,1,2,3,4,5,...|  -0.658176598867783|
|(10,[0,1,2,3,4,5,...|  -1.212225540058377|
|(10,[0,1,2,3,4,5,...|-0.06899147663396418|
|(10,[0,1,2,3,4,5,...|-0.18527170520260608|
|(10,[0,1,2,3,4,5,...| -0.8281621878955391|
|(10,[0,1,2,3,4,5,...|  3.2627436995541683|
|(10,[0,1,2,3,4,5,...| -2.2299919463684508|
|(10,[0,1,2,3,4,5,...| -3.4123838492297764|
|(10,[0,1,2,3,4,5,...|-0.03807845813048635|
|(10,[0,1,2,3,4,5,...| -0.7641284698002314|
|(10,[0,1,2,3,4,5,...| -2.2129984187206317|
|(10,[0,1,2,3,4,5,...|  2.3908286772701217|
|(10,[0,1,2,3,4,5,...|  -3.218855369311078|
|(10,[0,1,2,3,4,5,...|  0.6177623680196308|
|(10,[0,1,2,3,4,5,...|  2.9370485320891415|
|(10,[0,1,2,3,4,5,...|  0.3521201630426726|
|(10,[0,1,2,3,4,5,...|  1.9044538590325018|
|(10,[0,1,2,3,4,5,...|   1.81200