Skip to content

Commit

Permalink
support ordered categorical splits for multiclass classification
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Jun 4, 2014
1 parent e3e8843 commit adc7315
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 31 deletions.
95 changes: 67 additions & 28 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ object DecisionTree extends Serializable with Logging {
* Find bin for one feature.
*/
def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint,
isFeatureContinuous: Boolean): Int = {
isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
val binForFeatures = bins(featureIndex)
val feature = labeledPoint.features(featureIndex)

Expand Down Expand Up @@ -550,14 +550,14 @@ object DecisionTree extends Serializable with Logging {
* splits. The actual left/right child allocation per split is performed in the
* sequential phase of the bin aggregate operation.
*/
def sequentialBinSearchForCategoricalFeatureInMulticlassClassification(): Int = {
def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
labeledPoint.features(featureIndex).toInt
}

/**
* Sequential search helper method to find bin for categorical feature.
*/
def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = {
def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
var binIndex = 0
Expand All @@ -583,10 +583,10 @@ object DecisionTree extends Serializable with Logging {
} else {
// Perform sequential search to find bin for categorical features.
val binIndex = {
if (isMulticlassClassification) {
sequentialBinSearchForCategoricalFeatureInMulticlassClassification()
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
} else {
sequentialBinSearchForCategoricalFeatureInBinaryClassification()
sequentialBinSearchForOrderedCategoricalFeatureInClassification()
}
}
if (binIndex == -1){
Expand Down Expand Up @@ -622,8 +622,19 @@ object DecisionTree extends Serializable with Logging {
} else {
var featureIndex = 0
while (featureIndex < numFeatures) {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
val isFeatureContinuous = featureInfo.isEmpty
if (isFeatureContinuous) {
arr(shift + featureIndex)
= findBin(featureIndex, labeledPoint, isFeatureContinuous, false)
} else {
val featureCategories = featureInfo.get
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
arr(shift + featureIndex)
= findBin(featureIndex, labeledPoint, isFeatureContinuous,
isSpaceSufficientForAllCategoricalSplits)
}
featureIndex += 1
}
}
Expand Down Expand Up @@ -731,12 +742,19 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isContinuousFeature) {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
} else {
updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
if (isSpaceSufficientForAllCategoricalSplits) {
updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
} else {
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
}
}
featureIndex += 1
}
}
Expand Down Expand Up @@ -1093,7 +1111,14 @@ object DecisionTree extends Serializable with Logging {
if (isFeatureContinuous) {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
} else {
findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
if (isSpaceSufficientForAllCategoricalSplits) {
findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
} else {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
}
}
} else {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
Expand Down Expand Up @@ -1168,7 +1193,9 @@ object DecisionTree extends Serializable with Logging {
numBins - 1
} else { // Categorical feature
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
if (isMulticlassClassification) {
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
math.pow(2.0, featureCategories - 1).toInt - 1
} else { // Binary classification
featureCategories
Expand Down Expand Up @@ -1289,11 +1316,6 @@ object DecisionTree extends Serializable with Logging {
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
"in categorical features")
if (isMulticlassClassification) {
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1,
"numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" +
" with categorical variables")
}
}


Expand Down Expand Up @@ -1332,10 +1354,12 @@ object DecisionTree extends Serializable with Logging {
}
} else { // Categorical feature
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1

// Use different bin/split calculation strategy for categorical features in multiclass
// classification
if (isMulticlassClassification) {
// classification that satisfy the space constraint
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
// 2^(maxFeatureValue- 1) - 1 combinations
var index = 0
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
Expand All @@ -1360,14 +1384,29 @@ object DecisionTree extends Serializable with Logging {
}
index += 1
}
} else { // regression or binary classification

// For categorical variables, each bin is a category. The bins are sorted and they
// are ordered by calculating the centroid of their corresponding labels.
val centroidForCategories =
sampledInput.map(lp => (lp.features(featureIndex),lp.label))
.groupBy(_._1)
.mapValues(x => x.map(_._2).sum / x.map(_._1).length)
} else {

val centroidForCategories = {
if (isMulticlassClassification) {
// For categorical variables in multiclass classification,
// each bin is a category. The bins are sorted and they
// are ordered by calculating the impurity of their corresponding labels.
sampledInput.map(lp => (lp.features(featureIndex), lp.label))
.groupBy(_._1)
.mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
.map(x => (x._1, x._2.values.toArray))
.map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum)))
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// each bin is a category. The bins are sorted and they
// are ordered by calculating the centroid of their corresponding labels.
sampledInput.map(lp => (lp.features(featureIndex), lp.label))
.groupBy(_._1)
.mapValues(x => x.map(_._2).sum / x.map(_._1).length)
}
}

logDebug("centriod for categories = " + centroidForCategories.mkString(","))

// Check for missing categorical variables and putting them last in the sorted list.
val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
}

test("split and bin calculations for categorical variables with multiclass classification") {
test("split and bin calculations for unordered categorical variables with multiclass " +
"classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
Expand Down Expand Up @@ -332,6 +333,62 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

}

test("split and bin calculations for ordered categorical variables with multiclass " +
"classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
assert(arr.length === 3000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
maxDepth = 3,
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)

// 2^10 - 1 > 100, so categorical variables will be ordered

assert(splits(0)(0).feature === 0)
assert(splits(0)(0).threshold === Double.MinValue)
assert(splits(0)(0).featureType === Categorical)
assert(splits(0)(0).categories.length === 1)
assert(splits(0)(0).categories.contains(1.0))

assert(splits(0)(1).feature === 0)
assert(splits(0)(1).threshold === Double.MinValue)
assert(splits(0)(1).featureType === Categorical)
assert(splits(0)(1).categories.length === 2)
assert(splits(0)(1).categories.contains(2.0))

assert(splits(0)(2).feature === 0)
assert(splits(0)(2).threshold === Double.MinValue)
assert(splits(0)(2).featureType === Categorical)
assert(splits(0)(2).categories.length === 3)
assert(splits(0)(2).categories.contains(2.0))
assert(splits(0)(2).categories.contains(1.0))

assert(splits(0)(10) === null)
assert(splits(1)(10) === null)


// Check bins.

assert(bins(0)(0).category === 1.0)
assert(bins(0)(0).lowSplit.categories.length === 0)
assert(bins(0)(0).highSplit.categories.length === 1)
assert(bins(0)(0).highSplit.categories.contains(1.0))
assert(bins(0)(1).category === 2.0)
assert(bins(0)(1).lowSplit.categories.length === 1)
assert(bins(0)(1).highSplit.categories.length === 2)
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.contains(2.0))

assert(bins(0)(10) === null)

}


test("classification stump with all categorical variables") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
Expand Down Expand Up @@ -547,7 +604,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.featureType === Categorical)
}


test("stump with continuous variables for multiclass classification") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val input = sc.parallelize(arr)
Expand All @@ -568,7 +624,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

}


test("stump with continuous + categorical variables for multiclass classification") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val input = sc.parallelize(arr)
Expand All @@ -588,6 +643,26 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.threshold < 2020)
}

test("stump with categorical variables for ordered multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)

assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
assert(bestSplit.feature === 0)
assert(bestSplit.categories.length === 1)
println(bestSplit)
assert(bestSplit.categories.contains(1.0))
assert(bestSplit.featureType === Categorical)
}


}

object DecisionTreeSuite {
Expand Down Expand Up @@ -662,5 +737,20 @@ object DecisionTreeSuite {
arr
}

def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](3000)
for (i <- 0 until 3000) {
if (i < 1000) {
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
} else if (i < 2000) {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0))
} else {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 2.0))
}
}
arr
}


}

0 comments on commit adc7315

Please sign in to comment.