Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 18, 2014
1 parent 7e5f08c commit bce835f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
30 changes: 19 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
new DecisionTreeModel(topNode, strategy.algo)
}

// TODO: Unit test this
/**
* Extract the decision tree node information for the given tree level and node index
*/
Expand All @@ -161,6 +162,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
nodes(nodeIndex) = node
}

// TODO: Unit test this
/**
* Extract the decision tree node information for the children of the node
*/
Expand Down Expand Up @@ -458,6 +460,8 @@ object DecisionTree extends Serializable with Logging {
logDebug("numClasses = " + numClasses)
val labelWeights = strategy.labelWeights
logDebug("labelWeights = " + labelWeights)
val isMulticlassClassification = strategy.isMulticlassClassification
logDebug("isMulticlassClassification = " + isMulticlassClassification)


// shift when more than one group is used at deep tree level
Expand Down Expand Up @@ -582,7 +586,7 @@ object DecisionTree extends Serializable with Logging {
} else {
// Perform sequential search to find bin for categorical features.
val binIndex = {
if (strategy.isMultiClassification) {
if (isMulticlassClassification) {
sequentialBinSearchForCategoricalFeatureInBinaryClassification()
} else {
sequentialBinSearchForCategoricalFeatureInMultiClassClassification()
Expand All @@ -606,7 +610,9 @@ object DecisionTree extends Serializable with Logging {
def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = {
// Calculate bin index and label per feature per node.
val arr = new Array[Double](1 + (numFeatures * numNodes))
// First element of the array is the label of the instance.
arr(0) = labeledPoint.label
// Iterate over nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
val parentFilters = findParentFilters(nodeIndex)
Expand All @@ -629,7 +635,10 @@ object DecisionTree extends Serializable with Logging {
arr
}

/**
// Find feature bins for all nodes at a level.
val binMappedRDD = input.map(x => findBinsForLevel(x))

/**
* Performs a sequential aggregation over a partition for classification. For l nodes,
* k features, either the left count or the right count of one of the p bins is
* incremented based upon whether the feature is classified as 0 or 1.
Expand Down Expand Up @@ -663,7 +672,7 @@ object DecisionTree extends Serializable with Logging {
label.toInt match {
case n: Int =>
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (!isFeatureContinuous && strategy.isMultiClassification) {
if (!isFeatureContinuous && isMulticlassClassification) {
// Find all matching bins and increment their values
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
Expand Down Expand Up @@ -736,7 +745,6 @@ object DecisionTree extends Serializable with Logging {
agg
}

// TODO: Double-check this
// Calculate bin aggregate length for classification or regression.
val binAggregateLength = strategy.algo match {
case Classification => numClasses * numBins * numFeatures * numNodes
Expand All @@ -760,9 +768,6 @@ object DecisionTree extends Serializable with Logging {
combinedAggregate
}

// Find feature bins for all nodes at a level.
val binMappedRDD = input.map(x => findBinsForLevel(x))

// Calculate bin aggregates.
val binAggregates = {
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
Expand Down Expand Up @@ -922,7 +927,7 @@ object DecisionTree extends Serializable with Logging {
val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)

if (strategy.isMultiClassification) {
if (isMulticlassClassification) {
var featureIndex = 0
while (featureIndex < numFeatures){
var splitIndex = 0
Expand Down Expand Up @@ -1096,7 +1101,7 @@ object DecisionTree extends Serializable with Logging {
numBins - 1
} else { // Categorical feature
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
if (strategy.isMultiClassification) {
if (isMulticlassClassification) {
math.pow(2.0, featureCategories - 1).toInt - 1
} else { // Binary classification
featureCategories
Expand Down Expand Up @@ -1177,6 +1182,9 @@ object DecisionTree extends Serializable with Logging {
val maxBins = strategy.maxBins
val numBins = if (maxBins <= count) maxBins else count.toInt
logDebug("numBins = " + numBins)
val isMulticlassClassification = strategy.isMulticlassClassification
logDebug("isMulticlassClassification = " + isMulticlassClassification)


/*
* Ensure #bins is always greater than the categories. For multiclass classification,
Expand All @@ -1187,7 +1195,7 @@ object DecisionTree extends Serializable with Logging {
if (strategy.categoricalFeaturesInfo.size > 0) {
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
require(numBins > maxCategoriesForFeatures)
if (strategy.isMultiClassification) {
if (isMulticlassClassification) {
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1)
}
}
Expand Down Expand Up @@ -1230,7 +1238,7 @@ object DecisionTree extends Serializable with Logging {
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)

// Use different bin/split calculation strategy for multiclass classification
if (strategy.isMultiClassification) {
if (isMulticlassClassification) {
// 2^(maxFeatureValue- 1) - 1 combinations
var index = 0
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ class Strategy (
val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable {

require(numClassesForClassification >= 2)
val isMultiClassification = numClassesForClassification > 2
val isMulticlassClassification = numClassesForClassification > 2

}

0 comments on commit bce835f

Please sign in to comment.