From 8cfd3b6405891334a89f834c8c1fedbb3eb0868a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 21 May 2014 18:15:19 -0700 Subject: [PATCH] working for categorical multiclass classification --- .../spark/mllib/tree/DecisionTree.scala | 364 +++++++++++------- .../spark/mllib/tree/DecisionTreeSuite.scala | 43 ++- 2 files changed, 266 insertions(+), 141 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 861b35124368d..c6a306d436339 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -647,9 +647,9 @@ object DecisionTree extends Serializable with Logging { * numClasses * numSplits * numFeatures*numNodes for classification * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures * numNodes for classification + * 2 * numSplits * numFeatures * numNodes for classification */ - def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -666,27 +666,11 @@ object DecisionTree extends Serializable with Logging { val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label.toInt match { - case n: Int => - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (!isFeatureContinuous && isMulticlassClassification) { - // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 - var binIndex = 0 - while (binIndex < numCategoricalBins) { - if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)){ - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + labelWeights.getOrElse(binIndex, 1) - } - binIndex += 1 - } - } else { - agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) - } + case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) } featureIndex += 1 } @@ -695,6 +679,77 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * numClasses * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numClasses * numSplits * numFeatures * numNodes for classification + */ + def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + val rightChildShift = numClasses * numBins * numFeatures * numNodes + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isContinuousFeature) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + label.toInt match { + case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) + } + } else { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + label.toInt match { + case n: Int => + // Find all matching bins and increment their values + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)) { + agg(aggIndex + binIndex) + = agg(aggIndex + binIndex) + labelWeights.getOrElse(n, 1) + } else { + agg(rightChildShift + aggIndex + binIndex) + = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(n, 1) + + } + binIndex += 1 + } + } + } + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + /** * Performs a sequential aggregation over a partition for regression. For l nodes, k features, * the count, sum, sum of squares of one of the p bins is incremented. @@ -739,7 +794,12 @@ object DecisionTree extends Serializable with Logging { */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { - case Classification => classificationBinSeqOp(arr, agg) + case Classification => + if(isMulticlassClassification) { + multiClassificationBinSeqOp(arr, agg) + } else { + binaryClassificationBinSeqOp(arr, agg) + } case Regression => regressionBinSeqOp(arr, agg) } agg @@ -747,7 +807,12 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { - case Classification => numClasses * numBins * numFeatures * numNodes + case Classification => + if (isMulticlassClassification){ + 2 * numClasses * numBins * numFeatures * numNodes + } else { + 2 * numBins * numFeatures * numNodes + } case Regression => 3 * numBins * numFeatures * numNodes } logDebug("binAggregateLength = " + binAggregateLength) @@ -920,80 +985,139 @@ object DecisionTree extends Serializable with Logging { */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + + + def findAggForOrderedFeature( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = 2 * featureIndex * numBins + + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) + = binData(shift + (2 * (numBins - 1)) + 1) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(0) + leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex + + 1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = + binData(shift + (2 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = + binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) + + splitIndex += 1 + } + } + + def extractAggForCategoricalFeature( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + val rightChildShift = numClasses * numBins * numFeatures + var splitIndex = 0 + while (splitIndex < numBins - 1) { + var classIndex = 0 + while (classIndex < numClasses) { + // shift for this featureIndex + val shift = numClasses * featureIndex * numBins + splitIndex * numClasses + val leftBinValue = binData(shift + classIndex) + val rightBinValue = binData(rightChildShift + shift + classIndex) + leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue + rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue + classIndex += 1 + } + splitIndex += 1 + } + } + + def findAggForRegression( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(numBins - 2)(2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(0) + leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(splitIndex - 1)(1) + leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(splitIndex - 1)(2) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = + binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2) + + splitIndex += 1 + } + } + strategy.algo match { case Classification => - // Initialize left and right split aggregates. val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - - if (isMulticlassClassification) { - var featureIndex = 0 - while (featureIndex < numFeatures){ - var splitIndex = 0 - while (splitIndex < numBins - 1) { - val totalNodeAgg = Array.ofDim[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins - val binValue = binData(shift + classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) = binValue - totalNodeAgg(classIndex) = binValue - classIndex += 1 - } - // Calculate rightNodeAgg - classIndex = 0 - while (classIndex < numClasses) { - rightNodeAgg(featureIndex)(splitIndex)(classIndex) - = totalNodeAgg(classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) - classIndex += 1 - } - splitIndex += 1 - } - featureIndex += 1 - } - } else { - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) - = binData(shift + (2 * (numBins - 1)) + 1) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(0) - leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex + - 1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = - binData(shift + (2 *(numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = - binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) - - splitIndex += 1 + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (isMulticlassClassification){ + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + extractAggForCategoricalFeature(leftNodeAgg, rightNodeAgg, featureIndex) } - featureIndex += 1 + } else { + findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex) } + featureIndex += 1 } + (leftNodeAgg, rightNodeAgg) case Regression => // Initialize left and right split aggregates. @@ -1002,47 +1126,7 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 3 * featureIndex * numBins - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(numBins - 2)(2) = - binData(shift + (3 * (numBins - 1)) + 2) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(0) - leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) + - leftNodeAgg(featureIndex)(splitIndex - 1)(1) - leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) + - leftNodeAgg(featureIndex)(splitIndex - 1)(2) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = - binData(shift + (3 * (numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = - binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) = - binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2) - - splitIndex += 1 - } + findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) featureIndex += 1 } (leftNodeAgg, rightNodeAgg) @@ -1134,9 +1218,23 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + if (isMulticlassClassification) { + val shift = numClasses * node * numBins * numFeatures + val rightChildShift = numClasses * numBins * numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + numClasses * numBins * numFeatures) + leftChildData ++ rightChildData + } + binsForNode + } else { + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + binsForNode + } case Regression => val shift = 3 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 664abf742d4a1..41cf5a120bac8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -529,10 +529,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("stump with continuous variables for multiclass classification") { - assert(true==true) - } - test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) @@ -547,11 +543,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val bestSplit = bestSplits(0)._1 assert(bestSplit.feature === 0) assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(0)) + assert(bestSplit.categories.contains(1)) assert(bestSplit.featureType === Categorical) + } + + + test("stump with continuous variables for multiclass classification") { + val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + + //assert(bestSplit.feature === 1) + //assert(bestSplit.featureType == Continuous) + //assert(bestSplit.threshold > 1000) println(bestSplit) + } + test("stump with continuous + categorical variables for multiclass classification") { assert(true==true) } @@ -615,10 +632,20 @@ object DecisionTreeSuite { arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } } - println(arr(0)) - println(arr(1000)) - println(arr(2000)) arr } + def generateContinuousDataPointsForMulticlass(): Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 2000) { + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 100)) + } else { + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 3000)) + } + } + arr + } + + }