diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 8e7a6917946b8..f5054dbf0d769 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -516,7 +516,7 @@ object DecisionTree extends Serializable with Logging { * Find bin for one feature. */ def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint, - isFeatureContinuous: Boolean): Int = { + isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -550,14 +550,14 @@ object DecisionTree extends Serializable with Logging { * splits. The actual left/right child allocation per split is performed in the * sequential phase of the bin aggregate operation. */ - def sequentialBinSearchForCategoricalFeatureInMulticlassClassification(): Int = { + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { labeledPoint.features(featureIndex).toInt } /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = { + def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 @@ -583,10 +583,10 @@ object DecisionTree extends Serializable with Logging { } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlassClassification) { - sequentialBinSearchForCategoricalFeatureInMulticlassClassification() + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { - sequentialBinSearchForCategoricalFeatureInBinaryClassification() + sequentialBinSearchForOrderedCategoricalFeatureInClassification() } } if (binIndex == -1){ @@ -622,8 +622,19 @@ object DecisionTree extends Serializable with Logging { } else { var featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isSpaceSufficientForAllCategoricalSplits) + } featureIndex += 1 } } @@ -731,12 +742,19 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isContinuousFeature) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) } else { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift) + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift) + } else { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) } + } featureIndex += 1 } } @@ -1093,7 +1111,14 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } else { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } } } else { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -1168,7 +1193,9 @@ object DecisionTree extends Serializable with Logging { numBins - 1 } else { // Categorical feature val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - if (isMulticlassClassification) { + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { math.pow(2.0, featureCategories - 1).toInt - 1 } else { // Binary classification featureCategories @@ -1289,11 +1316,6 @@ object DecisionTree extends Serializable with Logging { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + "in categorical features") - if (isMulticlassClassification) { - require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1, - "numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" + - " with categorical variables") - } } @@ -1332,10 +1354,12 @@ object DecisionTree extends Serializable with Logging { } } else { // Categorical feature val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification - if (isMulticlassClassification) { + // classification that satisfy the space constraint + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -1360,14 +1384,29 @@ object DecisionTree extends Serializable with Logging { } index += 1 } - } else { // regression or binary classification - - // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - val centroidForCategories = - sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } else { + + val centroidForCategories = { + if (isMulticlassClassification) { + // For categorical variables in multiclass classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the impurity of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) + .map(x => (x._1, x._2.values.toArray)) + .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } + } + + logDebug("centriod for categories = " + centroidForCategories.mkString(",")) // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index c06ad055afeea..6a6ad5b871320 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -239,7 +239,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) } - test("split and bin calculations for categorical variables with multiclass classification") { + test("split and bin calculations for unordered categorical variables with multiclass " + + "classification") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -332,6 +333,62 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } + test("split and bin calculations for ordered categorical variables with multiclass " + + "classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + assert(arr.length === 3000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + numClassesForClassification = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // 2^10 - 1 > 100, so categorical variables will be ordered + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(2.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) + assert(splits(0)(2).categories.contains(2.0)) + assert(splits(0)(2).categories.contains(1.0)) + + assert(splits(0)(10) === null) + assert(splits(1)(10) === null) + + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + assert(bins(0)(1).category === 2.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(2.0)) + + assert(bins(0)(10) === null) + + } + + test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) @@ -547,7 +604,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } - test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) @@ -568,7 +624,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) @@ -588,6 +643,26 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.threshold < 2020) } + test("stump with categorical variables for ordered multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + println(bestSplit) + assert(bestSplit.categories.contains(1.0)) + assert(bestSplit.featureType === Categorical) + } + + } object DecisionTreeSuite { @@ -662,5 +737,20 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsForMulticlassForOrderedFeatures(): + Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 1000) { + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } else if (i < 2000) { + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 2.0)) + } + } + arr + } + }