diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b0f0bec899c81..fb53b588cdce9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -259,37 +259,6 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. - * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @param numClassesForClassification number of classes for classification. Default value of 2. - * @param labelWeights A map storing weights for each label to handle unbalanced - * datasets. For example, an entry (n -> k) implies the a weight of k is - * applied to an instance with label n. It's important to note that labels - * are zero-index and take values 0, 1, 2, ... , numClasses - 1. - * @return a DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int, - numClassesForClassification: Int, - labelWeights: Map[Int,Int]): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, - labelWeights = labelWeights) - new DecisionTree(strategy).train(input) - } - /** * Method to train a decision tree model where the instances are represented as an RDD of * (label, features) pairs. The decision tree method supports binary classification and @@ -303,10 +272,6 @@ object DecisionTree extends Serializable with Logging { * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree * @param numClassesForClassification number of classes for classification. Default value of 2. - * @param labelWeights A map storing weights applied to each label for handling unbalanced - * datasets. For example, an entry (n -> k) implies the a weight of k is - * applied to an instance with label n. It's important to note that labels - * are zero-index and take values 0, 1, 2, ... , numClasses - 1. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and @@ -322,12 +287,11 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int, numClassesForClassification: Int, - labelWeights: Map[Int,Int], maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, - quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights) + quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).train(input) } @@ -442,8 +406,6 @@ object DecisionTree extends Serializable with Logging { logDebug("numBins = " + numBins) val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) - val labelWeights = strategy.labelWeights - logDebug("labelWeights = " + labelWeights) val isMulticlassClassification = strategy.isMulticlassClassification logDebug("isMulticlassClassification = " + isMulticlassClassification) val isMulticlassClassificationWithCategoricalFeatures @@ -647,7 +609,7 @@ object DecisionTree extends Serializable with Logging { val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1) + agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 } def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], @@ -667,10 +629,10 @@ object DecisionTree extends Serializable with Logging { val labelInt = label.toInt if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1) + = agg(aggIndex + binIndex) + 1 } else { agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1) + = agg(rightChildShift + aggIndex + binIndex) + 1 } binIndex += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7aec14d293ec2..7c027ac2fda6b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -39,10 +39,6 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * @param labelWeights A map storing weights applied to each label for handling unbalanced - * datasets. For example, an entry (n -> k) implies the a weight of k is - * applied to an instance with label n. It's important to note that labels - * are zero-index and take values 0, 1, 2, ... , numClasses. * */ @Experimental @@ -54,8 +50,7 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128, - val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable { + val maxMemoryInMB: Int = 128) extends Serializable { require(numClassesForClassification >= 2) val isMulticlassClassification = numClassesForClassification > 2