Skip to content

Commit

Permalink
tests and use multiclass binaggregate length when atleast one categor…
Browse files Browse the repository at this point in the history
…ical feature is present
  • Loading branch information
manishamde committed May 23, 2014
1 parent f5f6b83 commit 1892a2c
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 120 deletions.
6 changes: 3 additions & 3 deletions docs/mllib-decision-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ bins if the condition is not satisfied.

**Categorical features**

For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
For `$M$` categorical features, one could come up with `$2^(M-1)-1$` split candidates. For
binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the
categorical feature values by the proportion of labels falling in one of the two classes (see
Section 9.2.4 in
[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
details). For example, for a binary classification problem with one categorical feature with three
categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
and A , B \| C where \| denotes the split.

<!-- -->
### Stopping rule

The recursive tree construction is stopped at a node when one of the two conditions is met:
Expand Down
234 changes: 118 additions & 116 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Max memory usage for aggregates
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
val numElementsPerNode =
strategy.algo match {
case Classification => 2 * numBins * numFeatures
case Regression => 3 * numBins * numFeatures
}
val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins,
strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
strategy.algo)

logDebug("numElementsPerNode = " + numElementsPerNode)
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
Expand Down Expand Up @@ -144,7 +142,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
new DecisionTreeModel(topNode, strategy.algo)
}

// TODO: Unit test this
/**
* Extract the decision tree node information for the given tree level and node index
*/
Expand All @@ -162,7 +159,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
nodes(nodeIndex) = node
}

// TODO: Unit test this
/**
* Extract the decision tree node information for the children of the node
*/
Expand Down Expand Up @@ -290,12 +286,12 @@ object DecisionTree extends Serializable with Logging {
* @return a DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClassesForClassification: Int,
labelWeights: Map[Int,Int]): DecisionTreeModel = {
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClassesForClassification: Int,
labelWeights: Map[Int,Int]): DecisionTreeModel = {
val strategy
= new Strategy(algo, impurity, maxDepth, numClassesForClassification,
labelWeights = labelWeights)
Expand Down Expand Up @@ -462,7 +458,9 @@ object DecisionTree extends Serializable with Logging {
logDebug("labelWeights = " + labelWeights)
val isMulticlassClassification = strategy.isMulticlassClassification
logDebug("isMulticlassClassification = " + isMulticlassClassification)

val isMulticlassClassificationWithCategoricalFeatures
= strategy.isMulticlassWithCategoricalFeatures
logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassClassificationWithCategoricalFeatures)

// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex
Expand Down Expand Up @@ -518,9 +516,7 @@ object DecisionTree extends Serializable with Logging {
/**
* Find bin for one feature.
*/
def findBin(
featureIndex: Int,
labeledPoint: WeightedLabeledPoint,
def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint,
isFeatureContinuous: Boolean): Int = {
val binForFeatures = bins(featureIndex)
val feature = labeledPoint.features(featureIndex)
Expand Down Expand Up @@ -636,9 +632,48 @@ object DecisionTree extends Serializable with Logging {
}

// Find feature bins for all nodes at a level.
val binMappedRDD = input.map(x => findBinsForLevel(x))
val binMappedRDD = input.map(x => findBinsForLevel(x))

def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int,
label: Double, featureIndex: Int) = {

// Find the bin index for this feature.
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggShift = numClasses * numBins * numFeatures * nodeIndex
val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
val labelInt = label.toInt
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1)
}

/**
def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
label: Double, agg: Array[Double], rightChildShift: Int) = {
// Find the bin index for this feature.
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggShift = numClasses * numBins * numFeatures * nodeIndex
val aggIndex
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
// Find all matching bins and increment their values
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
var binIndex = 0
while (binIndex < numCategoricalBins) {
val labelInt = label.toInt
if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
agg(aggIndex + binIndex)
= agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1)
} else {
agg(rightChildShift + aggIndex + binIndex)
= agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1)
}
binIndex += 1
}
}

/**
* Performs a sequential aggregation over a partition for classification. For l nodes,
* k features, either the left count or the right count of one of the p bins is
* incremented based upon whether the feature is classified as 0 or 1.
Expand All @@ -649,7 +684,7 @@ object DecisionTree extends Serializable with Logging {
* @return Array[Double] storing aggregate calculation of size
* 2 * numSplits * numFeatures * numNodes for classification
*/
def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
Expand All @@ -662,93 +697,51 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
// Find the bin index for this feature.
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggShift = 2 * numBins * numFeatures * nodeIndex
val aggIndex
= aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
label.toInt match {
case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1)
}
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
featureIndex += 1
}
}
nodeIndex += 1
}
}

/**
* Performs a sequential aggregation over a partition for classification. For l nodes,
* k features, either the left count or the right count of one of the p bins is
* incremented based upon whether the feature is classified as 0 or 1.
*
* @param agg Array[Double] storing aggregate calculation of size
* numClasses * numSplits * numFeatures*numNodes for classification
* @param arr Array[Double] of size 1 + (numFeatures * numNodes)
* @return Array[Double] storing aggregate calculation of size
* 2 * numClasses * numSplits * numFeatures * numNodes for classification
*/
def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
// Check whether the instance was valid for this nodeIndex.
val validSignalIndex = 1 + numFeatures * nodeIndex
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
if (isSampleValidForNode) {
val rightChildShift = numClasses * numBins * numFeatures * numNodes
// actual class label
val label = arr(0)
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isContinuousFeature) {
// Find the bin index for this feature.
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggShift = numClasses * numBins * numFeatures * nodeIndex
val aggIndex
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
label.toInt match {
case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1)
}
} else {
// Find the bin index for this feature.
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggShift = numClasses * numBins * numFeatures * nodeIndex
val aggIndex
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
label.toInt match {
case n: Int =>
// Find all matching bins and increment their values
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
var binIndex = 0
while (binIndex < numCategoricalBins) {
if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)) {
agg(aggIndex + binIndex)
= agg(aggIndex + binIndex) + labelWeights.getOrElse(n, 1)
} else {
agg(rightChildShift + aggIndex + binIndex)
= agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(n, 1)

}
binIndex += 1
}
}
}
featureIndex += 1
}
/**
* Performs a sequential aggregation over a partition for classification. For l nodes,
* k features, either the left count or the right count of one of the p bins is
* incremented based upon whether the feature is classified as 0 or 1.
*
* @param agg Array[Double] storing aggregate calculation of size
* numClasses * numSplits * numFeatures*numNodes for classification
* @param arr Array[Double] of size 1 + (numFeatures * numNodes)
* @return Array[Double] storing aggregate calculation of size
* 2 * numClasses * numSplits * numFeatures * numNodes for classification
*/
def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
// Check whether the instance was valid for this nodeIndex.
val validSignalIndex = 1 + numFeatures * nodeIndex
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
if (isSampleValidForNode) {
val rightChildShift = numClasses * numBins * numFeatures * numNodes
// actual class label
val label = arr(0)
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isContinuousFeature) {
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
} else {
updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
}
featureIndex += 1
}
nodeIndex += 1
}
nodeIndex += 1
}
}

/**
* Performs a sequential aggregation over a partition for regression. For l nodes, k features,
Expand All @@ -760,7 +753,7 @@ object DecisionTree extends Serializable with Logging {
* @return Array[Double] storing aggregate calculation of size
* 3 * numSplits * numFeatures * numNodes for regression
*/
def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) {
def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
Expand Down Expand Up @@ -795,7 +788,7 @@ object DecisionTree extends Serializable with Logging {
def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
strategy.algo match {
case Classification =>
if(isMulticlassClassification) {
if(isMulticlassClassificationWithCategoricalFeatures) {
multiClassificationBinSeqOp(arr, agg)
} else {
binaryClassificationBinSeqOp(arr, agg)
Expand All @@ -806,15 +799,8 @@ object DecisionTree extends Serializable with Logging {
}

// Calculate bin aggregate length for classification or regression.
val binAggregateLength = strategy.algo match {
case Classification =>
if (isMulticlassClassification){
2 * numClasses * numBins * numFeatures * numNodes
} else {
2 * numBins * numFeatures * numNodes
}
case Regression => 3 * numBins * numFeatures * numNodes
}
val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
logDebug("binAggregateLength = " + binAggregateLength)

/**
Expand Down Expand Up @@ -1024,7 +1010,7 @@ object DecisionTree extends Serializable with Logging {
}
}

def findAggregateForCategoricalFeatureClassification(
def findAggForUnorderedFeatureClassification(
leftNodeAgg: Array[Array[Array[Double]]],
rightNodeAgg: Array[Array[Array[Double]]],
featureIndex: Int) {
Expand Down Expand Up @@ -1101,12 +1087,12 @@ object DecisionTree extends Serializable with Logging {
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
var featureIndex = 0
while (featureIndex < numFeatures) {
if (isMulticlassClassification){
if (isMulticlassClassificationWithCategoricalFeatures){
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
} else {
findAggregateForCategoricalFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
}
} else {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
Expand Down Expand Up @@ -1214,7 +1200,7 @@ object DecisionTree extends Serializable with Logging {
def getBinDataForNode(node: Int): Array[Double] = {
strategy.algo match {
case Classification =>
if (isMulticlassClassification) {
if (isMulticlassClassificationWithCategoricalFeatures) {
val shift = numClasses * node * numBins * numFeatures
val rightChildShift = numClasses * numBins * numFeatures * numNodes
val binsForNode = {
Expand Down Expand Up @@ -1251,10 +1237,22 @@ object DecisionTree extends Serializable with Logging {
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
node += 1
}

bestSplits
}

private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int,
isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = {
algo match {
case Classification =>
if (isMulticlassClassificationWithCategoricalFeatures) {
2 * numClasses * numBins * numFeatures
} else {
numClasses * numBins * numFeatures
}
case Regression => 3 * numBins * numFeatures
}
}

/**
* Returns split and bins for decision tree calculation.
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
Expand Down Expand Up @@ -1288,9 +1286,12 @@ object DecisionTree extends Serializable with Logging {
*/
if (strategy.categoricalFeaturesInfo.size > 0) {
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
require(numBins > maxCategoriesForFeatures)
require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
"in categorical features")
if (isMulticlassClassification) {
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1)
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1,
"numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" +
" with categorical variables")
}
}

Expand Down Expand Up @@ -1331,7 +1332,8 @@ object DecisionTree extends Serializable with Logging {
} else { // Categorical feature
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)

// Use different bin/split calculation strategy for multiclass classification
// Use different bin/split calculation strategy for categorical features in multiclass
// classification
if (isMulticlassClassification) {
// 2^(maxFeatureValue- 1) - 1 combinations
var index = 0
Expand Down
Loading

0 comments on commit 1892a2c

Please sign in to comment.