Skip to content

Commit

Permalink
adding support for very deep trees
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Apr 20, 2014
1 parent 3a390bf commit 50b143a
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
logDebug("numSplits = " + bins(0).length)
val numBins = bins(0).length
logDebug("numBins = " + numBins)

// depth of the decision tree
val maxDepth = strategy.maxDepth
Expand All @@ -72,7 +73,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val parentImpurities = new Array[Double](maxNumNodes)
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
val numFeatures = input.take(1)(0).features.size

// Calculate level for single group construction

// Max memory usage for aggregates
val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB
logDebug("max memory usage for aggregates = " + maxMemoryUsage)
val numElementsPerNode = {
strategy.algo match {
case Classification => 2 * numBins * numFeatures
case Regression => 3 * numBins * numFeatures
}
}
logDebug("numElementsPerNode = " + numElementsPerNode)
val arraySizePerNode = 8 * numElementsPerNode //approx. memory usage for bin aggregate array
val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1)
logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
// nodes at a level is 2^(level-1). level is zero indexed.
val maxLevelForSingleGroup = scala.math.max(
(scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2)).floor.toInt - 1, 0)
logDebug("max level for single group = " + maxLevelForSingleGroup)

/*
* The main idea here is to perform level-wise training of the decision tree nodes thus
Expand All @@ -92,7 +114,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

// Find best split for all nodes at a level.
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
level, filters, splits, bins)
level, filters, splits, bins, maxLevelForSingleGroup)

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
// Extract info for nodes at the current level.
Expand All @@ -110,6 +132,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
}

logDebug("#####################################")
logDebug("Extracting tree model")
logDebug("#####################################")

// Initialize the top or root node of the tree.
val topNode = nodes(0)
// Build the full tree using the node info calculated in the level-wise best split calculations.
Expand Down Expand Up @@ -260,6 +286,7 @@ object DecisionTree extends Serializable with Logging {
* @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
* @param bins possible bins for all features
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
* @return array of splits with best splits for all nodes at a given level.
*/
protected[tree] def findBestSplits(
Expand All @@ -269,7 +296,50 @@ object DecisionTree extends Serializable with Logging {
level: Int,
filters: Array[List[Filter]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
val numGroups = scala.math.pow(2, (level - maxLevelForSingleGroup)).toInt
logDebug("numGroups = " + numGroups)
var groupIndex = 0
var bestSplits = new Array[(Split, InformationGainStats)](0)
while (groupIndex < numGroups) {
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
filters, splits, bins, numGroups, groupIndex)
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
groupIndex += 1
}
bestSplits
} else {
findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
}
}

/**
* Returns an array of optimal splits for a group of nodes at a given level
*
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
* for DecisionTree
* @param parentImpurities Impurities for all parent nodes for the current level
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
* parameters for construction the DecisionTree
* @param level Level of the tree
* @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
* @param bins possible bins for all features
* @return array of splits with best splits for all nodes at a given level.
*/
private def findBestSplitsPerGroup(
input: RDD[LabeledPoint],
parentImpurities: Array[Double],
strategy: Strategy,
level: Int,
filters: Array[List[Filter]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
numGroups: Int = 1,
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {

/*
* The high-level description for the best split optimizations are noted here.
Expand All @@ -296,20 +366,23 @@ object DecisionTree extends Serializable with Logging {
*/

// common calculations for multiple nested methods
val numNodes = scala.math.pow(2, level).toInt
val numNodes = scala.math.pow(2, level).toInt / numGroups
logDebug("numNodes = " + numNodes)
// Find the number of features by looking at the first sample.
val numFeatures = input.first().features.size
logDebug("numFeatures = " + numFeatures)
val numBins = bins(0).length
logDebug("numBins = " + numBins)

// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex

/** Find the filters used before reaching the current code. */
def findParentFilters(nodeIndex: Int): List[Filter] = {
if (level == 0) {
List[Filter]()
} else {
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + groupShift
filters(nodeFilterIndex)
}
}
Expand Down Expand Up @@ -878,7 +951,7 @@ object DecisionTree extends Serializable with Logging {
// Iterating over all nodes at this level
var node = 0
while (node < numNodes) {
val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + groupShift
val binsForNode: Array[Double] = getBinDataForNode(node)
logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)

val split = bestSplits(0)._1
assert(split.categories.length === 1)
Expand All @@ -281,7 +281,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)

val split = bestSplits(0)._1
assert(split.categories.length === 1)
Expand Down Expand Up @@ -310,7 +310,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length === 100)

val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
Expand All @@ -333,7 +333,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length === 100)

val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
Expand All @@ -357,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length === 100)

val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
Expand All @@ -381,7 +381,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length === 100)

val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
Expand Down

0 comments on commit 50b143a

Please sign in to comment.