From 6c7af2206e6bd16e8bcc4feb4626bfccb5837c55 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 23:09:46 -0700 Subject: [PATCH] prepared for multiclass without breaking binary classification --- .../spark/mllib/tree/DecisionTree.scala | 189 ++++++++++-------- 1 file changed, 107 insertions(+), 82 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 49b821d589071..0ca4366ae6e84 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -385,6 +385,8 @@ object DecisionTree extends Serializable with Logging { logDebug("numFeatures = " + numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification + logDebug("numClasses = " + numClasses) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -545,10 +547,10 @@ object DecisionTree extends Serializable with Logging { * incremented based upon whether the feature is classified as 0 or 1. * * @param agg Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures*numNodes for classification + * numClasses * numSplits * numFeatures*numNodes for classification * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification + * numClasses * numSplits * numFeatures * numNodes for classification */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { // Iterate over all nodes. @@ -562,16 +564,16 @@ object DecisionTree extends Serializable with Logging { val label = arr(0) // Iterate over all features. var featureIndex = 0 - // TODO: Multiclass modification here while (featureIndex < numFeatures) { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = 2 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 - label match { - case n: Double => agg(aggIndex) = agg(aggIndex + n.toInt) + 1 + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + label.toInt match { + case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + 1 } featureIndex += 1 } @@ -632,7 +634,7 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { - case Classification => 2 * numBins * numFeatures * numNodes + case Classification => numClasses * numBins * numFeatures * numNodes case Regression => 3 * numBins * numFeatures * numNodes } logDebug("binAggregateLength = " + binAggregateLength) @@ -672,20 +674,20 @@ object DecisionTree extends Serializable with Logging { * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], featureIndex: Int, splitIndex: Int, - rightNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Array[Double]]], topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => // TODO: Modify here - val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) - val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) + val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0) + val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1) val leftCount = left0Count + left1Count - val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) - val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) + val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0) + val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1) val rightCount = right0Count + right1Count val impurity = { @@ -722,13 +724,13 @@ object DecisionTree extends Serializable with Logging { new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) case Regression => - val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) - val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) - val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) + val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) + val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) + val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) - val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) - val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) - val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) + val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) + val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) + val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) val impurity = { if (level > 0) { @@ -777,73 +779,96 @@ object DecisionTree extends Serializable with Logging { * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { strategy.algo match { case Classification => // TODO: Multiclass modification here - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(2 * (numBins - 2)) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) - = binData(shift + (2 * (numBins - 1)) + 1) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = - binData(shift + (2 *(numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) - - splitIndex += 1 + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + + if (strategy.isMultiClassification) { + var featureIndex = 0 + while (featureIndex < numFeatures){ + val numCategories = strategy.categoricalFeaturesInfo(featureIndex) + val maxSplits = math.pow(2, numCategories) - 1 + var i = 0 + // TODO: Add multiclass case here + while (i < maxSplits) { + var classIndex = 0 + while (classIndex < numClasses) { + // shift for this featureIndex + val shift = numClasses * featureIndex * numBins + + classIndex += 1 + } + i += 1 + } + featureIndex += 1 + } + } else { + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 2 * featureIndex * numBins + + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) + = binData(shift + (2 * (numBins - 1)) + 1) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(0) + leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex + + 1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = + binData(shift + (2 *(numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = + binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) + + splitIndex += 1 + } + featureIndex += 1 } - featureIndex += 1 } (leftNodeAgg, rightNodeAgg) case Regression => // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { // shift for this featureIndex val shift = 3 * featureIndex * numBins // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(2) = binData(shift + 2) + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) // right node aggregate for the highest split - rightNodeAgg(featureIndex)(3 * (numBins - 2)) = + rightNodeAgg(featureIndex)(numBins - 2)(0) = binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = + rightNodeAgg(featureIndex)(numBins - 2)(1) = binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = + rightNodeAgg(featureIndex)(numBins - 2)(2) = binData(shift + (3 * (numBins - 1)) + 2) // Iterate over all splits. @@ -851,24 +876,24 @@ object DecisionTree extends Serializable with Logging { while (splitIndex < numBins - 1) { // calculating left node aggregate for a split as a sum of left node aggregate of a // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3) - leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) - leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(0) + leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(splitIndex - 1)(1) + leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(splitIndex - 1)(2) // calculating right node aggregate for a split as a sum of right node aggregate of a // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = binData(shift + (3 * (numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2) splitIndex += 1 } @@ -882,8 +907,8 @@ object DecisionTree extends Serializable with Logging { * Calculates information gain for all nodes splits. */ def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Double]], - rightNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)