In [1]:
!pip install findspark
!pip install pyspark
!pip install numpy




In [2]:
# make pyspark importable as a regular library.
import findspark
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.sql import SparkSession

findspark.init()

In [3]:
#create spark session
spark = SparkSession.builder.getOrCreate()

In [4]:
# load data
data = spark.read.csv('data/boston_housing.csv', header=True, inferSchema=True)

In [5]:
# create features vector
feature_columns = data.columns[:-1] # here we omit the final column

In [6]:
assembler = VectorAssembler(inputCols=feature_columns,outputCol="features")
data_2 = assembler.transform(data)

In [7]:
# train/test split
train, test = data_2.randomSplit([0.7, 0.3])

In [8]:
# define the model

algo = LinearRegression(featuresCol="features", labelCol="medv")

In [9]:
# train the model
model = algo.fit(train)

In [10]:
# evaluation
evaluation_summary = model.evaluate(test)
evaluation_summary.meanAbsoluteError
evaluation_summary.rootMeanSquaredError
evaluation_summary.r2

0.7353727413148607

In [11]:
# predicting values
predictions = model.transform(test)
predictions.select(predictions.columns[13:]).show() # here I am filtering out some col


+----+--------------------+------------------+
|medv|            features|        prediction|
+----+--------------------+------------------+
|32.2|[0.00906,90.0,2.9...|31.373305146273673|
|22.0|[0.01096,55.0,2.2...|27.283523552653804|
|32.7|[0.01301,35.0,1.5...|29.925755421570962|
|35.4|[0.01311,90.0,1.2...| 31.28776428865545|
|50.0|[0.01381,80.0,0.4...| 40.07475133143581|
|24.5|[0.01501,80.0,2.0...| 27.69314651336287|
|32.9|[0.01778,95.0,1.4...|30.325504612569596|
|42.3|[0.02177,82.5,2.0...|36.434044146688244|
|23.9|[0.02543,55.0,3.7...|27.330348688457555|
|34.7|[0.02729,0.0,7.07...| 30.21165530846681|
|21.6|[0.02731,0.0,7.07...| 24.81572362965898|
|18.5|[0.03041,0.0,5.19...|19.963060890535534|
|31.2|[0.03049,55.0,3.7...|28.491844002452936|
|34.9|[0.0315,95.0,1.47...|29.820998857580186|
|33.4|[0.03237,0.0,2.18...|28.710318455640305|
|19.5|[0.03427,0.0,5.19...|20.969566401471187|
|19.4|[0.03466,35.0,6.0...|23.695445539194388|
|22.9|[0.03551,25.0,4.8...| 25.19126856374131|
|45.4|[0.0357