<a href="https://colab.research.google.com/github/kccchiu/Spark_dataframe_basic/blob/main/binary_tree_with_Dog_Food.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
!pip install pyspark py4j

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Setup

In [23]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('dog-food').getOrCreate()

In [24]:
df = spark.read.csv('dog_food.csv', inferSchema=True, header=True)

In [25]:
df.printSchema()

root
 |-- A: integer (nullable = true)
 |-- B: integer (nullable = true)
 |-- C: double (nullable = true)
 |-- D: integer (nullable = true)
 |-- Spoiled: double (nullable = true)



### Create Features

In [26]:
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=['A', 'B', 'C', 'D'],
                            outputCol='features')
output = assembler.transform(df)
data = output.select('features', 'Spoiled')
train, test = data.randomSplit([0.7, 0.3], seed=123)
print(train.show(5), test.show(5))

+------------------+-------+
|          features|Spoiled|
+------------------+-------+
|[1.0,1.0,10.0,8.0]|    1.0|
|[1.0,1.0,12.0,2.0]|    1.0|
|[1.0,1.0,13.0,3.0]|    1.0|
| [1.0,2.0,9.0,1.0]|    0.0|
| [1.0,2.0,9.0,4.0]|    0.0|
+------------------+-------+
only showing top 5 rows

+------------------+-------+
|          features|Spoiled|
+------------------+-------+
|[1.0,1.0,12.0,4.0]|    1.0|
| [1.0,3.0,8.0,3.0]|    0.0|
| [1.0,4.0,8.0,1.0]|    0.0|
| [1.0,4.0,9.0,3.0]|    0.0|
| [1.0,4.0,9.0,6.0]|    0.0|
+------------------+-------+
only showing top 5 rows

None None


### Classifiers

In [27]:
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier

In [28]:
dtc = DecisionTreeClassifier(featuresCol='features', labelCol='Spoiled')

In [30]:
dtc_model = dtc.fit(train)

In [32]:
pred = dtc_model.transform(test)

In [38]:
pred.show(5)

+------------------+-------+-------------+-----------+----------+
|          features|Spoiled|rawPrediction|probability|prediction|
+------------------+-------+-------------+-----------+----------+
|[1.0,1.0,12.0,4.0]|    1.0|   [0.0,66.0]|  [0.0,1.0]|       1.0|
| [1.0,3.0,8.0,3.0]|    0.0|  [187.0,0.0]|  [1.0,0.0]|       0.0|
| [1.0,4.0,8.0,1.0]|    0.0|   [33.0,0.0]|  [1.0,0.0]|       0.0|
| [1.0,4.0,9.0,3.0]|    0.0|  [187.0,0.0]|  [1.0,0.0]|       0.0|
| [1.0,4.0,9.0,6.0]|    0.0|  [187.0,0.0]|  [1.0,0.0]|       0.0|
+------------------+-------+-------------+-----------+----------+
only showing top 5 rows



In [43]:
test.select('Spoiled').distinct().show()

#Use Binary instead of Multiclass
# from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.evaluation import BinaryClassificationEvaluator
# from pyspark.mllib.evaluation import BinaryClassificationMetrics

+-------+
|Spoiled|
+-------+
|    0.0|
|    1.0|
+-------+



In [39]:
eval = BinaryClassificationEvaluator(labelCol='Spoiled')

In [44]:
print(eval.evaluate(pred, {eval.metricName: 'areaUnderROC'}))
print(eval.evaluate(pred, {eval.metricName: 'areaUnderPR'}))

0.9954322638146167
0.9898120982008025


In [40]:
dtc_model.featureImportances

SparseVector(4, {0: 0.0167, 1: 0.0187, 2: 0.9532, 3: 0.0115})

#### Random Forest Classifier

In [47]:
rfc = RandomForestClassifier(featuresCol='features', labelCol='Spoiled')
rfc_model = rfc.fit(train)
pred = rfc_model.transform(test)

In [48]:
rfc_model.featureImportances

SparseVector(4, {0: 0.0327, 1: 0.0167, 2: 0.9226, 3: 0.0279})

In [49]:
pred.show()

+-------------------+-------+--------------------+--------------------+----------+
|           features|Spoiled|       rawPrediction|         probability|prediction|
+-------------------+-------+--------------------+--------------------+----------+
| [1.0,1.0,12.0,4.0]|    1.0|[0.10001193342460...|[0.00500059667123...|       1.0|
|  [1.0,3.0,8.0,3.0]|    0.0|[19.6981854412714...|[0.98490927206357...|       0.0|
|  [1.0,4.0,8.0,1.0]|    0.0|[19.3174128402741...|[0.96587064201370...|       0.0|
|  [1.0,4.0,9.0,3.0]|    0.0|[19.6981854412714...|[0.98490927206357...|       0.0|
|  [1.0,4.0,9.0,6.0]|    0.0|[19.9041793504629...|[0.99520896752314...|       0.0|
|[1.0,5.0,12.0,10.0]|    1.0|[0.07065587281854...|[0.00353279364092...|       1.0|
|  [1.0,6.0,8.0,9.0]|    0.0|[19.9327507790344...|[0.99663753895172...|       0.0|
| [1.0,7.0,11.0,9.0]|    1.0|[0.07065587281854...|[0.00353279364092...|       1.0|
|[1.0,7.0,11.0,10.0]|    1.0|[1.07065587281854...|[0.05353279364092...|       1.0|
|  [