Skip to content

Commit

Permalink
removed label weights support
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Jul 14, 2014
1 parent 2d85a48 commit afced16
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,37 +259,6 @@ object DecisionTree extends Serializable with Logging {
new DecisionTree(strategy).train(input)
}


/**
* Method to train a decision tree model where the instances are represented as an RDD of
* (label, features) pairs. The method supports binary classification and regression. For the
* binary classification, the label for each instance should either be 0 or 1 to denote the two
* classes.
*
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
* training data
* @param algo algorithm, classification or regression
* @param impurity impurity criterion used for information gain calculation
* @param maxDepth maxDepth maximum depth of the tree
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @param labelWeights A map storing weights for each label to handle unbalanced
* datasets. For example, an entry (n -> k) implies the a weight of k is
* applied to an instance with label n. It's important to note that labels
* are zero-index and take values 0, 1, 2, ... , numClasses - 1.
* @return a DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClassesForClassification: Int,
labelWeights: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification,
labelWeights = labelWeights)
new DecisionTree(strategy).train(input)
}

/**
* Method to train a decision tree model where the instances are represented as an RDD of
* (label, features) pairs. The decision tree method supports binary classification and
Expand All @@ -303,10 +272,6 @@ object DecisionTree extends Serializable with Logging {
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @param labelWeights A map storing weights applied to each label for handling unbalanced
* datasets. For example, an entry (n -> k) implies the a weight of k is
* applied to an instance with label n. It's important to note that labels
* are zero-index and take values 0, 1, 2, ... , numClasses - 1.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
Expand All @@ -322,12 +287,11 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int,
numClassesForClassification: Int,
labelWeights: Map[Int,Int],
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights)
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).train(input)
}

Expand Down Expand Up @@ -442,8 +406,6 @@ object DecisionTree extends Serializable with Logging {
logDebug("numBins = " + numBins)
val numClasses = strategy.numClassesForClassification
logDebug("numClasses = " + numClasses)
val labelWeights = strategy.labelWeights
logDebug("labelWeights = " + labelWeights)
val isMulticlassClassification = strategy.isMulticlassClassification
logDebug("isMulticlassClassification = " + isMulticlassClassification)
val isMulticlassClassificationWithCategoricalFeatures
Expand Down Expand Up @@ -647,7 +609,7 @@ object DecisionTree extends Serializable with Logging {
val aggIndex = aggShift + numClasses * featureIndex * numBins
+ arr(arrIndex).toInt * numClasses
val labelInt = label.toInt
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1)
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
}

def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
Expand All @@ -667,10 +629,10 @@ object DecisionTree extends Serializable with Logging {
val labelInt = label.toInt
if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
agg(aggIndex + binIndex)
= agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1)
= agg(aggIndex + binIndex) + 1
} else {
agg(rightChildShift + aggIndex + binIndex)
= agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1)
= agg(rightChildShift + aggIndex + binIndex) + 1
}
binIndex += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* zero-indexed.
* @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
* 128 MB.
* @param labelWeights A map storing weights applied to each label for handling unbalanced
* datasets. For example, an entry (n -> k) implies the a weight of k is
* applied to an instance with label n. It's important to note that labels
* are zero-index and take values 0, 1, 2, ... , numClasses.
*
*/
@Experimental
Expand All @@ -54,8 +50,7 @@ class Strategy (
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemoryInMB: Int = 128,
val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable {
val maxMemoryInMB: Int = 128) extends Serializable {

require(numClassesForClassification >= 2)
val isMulticlassClassification = numClassesForClassification > 2
Expand Down

0 comments on commit afced16

Please sign in to comment.