In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('boston-home').getOrCreate()

In [3]:
data=spark.read.csv("./boston_housing.csv",header=True,inferSchema=True)

In [5]:
data.printSchema()

root
 |-- crim: double (nullable = true)
 |-- zn: double (nullable = true)
 |-- indus: double (nullable = true)
 |-- chas: integer (nullable = true)
 |-- nox: double (nullable = true)
 |-- rm: double (nullable = true)
 |-- age: double (nullable = true)
 |-- dis: double (nullable = true)
 |-- rad: integer (nullable = true)
 |-- tax: integer (nullable = true)
 |-- ptratio: double (nullable = true)
 |-- b: double (nullable = true)
 |-- lstat: double (nullable = true)
 |-- medv: double (nullable = true)



In [14]:
data.describe(data.columns).toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
crim,506,3.6135235573122535,8.601545105332491,0.00632,88.9762
zn,506,11.363636363636363,23.32245299451514,0.0,100.0
indus,506,11.136778656126504,6.860352940897589,0.46,27.74
chas,506,0.0691699604743083,0.2539940413404101,0,1
nox,506,0.5546950592885372,0.11587767566755584,0.385,0.871
rm,506,6.284634387351787,0.7026171434153232,3.561,8.78
age,506,68.57490118577078,28.148861406903595,2.9,100.0
dis,506,3.795042687747034,2.10571012662761,1.1296,12.1265
rad,506,9.549407114624506,8.707259384239366,1,24


# Getting the feature columns

In [21]:
features_columns=data.columns[:-1] # except the last column

In [22]:
print(features_columns)

['crim', 'zn', 'indus', 'chas', 'nox', 'rm', 'age', 'dis', 'rad', 'tax', 'ptratio', 'b', 'lstat']


# Getting the vector assembler

In [20]:
from pyspark.ml.feature import VectorAssembler

In [25]:
assembler=VectorAssembler(inputCols=features_columns,outputCol="features")

In [30]:
transformed_data_with_feature_and_label=assembler.transform(data)

# Train test split data

In [32]:
train, test = transformed_data_with_feature_and_label.randomSplit([0.7,0.3])

In [38]:
train.toPandas().transpose()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,338,339,340,341,342,343,344,345,346,347
crim,0.00906,0.01301,0.0136,0.01381,0.01432,0.01439,0.01501,0.01501,0.01538,0.01709,...,22.5971,23.6482,24.3938,25.0461,25.9406,28.6558,37.6619,38.3518,41.5292,88.9762
zn,90,35,75,80,100,60,80,90,90,90,...,0,0,0,0,0,0,0,0,0,0
indus,2.97,1.52,4,0.46,1.32,2.93,2.01,1.21,3.75,2.02,...,18.1,18.1,18.1,18.1,18.1,18.1,18.1,18.1,18.1,18.1
chas,0,0,0,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
nox,0.4,0.442,0.41,0.422,0.411,0.401,0.435,0.401,0.394,0.41,...,0.7,0.671,0.7,0.693,0.679,0.597,0.679,0.693,0.693,0.671
rm,7.088,7.241,5.888,7.875,6.816,6.604,6.635,7.923,7.454,6.728,...,5,6.38,4.652,5.987,5.304,5.155,6.202,5.453,5.531,6.968
age,20.8,49.3,47.6,32,40.5,18.8,29.7,24.8,34.2,36.1,...,89.5,96.2,100,100,89.1,100,78.7,100,85.4,91.9
dis,7.3073,7.0379,7.3197,5.6484,8.3248,6.2196,8.344,5.885,6.3361,12.1265,...,1.5184,1.3861,1.4672,1.5888,1.6475,1.5894,1.8629,1.4896,1.6074,1.4165
rad,1,1,3,4,5,1,4,1,3,5,...,24,24,24,24,24,24,24,24,24,24
tax,285,284,469,255,256,265,280,198,244,187,...,666,666,666,666,666,666,666,666,666,666


In [39]:
train.count()

348

In [40]:
test.count()

158

# Using Spark ML Regression Model

In [41]:
from pyspark.ml.regression import LinearRegression

In [46]:
linearRegressionAlgorithm=LinearRegression(featuresCol="features",labelCol="medv")

In [47]:
model=linearRegressionAlgorithm.fit(train)

In [48]:
evaluvation_summary=model.evaluate(test)

In [52]:
predictions=model.transform(test)

In [63]:
predictions.select(predictions.columns[13:]).show() # here I am filtering out some columns just for the figure to fit


+----+--------------------+------------------+
|medv|            features|        prediction|
+----+--------------------+------------------+
|24.0|[0.00632,18.0,2.3...| 30.90720546820699|
|22.0|[0.01096,55.0,2.2...| 27.79739337041719|
|35.4|[0.01311,90.0,1.2...|    31.47279484776|
|32.9|[0.01778,95.0,1.4...|31.281917417484838|
|31.1|[0.02187,60.0,2.9...| 32.41260619141315|
|23.9|[0.02543,55.0,3.7...| 28.48769142522596|
|21.6|[0.02731,0.0,7.07...| 25.36311123638822|
|25.0|[0.02875,28.0,15....|29.135875961325876|
|26.6|[0.02899,40.0,1.2...|22.191871399937757|
|28.7|[0.02985,0.0,2.18...|25.624454163395058|
|20.6|[0.03306,0.0,5.19...|21.964725603318588|
|19.5|[0.03427,0.0,5.19...|20.003836958310696|
|28.5|[0.03502,80.0,4.9...| 34.03326970904453|
|22.0|[0.03537,34.0,6.0...|29.308107041274084|
|20.9|[0.03548,80.0,3.6...| 21.97570413334767|
|45.4|[0.03578,20.0,3.3...| 39.51064839497955|
|27.9|[0.03615,80.0,4.9...|32.324875191506095|
|34.6|[0.03768,80.0,1.5...|35.569763867356244|
|23.2|[0.0387