Skip to content

Commit

Permalink
sample weights
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 13, 2014
1 parent ed5a2df commit d8e4a11
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,39 @@ object DecisionTree extends Serializable with Logging {
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
}

// TODO: Add sample weight support

/**
* Method to train a decision tree model where the instances are represented as an RDD of
* (label, features) pairs. The method supports binary classification and regression. For the
* binary classification, the label for each instance should either be 0 or 1 to denote the two
* classes.
*
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
* training data
* @param algo algorithm, classification or regression
* @param impurity impurity criterion used for information gain calculation
* @param maxDepth maxDepth maximum depth of the tree
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @param labelWeights A map storing weights applied to each label for handling unbalanced
* datasets. For example, an entry (n -> k) implies the a weight of k is
* applied to an instance with label n. It's important to note that labels
* are zero-index and take values 0, 1, 2, ... , numClasses.
* @return a DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClassesForClassification: Int,
labelWeights: Map[Int,Int]): DecisionTreeModel = {
val strategy
= new Strategy(algo, impurity, maxDepth, numClassesForClassification,
labelWeights = labelWeights)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
}

/**
* Method to train a decision tree model where the instances are represented as an RDD of
Expand All @@ -283,6 +315,10 @@ object DecisionTree extends Serializable with Logging {
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @param labelWeights A map storing weights applied to each label for handling unbalanced
* datasets. For example, an entry (n -> k) implies the a weight of k is
* applied to an instance with label n. It's important to note that labels
* are zero-index and take values 0, 1, 2, ... , numClasses.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
Expand All @@ -298,11 +334,12 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int,
numClassesForClassification: Int,
labelWeights: Map[Int,Int],
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
Expand Down Expand Up @@ -419,6 +456,9 @@ object DecisionTree extends Serializable with Logging {
logDebug("numBins = " + numBins)
val numClasses = strategy.numClassesForClassification
logDebug("numClasses = " + numClasses)
val labelWeights = strategy.labelWeights
logDebug("labelWeights = " + labelWeights)


// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex
Expand Down Expand Up @@ -605,7 +645,8 @@ object DecisionTree extends Serializable with Logging {
val aggIndex
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
label.toInt match {
case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + 1
case n: Int =>
agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1)
}
featureIndex += 1
}
Expand Down Expand Up @@ -1010,6 +1051,7 @@ object DecisionTree extends Serializable with Logging {
while (featureIndex < numFeatures) {
// Iterate over all splits.
var splitIndex = 0
// TODO: Modify this for categorical variables to go over only valid splits
while (splitIndex < numBins - 1) {
val gainStats = gains(featureIndex)(splitIndex)
if (gainStats.gain > bestGainStats.gain) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* zero-indexed.
* @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
* 128 MB.
* @param labelWeights A map storing weights applied to each label for handling unbalanced
* datasets. For example, an entry (n -> k) implies the a weight of k is
* applied to an instance with label n. It's important to note that labels
* are zero-index and take values 0, 1, 2, ... , numClasses.
*
*/
@Experimental
Expand All @@ -50,7 +54,8 @@ class Strategy (
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemoryInMB: Int = 128) extends Serializable {
val maxMemoryInMB: Int = 128,
val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable {

require(numClassesForClassification >= 2)
val isMultiClassification = numClassesForClassification > 2
Expand Down

0 comments on commit d8e4a11

Please sign in to comment.