In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import VectorAssembler, StringIndexer,
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [None]:
sparkSession = SparkSession.builder.appName("DecisionTree").getOrCreate()

In [None]:
df = sparkSession.read.csv('/content/drive/MyDrive/Colab Notebooks/data/housing.csv',
                           header=True,
                           inferSchema=True)
df.printSchema()
df.show(5)

root
 |-- longitude: double (nullable = true)
 |-- latitude: double (nullable = true)
 |-- housing_median_age: double (nullable = true)
 |-- total_rooms: double (nullable = true)
 |-- total_bedrooms: double (nullable = true)
 |-- population: double (nullable = true)
 |-- households: double (nullable = true)
 |-- median_income: double (nullable = true)
 |-- median_house_value: double (nullable = true)
 |-- ocean_proximity: string (nullable = true)

+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+
|longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|ocean_proximity|
+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+
|  -122.23|   37.88|              41.0|      880.0|         129.0|     322.0|     126.0|       8.3252|          452600.0|       NEAR B

In [None]:
labelIndexer = StringIndexer(inputCol = 'ocean_proximity', outputCol = 'label_ocean_proximity')
df = labelIndexer.fit(df).transform(df)
df.show(5)

+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+---------------------+
|longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|ocean_proximity|label_ocean_proximity|
+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+---------------------+
|  -122.23|   37.88|              41.0|      880.0|         129.0|     322.0|     126.0|       8.3252|          452600.0|       NEAR BAY|                  3.0|
|  -122.22|   37.86|              21.0|     7099.0|        1106.0|    2401.0|    1138.0|       8.3014|          358500.0|       NEAR BAY|                  3.0|
|  -122.24|   37.85|              52.0|     1467.0|         190.0|     496.0|     177.0|       7.2574|          352100.0|       NEAR BAY|                  3.0|
|  -122.25|   37.85|              52.0| 

In [None]:
featureCols = df.columns
featureCols.remove('ocean_proximity')
featureCols.remove('label_ocean_proximity')
assembler = VectorAssembler(inputCols = featureCols, outputCol = 'features', handleInvalid = 'skip')
if 'features' in df.columns:
  df = df.drop('features')
#End if
df = assembler.transform(df, {assembler.inputCols:featureCols, assembler.outputCol : 'features'})
df.show(5)

+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+---------------------+--------------------+
|longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|ocean_proximity|label_ocean_proximity|            features|
+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+---------------------+--------------------+
|  -122.23|   37.88|              41.0|      880.0|         129.0|     322.0|     126.0|       8.3252|          452600.0|       NEAR BAY|                  3.0|[-122.23,37.88,41...|
|  -122.22|   37.86|              21.0|     7099.0|        1106.0|    2401.0|    1138.0|       8.3014|          358500.0|       NEAR BAY|                  3.0|[-122.22,37.86,21...|
|  -122.24|   37.85|              52.0|     1467.0|         190.0|     496.0|     177.0|       

In [None]:
ds = df.select('features','label_ocean_proximity')
ds.show(5)

+--------------------+---------------------+
|            features|label_ocean_proximity|
+--------------------+---------------------+
|[-122.23,37.88,41...|                  3.0|
|[-122.22,37.86,21...|                  3.0|
|[-122.24,37.85,52...|                  3.0|
|[-122.25,37.85,52...|                  3.0|
|[-122.25,37.85,52...|                  3.0|
+--------------------+---------------------+
only showing top 5 rows



In [None]:
train, test = ds.randomSplit([0.8, 0.2])
print(f'train: {train.count()}, test: {test.count()}')

train: 16210, test: 4223


In [None]:
decisionTreeClassifier = DecisionTreeClassifier(labelCol = 'label_ocean_proximity', featuresCol = 'features')
model = decisionTreeClassifier.fit(train)
#

In [None]:
predictions = model.transform(test)
predictions.show(5)

+--------------------+---------------------+--------------------+--------------------+----------+
|            features|label_ocean_proximity|       rawPrediction|         probability|prediction|
+--------------------+---------------------+--------------------+--------------------+----------+
|[-124.3,41.8,19.0...|                  2.0|[294.0,133.0,97.0...|[0.55893536121673...|       0.0|
|[-124.21,40.75,32...|                  2.0|[294.0,133.0,97.0...|[0.55893536121673...|       0.0|
|[-124.21,41.75,20...|                  2.0|[294.0,133.0,97.0...|[0.55893536121673...|       0.0|
|[-124.17,40.78,39...|                  2.0|[294.0,133.0,97.0...|[0.55893536121673...|       0.0|
|[-124.17,40.79,43...|                  2.0|[294.0,133.0,97.0...|[0.55893536121673...|       0.0|
+--------------------+---------------------+--------------------+--------------------+----------+
only showing top 5 rows



In [None]:
evaluator = MulticlassClassificationEvaluator(labelCol = 'label_ocean_proximity', predictionCol = 'prediction', metricName = 'accuracy')
accuracy = evaluator.evaluate(predictions)
print(f'accuracy: {accuracy}')

accuracy: 0.8681032441392376


In [None]:
featureImportances = model.featureImportances
for i in range(len(featureCols)):
  print(f'{featureCols[i]}: {featureImportances[i]}')
#End for

longitude: 0.5311598653890058
latitude: 0.44321342009367054
housing_median_age: 0.001288309834539509
total_rooms: 0.0
total_bedrooms: 0.0
population: 0.0
households: 0.0
median_income: 0.0
median_house_value: 0.024338404682784132


In [None]:
print(model.toDebugString)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_524ec32b49ec, depth=5, numNodes=49, numClasses=5, numFeatures=9
  If (feature 1 <= 34.455)
   If (feature 0 <= -117.67500000000001)
    If (feature 0 <= -118.955)
     If (feature 2 <= 36.5)
      Predict: 2.0
     Else (feature 2 > 36.5)
      If (feature 0 <= -119.32499999999999)
       Predict: 0.0
      Else (feature 0 > -119.32499999999999)
       Predict: 2.0
    Else (feature 0 > -118.955)
     If (feature 1 <= 33.845)
      If (feature 0 <= -118.095)
       Predict: 2.0
      Else (feature 0 > -118.095)
       Predict: 0.0
     Else (feature 1 > 33.845)
      If (feature 0 <= -117.845)
       Predict: 0.0
      Else (feature 0 > -117.845)
       Predict: 1.0
   Else (feature 0 > -117.67500000000001)
    If (feature 1 <= 33.724999999999994)
     If (feature 0 <= -116.995)
      If (feature 1 <= 32.974999999999994)
       Predict: 2.0
      Else (feature 1 > 32.974999999999994)
       Predict: 0.0
     Else (feature 0 > 