# Tree Method the Random Forest Classification

Scenario:

You've been hired by a dog food company to try to predict why some batches of their dog food are spoiling much quicker than intended! Unfortunately this Dog Food company hasn't upgraded to the latest machinery, meaning that the amounts of the five preservative chemicals they are using can vary a lot, but which is the chemical that has the strongest effect? The dog food company first mixes up a batch of preservative that contains 4 different preservative chemicals (A,B,C,D) and then is completed with a "filler" chemical. The food scientists beelive one of the A,B,C, or D preservatives is causing the problem, but need your help to figure out which one! Use Machine Learning with RF to find out which parameter had the most predicitive power, thus finding out which chemical causes the early spoiling! So create a model and then find out how you can decide which chemical is the problem!

Data description:

In [1]:
# Basic imports
from pyspark.sql import SparkSession

In [2]:
# Creation of a spark session
spark = SparkSession.builder.appName('rf_project').getOrCreate()

In [3]:
# Reading the data
data = spark.read.csv('dog_food.csv', inferSchema= True, header= True)

In [4]:
data.show()

+---+---+----+---+-------+
|  A|  B|   C|  D|Spoiled|
+---+---+----+---+-------+
|  4|  2|12.0|  3|    1.0|
|  5|  6|12.0|  7|    1.0|
|  6|  2|13.0|  6|    1.0|
|  4|  2|12.0|  1|    1.0|
|  4|  2|12.0|  3|    1.0|
| 10|  3|13.0|  9|    1.0|
|  8|  5|14.0|  5|    1.0|
|  5|  8|12.0|  8|    1.0|
|  6|  5|12.0|  9|    1.0|
|  3|  3|12.0|  1|    1.0|
|  9|  8|11.0|  3|    1.0|
|  1| 10|12.0|  3|    1.0|
|  1|  5|13.0| 10|    1.0|
|  2| 10|12.0|  6|    1.0|
|  1| 10|11.0|  4|    1.0|
|  5|  3|12.0|  2|    1.0|
|  4|  9|11.0|  8|    1.0|
|  5|  1|11.0|  1|    1.0|
|  4|  9|12.0| 10|    1.0|
|  5|  8|10.0|  9|    1.0|
+---+---+----+---+-------+
only showing top 20 rows



As we can see the dataset is very simple, we don't need to normalize the variables.

# Data preparation

Now the first thing to do is to vectorize the features

In [6]:
# Import the VectorAssembler
from pyspark.ml.feature import VectorAssembler

In [7]:
# Instancing the object
# Here i take care about to give the right columns in input
assembler = VectorAssembler(inputCols=['A', 'B', 'C', 'D'], outputCol='features')

# Transforming the dataset
output = assembler.transform(data)

In [8]:
# Checking if the "features" column has been created
output.printSchema()

root
 |-- A: integer (nullable = true)
 |-- B: integer (nullable = true)
 |-- C: double (nullable = true)
 |-- D: integer (nullable = true)
 |-- Spoiled: double (nullable = true)
 |-- features: vector (nullable = true)



At this step we can divide our "output" dataset into training and test

In [9]:
# Creation of training and test set
training_data, test_data = output.randomSplit([0.7, 0.3])

In [10]:
# Import the random forest classifier
from pyspark.ml.classification import RandomForestClassifier

In [11]:
# Instancing the object
# Here I pay attention to give the rigth column as parameter
rfc = RandomForestClassifier(labelCol='Spoiled', featuresCol='features')

At this step we create the model on the training data

In [12]:
# Fitting the model on the training data
rfc_model = rfc.fit(training_data)

Now that we have the model we can use it on the test data

In [14]:
# Transforming the test data
rfc_predict = rfc_model.transform(test_data)

In [15]:
# Look up how is the prediction
rfc_predict.show()

+---+---+----+---+-------+-------------------+--------------------+--------------------+----------+
|  A|  B|   C|  D|Spoiled|           features|       rawPrediction|         probability|prediction|
+---+---+----+---+-------+-------------------+--------------------+--------------------+----------+
|  1|  1|12.0|  2|    1.0| [1.0,1.0,12.0,2.0]|          [1.0,19.0]|         [0.05,0.95]|       1.0|
|  1|  1|12.0|  4|    1.0| [1.0,1.0,12.0,4.0]|          [1.0,19.0]|         [0.05,0.95]|       1.0|
|  1|  1|13.0|  3|    1.0| [1.0,1.0,13.0,3.0]|          [1.0,19.0]|         [0.05,0.95]|       1.0|
|  1|  2| 9.0|  4|    0.0|  [1.0,2.0,9.0,4.0]|          [18.0,2.0]|           [0.9,0.1]|       0.0|
|  1|  3| 9.0|  8|    0.0|  [1.0,3.0,9.0,8.0]|          [19.0,1.0]|         [0.95,0.05]|       0.0|
|  1|  4| 8.0|  5|    0.0|  [1.0,4.0,8.0,5.0]|          [20.0,0.0]|           [1.0,0.0]|       0.0|
|  1|  4| 9.0|  3|    0.0|  [1.0,4.0,9.0,3.0]|      [19.775,0.225]|[0.98874999999999...|       0.0|


At a first glance if we compare the "Spoiled" column with the "prediction" column seems that our model works good!
Now let's evaluate our model.

In [16]:
# Import
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [17]:
# Instancing the object
# In this case I pay attention to give the right column to evaluate end the metric we want to use
acc_evaluator = MulticlassClassificationEvaluator(labelCol="Spoiled", predictionCol="prediction", 
                                                  metricName="accuracy")

In case of more info about the accuracy: https://spark.apache.org/docs/2.2.0/mllib-evaluation-metrics.html#multiclass-classification

In [18]:
# Using the evaluator on the dataframe
acc_evaluator.evaluate(rfc_predict)

0.9724137931034482

Nice result! Our model has the 97% of accuracy!

# Discovering 

Ok now comes the most important part, we have to discove what preservatives is causing the problem.
It seems a difficult task but in contrast it's very easy, we can solve it with the use of a simple function.

In [19]:
rfc_model.featureImportances

SparseVector(4, {0: 0.0285, 1: 0.0324, 2: 0.9008, 3: 0.0382})

The result is a sparse vector in a form of a dictionary. So if we look to the key:value pairs of the dictionary the most high value is 2:0.9008 that correspond to the preservative that cause the problem. If we want to understad to wich column it coincides it's very easy, it is the third column on the right, the "C" column.

In [20]:
data.show()

+---+---+----+---+-------+
|  A|  B|   C|  D|Spoiled|
+---+---+----+---+-------+
|  4|  2|12.0|  3|    1.0|
|  5|  6|12.0|  7|    1.0|
|  6|  2|13.0|  6|    1.0|
|  4|  2|12.0|  1|    1.0|
|  4|  2|12.0|  3|    1.0|
| 10|  3|13.0|  9|    1.0|
|  8|  5|14.0|  5|    1.0|
|  5|  8|12.0|  8|    1.0|
|  6|  5|12.0|  9|    1.0|
|  3|  3|12.0|  1|    1.0|
|  9|  8|11.0|  3|    1.0|
|  1| 10|12.0|  3|    1.0|
|  1|  5|13.0| 10|    1.0|
|  2| 10|12.0|  6|    1.0|
|  1| 10|11.0|  4|    1.0|
|  5|  3|12.0|  2|    1.0|
|  4|  9|11.0|  8|    1.0|
|  5|  1|11.0|  1|    1.0|
|  4|  9|12.0| 10|    1.0|
|  5|  8|10.0|  9|    1.0|
+---+---+----+---+-------+
only showing top 20 rows



# Thanks for your attention!