In [56]:
import findspark
findspark.init('/home/nick/spark-3.0.1-bin-hadoop2.7')

from pyspark.sql import SparkSession
from pyspark.sql.functions import count, when, isnan, isnull
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, DecisionTreeClassifier, GBTClassifier, DecisionTreeClassificationModel
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

spark = SparkSession.builder.appName('Dog Food').getOrCreate()

In [45]:
data = spark.read.csv('Tree_Methods/dog_food.csv', inferSchema=True, header=True)

In [46]:
# Check for missing data
data.select([count(when(isnan(c), c)).alias(c) for c in data.columns]).show()
data.select([count(when(isnull(c), c)).alias(c) for c in data.columns]).show()

+---+---+---+---+-------+
|  A|  B|  C|  D|Spoiled|
+---+---+---+---+-------+
|  0|  0|  0|  0|      0|
+---+---+---+---+-------+

+---+---+---+---+-------+
|  A|  B|  C|  D|Spoiled|
+---+---+---+---+-------+
|  0|  0|  0|  0|      0|
+---+---+---+---+-------+



In [47]:
data.printSchema()
data.show(5)
data.describe().show()

root
 |-- A: integer (nullable = true)
 |-- B: integer (nullable = true)
 |-- C: double (nullable = true)
 |-- D: integer (nullable = true)
 |-- Spoiled: double (nullable = true)

+---+---+----+---+-------+
|  A|  B|   C|  D|Spoiled|
+---+---+----+---+-------+
|  4|  2|12.0|  3|    1.0|
|  5|  6|12.0|  7|    1.0|
|  6|  2|13.0|  6|    1.0|
|  4|  2|12.0|  1|    1.0|
|  4|  2|12.0|  3|    1.0|
+---+---+----+---+-------+
only showing top 5 rows

+-------+------------------+------------------+------------------+------------------+-------------------+
|summary|                 A|                 B|                 C|                 D|            Spoiled|
+-------+------------------+------------------+------------------+------------------+-------------------+
|  count|               490|               490|               490|               490|                490|
|   mean|  5.53469387755102| 5.504081632653061| 9.126530612244897| 5.579591836734694| 0.2857142857142857|
| stddev|2.95152042343

In [48]:
data.columns

['A', 'B', 'C', 'D', 'Spoiled']

In [49]:
feature_columns = data.columns[:-1]
label_column = data.columns[-1]

In [50]:
assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')
output = assembler.transform(data)

In [58]:
dtc_model = DecisionTreeClassifier(labelCol='Spoiled').fit(output)
results = dtc_model.transform(output)
results.show(5)

multiclass_eval = MulticlassClassificationEvaluator(labelCol='Spoiled', metricName='accuracy')
multiclass_eval.evaluate(results)

+---+---+----+---+-------+------------------+-------------+-----------+----------+
|  A|  B|   C|  D|Spoiled|          features|rawPrediction|probability|prediction|
+---+---+----+---+-------+------------------+-------------+-----------+----------+
|  4|  2|12.0|  3|    1.0|[4.0,2.0,12.0,3.0]|   [0.0,94.0]|  [0.0,1.0]|       1.0|
|  5|  6|12.0|  7|    1.0|[5.0,6.0,12.0,7.0]|   [0.0,94.0]|  [0.0,1.0]|       1.0|
|  6|  2|13.0|  6|    1.0|[6.0,2.0,13.0,6.0]|   [0.0,94.0]|  [0.0,1.0]|       1.0|
|  4|  2|12.0|  1|    1.0|[4.0,2.0,12.0,1.0]|   [0.0,94.0]|  [0.0,1.0]|       1.0|
|  4|  2|12.0|  3|    1.0|[4.0,2.0,12.0,3.0]|   [0.0,94.0]|  [0.0,1.0]|       1.0|
+---+---+----+---+-------+------------------+-------------+-----------+----------+
only showing top 5 rows



0.9857142857142858

In [60]:
# Now use Random Forest
rfc_model = RandomForestClassifier(labelCol='Spoiled').fit(output)
results = rfc_model.transform(output)
results.show(5)

multiclass_eval = MulticlassClassificationEvaluator(labelCol='Spoiled', metricName='accuracy')
multiclass_eval.evaluate(results)

+---+---+----+---+-------+------------------+--------------------+--------------------+----------+
|  A|  B|   C|  D|Spoiled|          features|       rawPrediction|         probability|prediction|
+---+---+----+---+-------+------------------+--------------------+--------------------+----------+
|  4|  2|12.0|  3|    1.0|[4.0,2.0,12.0,3.0]|[0.77037459556500...|[0.03851872977825...|       1.0|
|  5|  6|12.0|  7|    1.0|[5.0,6.0,12.0,7.0]|[0.08734733734733...|[0.00436736686736...|       1.0|
|  6|  2|13.0|  6|    1.0|[6.0,2.0,13.0,6.0]|[0.83770219550129...|[0.04188510977506...|       1.0|
|  4|  2|12.0|  1|    1.0|[4.0,2.0,12.0,1.0]|[0.77037459556500...|[0.03851872977825...|       1.0|
|  4|  2|12.0|  3|    1.0|[4.0,2.0,12.0,3.0]|[0.77037459556500...|[0.03851872977825...|       1.0|
+---+---+----+---+-------+------------------+--------------------+--------------------+----------+
only showing top 5 rows



0.9857142857142858

In [63]:
dtc_model.featureImportances

SparseVector(4, {1: 0.0019, 2: 0.9832, 3: 0.0149})

In [64]:
rfc_model.featureImportances

SparseVector(4, {0: 0.0157, 1: 0.0157, 2: 0.9522, 3: 0.0165})

In [None]:
# Ingredient C is causing the issues.