# Classification with Decision Tree and Naive Bayes example

### Importing MLlib libraries 

In [1]:
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.impurity.Entropy
import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.evaluation.MulticlassMetrics

### Read data and pre-processing

The **Iris flower data set** or Fisher's Iris data set is a multivariate data set introduced by Ronald Fisher in his 1936 paper "The use of multiple measurements in taxonomic problems as an example of linear discriminant analysis". The data set consists of 50 samples from each of three species of Iris (Iris setosa, Iris virginica and Iris versicolor). Four features were measured from each sample: the length and the width of the sepals and petals, in centimetres.

In [2]:
val rawData = sc.textFile("data/iris.csv")

In [3]:
val splitlines = rawData.map(lines => {
    lines.split(',')
  })
splitlines.first()

Array(5.1, 3.5, 1.4, 0.2, Iris-setosa)

In [4]:
val Data = splitlines.map { col =>   
     val species = col(col.size - 1)                       
     val label = if (species == "Iris-versicolor") 0.toInt else if (species == "Iris-setosa") 1.toInt else 2.toInt
     val features = col.slice(0, col.size - 1).map(_.toDouble)
     LabeledPoint(label, Vectors.dense(features))
}
Data.take(5).foreach(println)

(1.0,[5.1,3.5,1.4,0.2])
(1.0,[4.9,3.0,1.4,0.2])
(1.0,[4.7,3.2,1.3,0.2])
(1.0,[4.6,3.1,1.5,0.2])
(1.0,[5.0,3.6,1.4,0.2])


### Split the data into training and test sets (40% held out for testing)

In [5]:
val splits = Data.randomSplit(Array(0.6, 0.4), seed = 11L)
val trainingData = splits(0).cache()
val testData = splits(1)
println("Training Data")
trainingData.take(5).foreach(println)
println("Test Data")
testData.take(5).foreach(println)

Training Data
(1.0,[5.1,3.5,1.4,0.2])
(1.0,[4.9,3.0,1.4,0.2])
(1.0,[4.7,3.2,1.3,0.2])
(1.0,[4.6,3.1,1.5,0.2])
(1.0,[5.0,3.6,1.4,0.2])
Test Data
(1.0,[4.9,3.1,1.5,0.1])
(1.0,[5.4,3.7,1.5,0.2])
(1.0,[5.8,4.0,1.2,0.2])
(1.0,[5.4,3.9,1.3,0.4])
(1.0,[5.1,3.8,1.5,0.3])


### Train a Decision Tree

Decision trees are widely used since they are easy to interpret, handle categorical features, extend to the multiclass classification setting, do not require feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble algorithms such as random forests and boosting are among the top performers for classification and regression tasks.
MLlib supports decision trees for binary and multiclass classification and for regression, using both continuous and categorical features. The implementation partitions data by rows, allowing distributed training with millions of instances.

In [6]:
val numClasses = 3
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "entropy"
val maxDepth = 3
val maxBins = 10
val dtModel = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)
println(dtModel.toDebugString)


DecisionTreeModel classifier of depth 3 with 9 nodes
  If (feature 2 <= 1.7)
   Predict: 1.0
  Else (feature 2 > 1.7)
   If (feature 3 <= 1.7)
    If (feature 2 <= 5.0)
     Predict: 0.0
    Else (feature 2 > 5.0)
     Predict: 2.0
   Else (feature 3 > 1.7)
    If (feature 0 <= 6.0)
     Predict: 2.0
    Else (feature 0 > 6.0)
     Predict: 2.0



In [7]:
val dtTotalCorrect = trainingData.map { point =>
  if (dtModel.predict(point.features) == point.label) 1 else 0
  }.sum

println(dtTotalCorrect)
println(trainingData.count)

97.0
100


In [8]:
val dtAccuracy = dtTotalCorrect / trainingData.count
println(dtAccuracy)

0.97


### Test

In [9]:
val dtTotalCorrect = testData.map { point =>
  if (dtModel.predict(point.features) == point.label) 1 else 0
  }.sum
println(dtTotalCorrect)
println(testData.count)

47.0
50


In [10]:
val dtAccuracy = dtTotalCorrect / testData.count
println(dtAccuracy)

0.94


### Train Naive Bayes

Naive Bayes is a simple multiclass classification algorithm with the assumption of independence between every pair of features. Naive Bayes can be trained very efficiently. Within a single pass to the training data, it computes the conditional probability distribution of each feature given label, and then it applies Bayes theorem to compute the conditional probability distribution of label given an observation and use it for prediction.

In [11]:
val nbModel = NaiveBayes.train(trainingData)
println(nbModel)

org.apache.spark.mllib.classification.NaiveBayesModel@18dfc198


In [12]:
val nbTotalCorrect = trainingData.map { point =>
    if (nbModel.predict(point.features) == point.label) 1 else 0
}.sum
println(nbTotalCorrect)
println(trainingData.count)

96.0
100


In [13]:
val nbAccuracy = nbTotalCorrect / trainingData.count
println(nbAccuracy)

0.96


### Test

In [14]:
val nbTotalCorrect = testData.map { point =>
    if (nbModel.predict(point.features) == point.label) 1 else 0
}.sum

println(nbTotalCorrect)
println(testData.count)

48.0
50


In [15]:
val nbAccuracy = nbTotalCorrect / testData.count
println(nbAccuracy)

0.96


### Complete evaluation on test set for Decision Tree  model

In [16]:

val predictionAndLabels = testData.map { case LabeledPoint(label, features) =>
  val prediction = dtModel.predict(features)
  (prediction, label)
}

// Instantiate metrics object
val metrics = new MulticlassMetrics(predictionAndLabels)

// Confusion matrix
println("Confusion matrix:")
println(metrics.confusionMatrix)

// Overall Statistics
val precision = metrics.precision
val recall = metrics.recall // same as true positive rate
val f1Score = metrics.fMeasure
println("Summary Statistics")
println(s"Precision = $precision")
println(s"Recall = $recall")
println(s"F1 Score = $f1Score")

// Precision by label
val labels = metrics.labels
labels.foreach { l =>
    println(s"Precision($l) = " + metrics.precision(l))
}

// Recall by label
labels.foreach { l =>
    println(s"Recall($l) = " + metrics.recall(l))
}

// False positive rate by label
labels.foreach { l =>
    println(s"FPR($l) = " + metrics.falsePositiveRate(l))
}

// F-measure by label
labels.foreach { l =>
    println(s"F1-Score($l) = " + metrics.fMeasure(l))
}

// Weighted stats
println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")

Confusion matrix:
15.0  0.0   0.0   
2.0   18.0  0.0   
1.0   0.0   14.0  
Summary Statistics
Precision = 0.94
Recall = 0.94
F1 Score = 0.94
Precision(0.0) = 0.8333333333333334
Precision(1.0) = 1.0
Precision(2.0) = 1.0
Recall(0.0) = 1.0
Recall(1.0) = 0.9
Recall(2.0) = 0.9333333333333333
FPR(0.0) = 0.08571428571428572
FPR(1.0) = 0.0
FPR(2.0) = 0.0
F1-Score(0.0) = 0.9090909090909091
F1-Score(1.0) = 0.9473684210526316
F1-Score(2.0) = 0.9655172413793104
Weighted precision: 0.9500000000000001
Weighted recall: 0.9400000000000001
Weighted F1 score: 0.9413298135621185
Weighted false positive rate: 0.025714285714285717


### Complete evaluation on test set for Naive Bayes  model

In [17]:
val predictionAndLabels = testData.map { case LabeledPoint(label, features) =>
  val prediction = nbModel.predict(features)
  (prediction, label)
}

// Instantiate metrics object
val metrics = new MulticlassMetrics(predictionAndLabels)

// Confusion matrix
println("Confusion matrix:")
println(metrics.confusionMatrix)

// Overall Statistics
val precision = metrics.precision
val recall = metrics.recall // same as true positive rate
val f1Score = metrics.fMeasure
println("Summary Statistics")
println(s"Precision = $precision")
println(s"Recall = $recall")
println(s"F1 Score = $f1Score")

// Precision by label
val labels = metrics.labels
labels.foreach { l =>
    println(s"Precision($l) = " + metrics.precision(l))
}

// Recall by label
labels.foreach { l =>
    println(s"Recall($l) = " + metrics.recall(l))
}

// False positive rate by label
labels.foreach { l =>
    println(s"FPR($l) = " + metrics.falsePositiveRate(l))
}

// F-measure by label
labels.foreach { l =>
    println(s"F1-Score($l) = " + metrics.fMeasure(l))
}

// Weighted stats
println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")

Confusion matrix:
14.0  0.0   1.0   
0.0   20.0  0.0   
1.0   0.0   14.0  
Summary Statistics
Precision = 0.96
Recall = 0.96
F1 Score = 0.96
Precision(0.0) = 0.9333333333333333
Precision(1.0) = 1.0
Precision(2.0) = 0.9333333333333333
Recall(0.0) = 0.9333333333333333
Recall(1.0) = 1.0
Recall(2.0) = 0.9333333333333333
FPR(0.0) = 0.02857142857142857
FPR(1.0) = 0.0
FPR(2.0) = 0.02857142857142857
F1-Score(0.0) = 0.9333333333333333
F1-Score(1.0) = 1.0
F1-Score(2.0) = 0.9333333333333333
Weighted precision: 0.9600000000000001
Weighted recall: 0.9600000000000001
Weighted F1 score: 0.9600000000000001
Weighted false positive rate: 0.01714285714285714
