In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
   .master("local[*]") \
    .appName("ace") \
    .config("spark.driver.memory", "8g") \
    .config("spark.driver.maxResultSize", "10g") \
    .config("spark.network.timeout", "500s") \
    .config("spark.rpc.lookupTimeout", "500s") \
    .config("spark.hadoop.fs.s3a.connection.maximum", "100") \
    .config("spark.sql.execution.arrow.enabled", "true") \
    .config("spark.sql.session.timeZone", "Pacific/Auckland") \
    .getOrCreate()

In [2]:
data = spark.read.csv('data/boston_housing.csv', header=True, inferSchema=True)

In [3]:
data.show()

+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+
|   crim|  zn|indus|chas|  nox|   rm|  age|   dis|rad|tax|ptratio|     b|lstat|medv|
+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+
|0.00632|18.0| 2.31|   0|0.538|6.575| 65.2|  4.09|  1|296|   15.3| 396.9| 4.98|24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421| 78.9|4.9671|  2|242|   17.8| 396.9| 9.14|21.6|
|0.02729| 0.0| 7.07|   0|0.469|7.185| 61.1|4.9671|  2|242|   17.8|392.83| 4.03|34.7|
|0.03237| 0.0| 2.18|   0|0.458|6.998| 45.8|6.0622|  3|222|   18.7|394.63| 2.94|33.4|
|0.06905| 0.0| 2.18|   0|0.458|7.147| 54.2|6.0622|  3|222|   18.7| 396.9| 5.33|36.2|
|0.02985| 0.0| 2.18|   0|0.458| 6.43| 58.7|6.0622|  3|222|   18.7|394.12| 5.21|28.7|
|0.08829|12.5| 7.87|   0|0.524|6.012| 66.6|5.5605|  5|311|   15.2| 395.6|12.43|22.9|
|0.14455|12.5| 7.87|   0|0.524|6.172| 96.1|5.9505|  5|311|   15.2| 396.9|19.15|27.1|
|0.21124|12.5| 7.87|   0|0.524|5.631|100.0|6.0821|  5|311|   15.2

In [7]:
feature_columns = data.columns[:-1]

In [8]:
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=feature_columns,outputCol="features")

In [9]:
data_2 = assembler.transform(data)
data_2.show()

+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+--------------------+
|   crim|  zn|indus|chas|  nox|   rm|  age|   dis|rad|tax|ptratio|     b|lstat|medv|            features|
+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+--------------------+
|0.00632|18.0| 2.31|   0|0.538|6.575| 65.2|  4.09|  1|296|   15.3| 396.9| 4.98|24.0|[0.00632,18.0,2.3...|
|0.02731| 0.0| 7.07|   0|0.469|6.421| 78.9|4.9671|  2|242|   17.8| 396.9| 9.14|21.6|[0.02731,0.0,7.07...|
|0.02729| 0.0| 7.07|   0|0.469|7.185| 61.1|4.9671|  2|242|   17.8|392.83| 4.03|34.7|[0.02729,0.0,7.07...|
|0.03237| 0.0| 2.18|   0|0.458|6.998| 45.8|6.0622|  3|222|   18.7|394.63| 2.94|33.4|[0.03237,0.0,2.18...|
|0.06905| 0.0| 2.18|   0|0.458|7.147| 54.2|6.0622|  3|222|   18.7| 396.9| 5.33|36.2|[0.06905,0.0,2.18...|
|0.02985| 0.0| 2.18|   0|0.458| 6.43| 58.7|6.0622|  3|222|   18.7|394.12| 5.21|28.7|[0.02985,0.0,2.18...|
|0.08829|12.5| 7.87|   0|0.524|6.012| 66.6|5.5

In [10]:
train, test = data_2.randomSplit([0.7, 0.3])

In [16]:
from pyspark.ml.regression import LinearRegression
algo = LinearRegression(featuresCol="features", labelCol="medv")
model = algo.fit(train)
evaluation_summary = model.evaluate(test)
print(evaluation_summary.rootMeanSquaredError)

5.920213613324541


In [21]:
predictions = model.transform(test)
predictions.select('medv', 'prediction', 'features').show()

+----+------------------+--------------------+
|medv|        prediction|            features|
+----+------------------+--------------------+
|24.0|  29.7403308411066|[0.00632,18.0,2.3...|
|32.7|30.619233275996816|[0.01301,35.0,1.5...|
|18.9|14.928593153582042|[0.0136,75.0,4.0,...|
|29.1| 31.16391628381861|[0.01439,60.0,2.9...|
|30.1|24.753532133544734|[0.01709,90.0,2.0...|
|50.0|42.957767039224066|[0.02009,95.0,2.6...|
|24.7|24.159647934029376|[0.02055,85.0,0.7...|
|25.0|28.383866397384935|[0.02875,28.0,15....|
|28.7| 25.00452730987143|[0.02985,0.0,2.18...|
|34.9| 29.67802562290653|[0.0315,95.0,1.47...|
|33.4| 28.72273521755621|[0.03237,0.0,2.18...|
|48.5| 41.73852981964769|[0.0351,95.0,2.68...|
|23.5|29.376532809908213|[0.03584,80.0,3.3...|
|27.9| 30.94260876061649|[0.03615,80.0,4.9...|
|33.3| 35.86175356535294|[0.04011,80.0,1.5...|
|28.0| 28.19615736591183|[0.04113,25.0,4.8...|
|18.2|13.496246024213594|[0.04301,80.0,1.9...|
|24.8|  30.1632247421383|[0.04417,70.0,2.2...|
|19.8| 21.756