In [1]:
import findspark
findspark.init('/home/dangkhoa/spark-2.3.1-bin-hadoop2.7')

## Session

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Linear_Regression').getOrCreate()

## Load dataset

In [3]:
dataset = spark.read.format("libsvm").load("sample_linear_regression_data.txt")

dataset.show(5)

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
| -9.490009878824548|(10,[0,1,2,3,4,5,...|
| 0.2577820163584905|(10,[0,1,2,3,4,5,...|
| -4.438869807456516|(10,[0,1,2,3,4,5,...|
|-19.782762789614537|(10,[0,1,2,3,4,5,...|
| -7.966593841555266|(10,[0,1,2,3,4,5,...|
+-------------------+--------------------+
only showing top 5 rows



## Train/Test Splits

In [4]:
# Split dataset
train_data, test_data = dataset.randomSplit([0.7,0.3])

In [5]:
train_data.show(5)

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
|-28.571478869743427|(10,[0,1,2,3,4,5,...|
|-28.046018037776633|(10,[0,1,2,3,4,5,...|
|-26.736207182601724|(10,[0,1,2,3,4,5,...|
|-23.487440120936512|(10,[0,1,2,3,4,5,...|
|-22.949825936196074|(10,[0,1,2,3,4,5,...|
+-------------------+--------------------+
only showing top 5 rows



In [6]:
test_data.show(5)

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
|-26.805483428483072|(10,[0,1,2,3,4,5,...|
| -23.51088409032297|(10,[0,1,2,3,4,5,...|
|-22.837460416919342|(10,[0,1,2,3,4,5,...|
|-21.432387764165806|(10,[0,1,2,3,4,5,...|
|-20.212077258958672|(10,[0,1,2,3,4,5,...|
+-------------------+--------------------+
only showing top 5 rows



## Linear Regression

In [7]:
from pyspark.ml.regression import LinearRegression

lr = LinearRegression(
    featuresCol='features',
    labelCol='label',
    predictionCol='prediction')

In [8]:
# Train
model = lr.fit(train_data)

## Model Summary

In [9]:
# Coefficients and Intercept
print("Coefficients: {}".format(str(model.coefficients))) # For each feature...
print('\n')
print("Intercept:{}".format(str(model.intercept)))

Coefficients: [-0.548341137665528,0.5361261226664056,-0.36310971462438923,2.456256023547675,0.5118479947267605,2.010119946798614,-0.5671473974903186,-0.07283588047135169,-0.12926892094015197,0.8705736028919786]


Intercept:0.08952091951949728


In [10]:
trainingSummary = model.summary

# train residual errors
trainingSummary.residuals.show(5)

# train RMSE and R2 score
print('')
print("RMSE: {}".format(trainingSummary.rootMeanSquaredError))
print("r2: {}".format(trainingSummary.r2))

+-------------------+
|          residuals|
+-------------------+
| -27.69969645706778|
|-26.579507980772025|
|-22.753262454952807|
|-22.963403196681888|
|-25.806423389242333|
+-------------------+
only showing top 5 rows


RMSE: 9.70796713857534
r2: 0.037592824639128164


## Evaluate test data

In [11]:
test_results = model.evaluate(test_data)

# test residual errors
test_results.residuals.show(5)

# test RMSE
print('')
print("RMSE: {}".format(test_results.rootMeanSquaredError))

+-------------------+
|          residuals|
+-------------------+
|-26.663107206813198|
|-22.668727421450107|
| -19.69905482480412|
| -22.27309214271235|
|-22.993800015190374|
+-------------------+
only showing top 5 rows


RMSE: 11.15283465228093


## Prediction on test data

In [12]:
X = test_data.select('features')

predictions = model.transform(X)

In [13]:
# Predict
predictions.show(5)

+--------------------+--------------------+
|            features|          prediction|
+--------------------+--------------------+
|(10,[0,1,2,3,4,5,...|-0.14237622166987565|
|(10,[0,1,2,3,4,5,...| -0.8421566688728634|
|(10,[0,1,2,3,4,5,...|  -3.138405592115222|
|(10,[0,1,2,3,4,5,...|  0.8407043785465442|
|(10,[0,1,2,3,4,5,...|  2.7817227562317006|
+--------------------+--------------------+
only showing top 5 rows



In [14]:
# Real label
test_data.show(5)

+-------------------+--------------------+
|              label|            features|
+-------------------+--------------------+
|-26.805483428483072|(10,[0,1,2,3,4,5,...|
| -23.51088409032297|(10,[0,1,2,3,4,5,...|
|-22.837460416919342|(10,[0,1,2,3,4,5,...|
|-21.432387764165806|(10,[0,1,2,3,4,5,...|
|-20.212077258958672|(10,[0,1,2,3,4,5,...|
+-------------------+--------------------+
only showing top 5 rows

