In [96]:
from pyspark.sql import SQLContext, SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.sql.functions import when, sum, col, count, round
import matplotlib.pyplot as plt
import seaborn as sns

In [74]:
spark = SparkSession.Builder().appName('DDAM_Project_Mushrooms').getOrCreate()
sql_ctx = SQLContext(spark)
mushroom_df_cleaned = spark.read.options(inferSchema = True, delimiter = ',', header = True).csv('dataset/cleaned.csv')



In [75]:
mushroom_df_cleaned.show()
mushroom_df_cleaned.count()

+-----+------------+---------+-----------+---------+--------------------+---------------+----------+-----------+----------+----------+--------+---------+-------+------+
|class|cap-diameter|cap-shape|cap-surface|cap-color|does-bruise-or-bleed|gill-attachment|gill-color|stem-height|stem-width|stem-color|has-ring|ring-type|habitat|season|
+-----+------------+---------+-----------+---------+--------------------+---------------+----------+-----------+----------+----------+--------+---------+-------+------+
|    e|       10.96|        f|          i|        l|                   f|              s|         b|       7.23|     1.915|         b|       f|        f|      l|     a|
|    e|        9.33|        f|          i|        l|                   f|              s|         b|       7.36|     1.894|         b|       f|        f|      d|     w|
|    e|        11.0|        f|          i|        l|                   f|              s|         b|       8.28|     1.988|         u|       f|        f|  

60014

In [76]:
cat_cols = [element[0] for element in mushroom_df_cleaned.dtypes if element[1] == 'string']
print('The numerical columns are')
print(cat_cols)

The numerical columns are
['class', 'cap-shape', 'cap-surface', 'cap-color', 'does-bruise-or-bleed', 'gill-attachment', 'gill-color', 'stem-color', 'has-ring', 'ring-type', 'habitat', 'season']


In [77]:
num_cols = [element[0] for element in mushroom_df_cleaned.dtypes if element[1] != 'string']
print('The numerical columns are')
print(num_cols)

The numerical columns are
['cap-diameter', 'stem-height', 'stem-width']


In [78]:
for c in cat_cols:
    if mushroom_df_cleaned.select(c).distinct().count() <=2:
        mushroom_df_cleaned.select(c).distinct().show()

+-----+
|class|
+-----+
|    e|
|    p|
+-----+

+--------------------+
|does-bruise-or-bleed|
+--------------------+
|                   f|
|                   t|
+--------------------+

+--------+
|has-ring|
+--------+
|       f|
|       t|
+--------+



Feature Preprocessing for the classification

In [79]:
bool_cols = ['does-bruise-or-bleed', 'has-ring']
for c in bool_cols:
    mushroom_df_cleaned = mushroom_df_cleaned.withColumn(f"{c}_indexed", when(mushroom_df_cleaned[c] == 'f', 0).otherwise(1))


remaining_cols = [item for item in cat_cols if item not in bool_cols]

for c in remaining_cols:
    indexer_fitted = StringIndexer(inputCol=c, outputCol=f"{c}_indexed").fit(mushroom_df_cleaned)
    mushroom_df_cleaned = indexer_fitted.transform(mushroom_df_cleaned)

mushroom_df_cleaned = mushroom_df_cleaned.drop(*cat_cols)

In [102]:
valid_cols = ['cap-diameter', 'stem-height', 'stem-width', 'does-bruise-or-bleed_indexed',
 'has-ring_indexed', 'cap-shape_indexed', 'cap-surface_indexed', 'cap-color_indexed',
 'gill-attachment_indexed', 'gill-color_indexed', 'stem-color_indexed',
 'ring-type_indexed', 'habitat_indexed', 'season_indexed']


assembler = VectorAssembler(inputCols=valid_cols, outputCol="input_features")

dataset = assembler.transform(mushroom_df_cleaned)

dataset = dataset.withColumnRenamed('class_indexed', 'class')

classification_df = dataset.select("input_features", "class")

classification_df.show(truncate=False)


+-----------------------------------------------------------------+-----+
|input_features                                                   |class|
+-----------------------------------------------------------------+-----+
|[10.96,7.23,1.915,0.0,0.0,1.0,9.0,11.0,4.0,11.0,12.0,0.0,2.0,0.0]|1.0  |
|[9.33,7.36,1.894,0.0,0.0,1.0,9.0,11.0,4.0,11.0,12.0,0.0,0.0,2.0] |1.0  |
|[11.0,8.28,1.988,0.0,0.0,1.0,9.0,11.0,4.0,11.0,6.0,0.0,2.0,2.0]  |1.0  |
|[14.75,8.4,2.211,0.0,0.0,1.0,9.0,11.0,4.0,11.0,12.0,0.0,2.0,0.0] |1.0  |
|[11.15,6.98,1.786,0.0,0.0,1.0,9.0,11.0,4.0,11.0,6.0,0.0,0.0,2.0] |1.0  |
|[10.71,8.24,1.989,0.0,0.0,1.0,9.0,11.0,4.0,11.0,12.0,0.0,0.0,0.0]|1.0  |
|[10.07,8.63,2.229,0.0,0.0,1.0,9.0,11.0,4.0,11.0,12.0,0.0,0.0,0.0]|1.0  |
|[9.48,7.28,2.091,0.0,0.0,1.0,9.0,11.0,4.0,11.0,6.0,0.0,2.0,0.0]  |1.0  |
|[8.61,8.41,2.004,0.0,0.0,1.0,9.0,11.0,4.0,11.0,12.0,0.0,0.0,2.0] |1.0  |
|[9.31,8.09,1.804,0.0,0.0,1.0,9.0,11.0,4.0,11.0,12.0,0.0,2.0,2.0] |1.0  |
|[9.73,6.58,1.864,0.0,0.0,1.0,9.0,11.0

In [121]:
def value_counts(df):
   grouped_df = df.select("class").groupBy("class").agg(count("class").alias("total_value"))
   proportions = grouped_df.withColumn("proportions", col("total_value") / df.count()).select("class","proportions")
   return proportions
 

In [132]:
value_counts(classification_df).show()


+-----+------------------+
|class|       proportions|
+-----+------------------+
|  0.0|0.5546039257506582|
|  1.0|0.4453960742493418|
+-----+------------------+



In [133]:
(dev_df, test_df) = classification_df.randomSplit([0.7, 0.3],seed=0)
(training_df, val_df) = dev_df.randomSplit([0.7, 0.3],seed=0)

In [134]:
#Verify the splits
value_counts(training_df).show()
value_counts(val_df).show()
value_counts(test_df).show()

                                                                                

+-----+-------------------+
|class|        proportions|
+-----+-------------------+
|  0.0| 0.5563188681811998|
|  1.0|0.44368113181880015|
+-----+-------------------+



                                                                                

+-----+------------------+
|class|       proportions|
+-----+------------------+
|  0.0|0.5585770750988143|
|  1.0|0.4414229249011858|
+-----+------------------+



                                                                                

+-----+------------------+
|class|       proportions|
+-----+------------------+
|  0.0|0.5489977728285078|
|  1.0|0.4510022271714922|
+-----+------------------+



                                                                                

In [137]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

dt = DecisionTreeClassifier(labelCol="class", featuresCol="input_features")

model = dt.fit(training_df)

# Make predictions.
predictions = model.transform(val_df)
 
evaluator = MulticlassClassificationEvaluator(labelCol="class"
                                              , predictionCol="prediction"
                                              , metricName="accuracy")
 
accuracy = evaluator.evaluate(predictions)
print(accuracy)

[Stage 1461:>                                                       (0 + 1) / 1]

0.7502766798418973


                                                                                