From f5f6b833d62d7fba982c62971dc373c70363385e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 21 May 2014 18:40:24 -0700 Subject: [PATCH] multiclass for continous variables --- .../spark/mllib/tree/DecisionTree.scala | 56 +++++++++---------- .../spark/mllib/tree/DecisionTreeSuite.scala | 12 ++-- 2 files changed, 32 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c6a306d436339..82fd719c75990 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -987,48 +987,44 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { - def findAggForOrderedFeature( + def findAggForOrderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], featureIndex: Int) { // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) - = binData(shift + (2 * (numBins - 1)) + 1) + val shift = numClasses * featureIndex * numBins + + var classIndex = 0 + while (classIndex < numClasses) { + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(classIndex) + = binData(shift + (numClasses * (numBins - 1)) + classIndex) + classIndex += 1 + } // Iterate over all splits. var splitIndex = 1 while (splitIndex < numBins - 1) { // calculating left node aggregate for a split as a sum of left node aggregate of a // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(0) - leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex + - 1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = - binData(shift + (2 * (numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = - binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) - + var innerClassIndex = 0 + while (innerClassIndex < numClasses) { + leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) + = binData(shift + numClasses * splitIndex + innerClassIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = + binData(shift + (numClasses * (numBins - 2 - splitIndex) + innerClassIndex)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) + innerClassIndex += 1 + } splitIndex += 1 } } - def extractAggForCategoricalFeature( + def findAggregateForCategoricalFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], featureIndex: Int) { @@ -1108,12 +1104,12 @@ object DecisionTree extends Serializable with Logging { if (isMulticlassClassification){ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { - findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex) + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } else { - extractAggForCategoricalFeature(leftNodeAgg, rightNodeAgg, featureIndex) + findAggregateForCategoricalFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } } else { - findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex) + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } featureIndex += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 41cf5a120bac8..12477c1fc1b07 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -561,10 +561,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 - //assert(bestSplit.feature === 1) - //assert(bestSplit.featureType == Continuous) - //assert(bestSplit.threshold > 1000) - println(bestSplit) + assert(bestSplit.feature === 1) + assert(bestSplit.featureType === Continuous) + assert(bestSplit.threshold > 1980) + assert(bestSplit.threshold < 2020) } @@ -639,9 +639,9 @@ object DecisionTreeSuite { val arr = new Array[WeightedLabeledPoint](3000) for (i <- 0 until 3000) { if (i < 2000) { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 100)) + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, i)) } else { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 3000)) + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, i)) } } arr