From d8e4a11833b5a7e5a6e4f0f72d203fbf8e0bb0ed Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 13 May 2014 09:20:40 -0700 Subject: [PATCH] sample weights --- .../spark/mllib/tree/DecisionTree.scala | 48 +++++++++++++++++-- .../mllib/tree/configuration/Strategy.scala | 7 ++- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 072651dbf1732..c467d5ba65d94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -268,7 +268,39 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } - // TODO: Add sample weight support + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value of 2. + * @param labelWeights A map storing weights applied to each label for handling unbalanced + * datasets. For example, an entry (n -> k) implies the a weight of k is + * applied to an instance with label n. It's important to note that labels + * are zero-index and take values 0, 1, 2, ... , numClasses. + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int, + labelWeights: Map[Int,Int]): DecisionTreeModel = { + val strategy + = new Strategy(algo, impurity, maxDepth, numClassesForClassification, + labelWeights = labelWeights) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) + } /** * Method to train a decision tree model where the instances are represented as an RDD of @@ -283,6 +315,10 @@ object DecisionTree extends Serializable with Logging { * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree * @param numClassesForClassification number of classes for classification. Default value of 2. + * @param labelWeights A map storing weights applied to each label for handling unbalanced + * datasets. For example, an entry (n -> k) implies the a weight of k is + * applied to an instance with label n. It's important to note that labels + * are zero-index and take values 0, 1, 2, ... , numClasses. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and @@ -298,11 +334,12 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int, numClassesForClassification: Int, + labelWeights: Map[Int,Int], maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, - quantileCalculationStrategy, categoricalFeaturesInfo) + quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) @@ -419,6 +456,9 @@ object DecisionTree extends Serializable with Logging { logDebug("numBins = " + numBins) val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) + val labelWeights = strategy.labelWeights + logDebug("labelWeights = " + labelWeights) + // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -605,7 +645,8 @@ object DecisionTree extends Serializable with Logging { val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses label.toInt match { - case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + 1 + case n: Int => + agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1) } featureIndex += 1 } @@ -1010,6 +1051,7 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 + // TODO: Modify this for categorical variables to go over only valid splits while (splitIndex < numBins - 1) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index c397a889f2605..89daaaeccdca6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -39,6 +39,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. + * @param labelWeights A map storing weights applied to each label for handling unbalanced + * datasets. For example, an entry (n -> k) implies the a weight of k is + * applied to an instance with label n. It's important to note that labels + * are zero-index and take values 0, 1, 2, ... , numClasses. * */ @Experimental @@ -50,7 +54,8 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable { + val maxMemoryInMB: Int = 128, + val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable { require(numClassesForClassification >= 2) val isMultiClassification = numClassesForClassification > 2