From 237762d3186c2f271e26a9a8bb61899016290312 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 25 May 2014 19:30:47 -0700 Subject: [PATCH] renaming functions --- .../spark/mllib/tree/DecisionTree.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 12eb09ffc39c5..61975fa69c681 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -547,16 +547,18 @@ object DecisionTree extends Serializable with Logging { /** * Sequential search helper method to find bin for categorical feature in multiclass - * classification. Dummy value of 0 used since it is not used in future calculation + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. */ - def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = { + def sequentialBinSearchForCategoricalFeatureInMulticlassClassification(): Int = { labeledPoint.features(featureIndex).toInt } /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeatureInMultiClassClassification(): Int = { + def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 @@ -583,9 +585,9 @@ object DecisionTree extends Serializable with Logging { // Perform sequential search to find bin for categorical features. val binIndex = { if (isMulticlassClassification) { - sequentialBinSearchForCategoricalFeatureInBinaryClassification() + sequentialBinSearchForCategoricalFeatureInMulticlassClassification() } else { - sequentialBinSearchForCategoricalFeatureInMultiClassClassification() + sequentialBinSearchForCategoricalFeatureInBinaryClassification() } } if (binIndex == -1){ @@ -684,7 +686,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 2 * numSplits * numFeatures * numNodes for classification */ - def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -716,7 +718,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 2 * numClasses * numSplits * numFeatures * numNodes for classification */ - def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -789,9 +791,9 @@ object DecisionTree extends Serializable with Logging { strategy.algo match { case Classification => if(isMulticlassClassificationWithCategoricalFeatures) { - multiClassificationBinSeqOp(arr, agg) + unorderedClassificationBinSeqOp(arr, agg) } else { - binaryClassificationBinSeqOp(arr, agg) + orderedClassificationBinSeqOp(arr, agg) } case Regression => regressionBinSeqOp(arr, agg) }