From 1dd2735d095a46c19a1811c22a65ca211268eedd Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 17 May 2014 23:50:40 -0700 Subject: [PATCH] bin search logic for multiclass --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9ebb1d25ffa02..f1a3aea1f8c6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -549,7 +549,9 @@ object DecisionTree extends Serializable with Logging { * Sequential search helper method to find bin for categorical feature in multiclass * classification. Dummy value of 0 used since it is not used in future calculation */ - def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = 0 + def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } /** * Sequential search helper method to find bin for categorical feature. @@ -662,7 +664,7 @@ object DecisionTree extends Serializable with Logging { label.toInt match { case n: Int => val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous && strategy.isMultiClassification) { + if (!isFeatureContinuous && strategy.isMultiClassification) { // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1