In [11]:
import os
import sys
import git
from numpy import allclose
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import GBTClassifier

### SparkContext and SparkSession

In [12]:
# create entry points to spark
try:
    sc.stop()
except:
    pass
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

In [13]:
sc=SparkContext()
spark = SparkSession(sparkContext=sc)

23/07/27 12:48:53 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


### Load the dataset

In [14]:
base_path = git.Repo('.', search_parent_directories=True).working_tree_dir
data_path = "data/iris.csv"
path = os.path.join(base_path, data_path)

In [43]:
iris = spark.read.csv(path, header=True, inferSchema=True)

In [44]:
iris.show(5)

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| setosa|
|         4.9|        3.0|         1.4|        0.2| setosa|
|         4.7|        3.2|         1.3|        0.2| setosa|
|         4.6|        3.1|         1.5|        0.2| setosa|
|         5.0|        3.6|         1.4|        0.2| setosa|
+------------+-----------+------------+-----------+-------+
only showing top 5 rows



In [45]:
iris.describe().show()

23/07/27 13:05:33 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+-------+------------------+-------------------+------------------+------------------+---------+
|summary|      sepal_length|        sepal_width|      petal_length|       petal_width|  species|
+-------+------------------+-------------------+------------------+------------------+---------+
|  count|               150|                150|               150|               150|      150|
|   mean| 5.843333333333335| 3.0540000000000007|3.7586666666666693|1.1986666666666672|     null|
| stddev|0.8280661279778637|0.43359431136217375| 1.764420419952262|0.7631607417008414|     null|
|    min|               4.3|                2.0|               1.0|               0.1|   setosa|
|    max|               7.9|                4.4|               6.9|               2.5|virginica|
+-------+------------------+-------------------+------------------+------------------+---------+



### Create feature and target column from the features and label columns

In [46]:
from pyspark.ml.linalg import Vectors
from pyspark.sql import Row

In [47]:
iris2 = iris.rdd.map(lambda x: Row(features=Vectors.dense(x[:-1]), species=x[-1])).toDF()
iris2.show(5)

+-----------------+-------+
|         features|species|
+-----------------+-------+
|[5.1,3.5,1.4,0.2]| setosa|
|[4.9,3.0,1.4,0.2]| setosa|
|[4.7,3.2,1.3,0.2]| setosa|
|[4.6,3.1,1.5,0.2]| setosa|
|[5.0,3.6,1.4,0.2]| setosa|
+-----------------+-------+
only showing top 5 rows



### Encode labels as numbers

In [48]:
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

In [49]:
stringindexer = StringIndexer(inputCol='species', outputCol='label')
stages = [stringindexer]
pipeline = Pipeline(stages=stages)

In [50]:
iris_df = pipeline.fit(iris2).transform(iris2)
iris_df.show(5)

+-----------------+-------+-----+
|         features|species|label|
+-----------------+-------+-----+
|[5.1,3.5,1.4,0.2]| setosa|  0.0|
|[4.9,3.0,1.4,0.2]| setosa|  0.0|
|[4.7,3.2,1.3,0.2]| setosa|  0.0|
|[4.6,3.1,1.5,0.2]| setosa|  0.0|
|[5.0,3.6,1.4,0.2]| setosa|  0.0|
+-----------------+-------+-----+
only showing top 5 rows



### Split the data into train and test sets

In [51]:
train, test = iris_df.randomSplit([0.8, 0.2], seed=1234)

### Model

In [64]:
from pyspark.ml.classification import NaiveBayes
naivebayes = NaiveBayes(featuresCol="features", labelCol="label")

### Cross-validation

In [65]:
from pyspark.ml.tuning import ParamGridBuilder
param_grid = ParamGridBuilder().\
    addGrid(naivebayes.smoothing, [0, 1, 2, 4, 8]).\
    build()

In [66]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator()

In [67]:
from pyspark.ml.tuning import CrossValidator
crossvalidator = CrossValidator(estimator=naivebayes, estimatorParamMaps=param_grid, evaluator=evaluator)

### Train

In [68]:
crossvalidation_mode = crossvalidator.fit(train)

23/07/27 13:10:45 WARN CacheManager: Asked to cache already cached data.
23/07/27 13:10:45 WARN CacheManager: Asked to cache already cached data.
23/07/27 13:10:45 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS


In [69]:
pred_train = crossvalidation_mode.transform(train)
pred_train.show(5)

+-----------------+-------+-----+--------------------+--------------------+----------+
|         features|species|label|       rawPrediction|         probability|prediction|
+-----------------+-------+-----+--------------------+--------------------+----------+
|[4.3,3.0,1.1,0.1]| setosa|  0.0|[-9.9913239014221...|[0.71493036341907...|       0.0|
|[4.4,3.0,1.3,0.2]| setosa|  0.0|[-10.782534742072...|[0.66409361764449...|       0.0|
|[4.4,3.2,1.3,0.2]| setosa|  0.0|[-11.001048101629...|[0.68946803815090...|       0.0|
|[4.6,3.1,1.5,0.2]| setosa|  0.0|[-11.417381411425...|[0.65355055727946...|       0.0|
|[4.6,3.2,1.4,0.2]| setosa|  0.0|[-11.337110473117...|[0.68215432402249...|       0.0|
+-----------------+-------+-----+--------------------+--------------------+----------+
only showing top 5 rows



### Test

In [70]:
pred_test = crossvalidation_mode.transform(test)
pred_test.show(5)

+-----------------+-------+-----+--------------------+--------------------+----------+
|         features|species|label|       rawPrediction|         probability|prediction|
+-----------------+-------+-----+--------------------+--------------------+----------+
|[4.4,2.9,1.4,0.2]| setosa|  0.0|[-10.862805680379...|[0.63470702036359...|       0.0|
|[4.5,2.3,1.3,0.3]| setosa|  0.0|[-10.429893588097...|[0.54449779826769...|       0.0|
|[4.9,3.1,1.5,0.1]| setosa|  0.0|[-11.298295313751...|[0.69093924521370...|       0.0|
|[5.0,3.0,1.6,0.2]| setosa|  0.0|[-11.790721856536...|[0.64102778563998...|       0.0|
|[5.0,3.2,1.2,0.2]| setosa|  0.0|[-11.251124743750...|[0.72700900915421...|       0.0|
+-----------------+-------+-----+--------------------+--------------------+----------+
only showing top 5 rows



In [71]:
print("The parameter smoothing has best value:",
      crossvalidation_mode.bestModel._java_obj.getSmoothing())

The parameter smoothing has best value: 4.0


### Metrics

In [76]:
print(
    'training data (f1):', evaluator.setMetricName('f1').evaluate(pred_train), "\n",
    'training data (weightedPrecision): ', evaluator.setMetricName('weightedPrecision').evaluate(pred_train),"\n",
    'training data (weightedRecall): ', evaluator.setMetricName('weightedRecall').evaluate(pred_train),"\n",
    'training data (accuracy): ', evaluator.setMetricName('accuracy').evaluate(pred_train)
    )

training data (f1): 0.9468488399904329 
 training data (weightedPrecision):  0.947646113629278 
 training data (weightedRecall):  0.9469026548672568 
 training data (accuracy):  0.9469026548672567
