Skip to content

Commit

Permalink
multiclass logic
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 18, 2014
1 parent d8e4a11 commit ab5cb21
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 54 deletions.
94 changes: 71 additions & 23 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -545,17 +545,24 @@ object DecisionTree extends Serializable with Logging {
-1
}

/**
* Sequential search helper method to find bin for categorical feature in multiclass
* classification. Dummy value of 0 used since it is not used in future calculation
*/
def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = 0

/**
* Sequential search helper method to find bin for categorical feature.
*/
def sequentialBinSearchForCategoricalFeature(): Int = {
val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
def sequentialBinSearchForCategoricalFeatureInMultiClassClassification(): Int = {
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
var binIndex = 0
while (binIndex < numCategoricalBins) {
val bin = bins(featureIndex)(binIndex)
val category = bin.category
val categories = bin.highSplit.categories
val features = labeledPoint.features
if (category == features(featureIndex)) {
if (categories.contains(features(featureIndex))) {
return binIndex
}
binIndex += 1
Expand All @@ -572,7 +579,14 @@ object DecisionTree extends Serializable with Logging {
binIndex
} else {
// Perform sequential search to find bin for categorical features.
val binIndex = sequentialBinSearchForCategoricalFeature()
val binIndex = {
if (strategy.isMultiClassification) {
sequentialBinSearchForCategoricalFeatureInBinaryClassification()
}
else {
sequentialBinSearchForCategoricalFeatureInMultiClassClassification()
}
}
if (binIndex == -1){
throw new UnknownError("no bin was found for categorical variable.")
}
Expand All @@ -584,7 +598,8 @@ object DecisionTree extends Serializable with Logging {
* Finds bins for all nodes (and all features) at a given level.
* For l nodes, k features the storage is as follows:
* label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
* where b_ij is an integer between 0 and numBins - 1.
* where b_ij is an integer between 0 and numBins - 1 for regressions and binary
* classification and an invalid value for categorical feature in multiclass classification.
* Invalid sample is denoted by noting bin for feature 1 as -1.
*/
def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = {
Expand Down Expand Up @@ -646,7 +661,22 @@ object DecisionTree extends Serializable with Logging {
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
label.toInt match {
case n: Int =>
agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1)
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous && strategy.isMultiClassification) {
// Find all matching bins and increment their values
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
var binIndex = 0
while (binIndex < numCategoricalBins) {
if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)){
agg(aggIndex + binIndex)
= agg(aggIndex + binIndex) + labelWeights.getOrElse(binIndex, 1)
}
binIndex += 1
}
} else {
agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1)
}
}
featureIndex += 1
}
Expand Down Expand Up @@ -705,6 +735,7 @@ object DecisionTree extends Serializable with Logging {
agg
}

// TODO: Double-check this
// Calculate bin aggregate length for classification or regression.
val binAggregateLength = strategy.algo match {
case Classification => numClasses * numBins * numFeatures * numNodes
Expand Down Expand Up @@ -785,10 +816,10 @@ object DecisionTree extends Serializable with Logging {
}

if (leftTotalCount == 0) {
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1)
}
if (rightTotalCount == 0) {
return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1)
}

val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
Expand All @@ -812,16 +843,16 @@ object DecisionTree extends Serializable with Logging {
= leftCounts.zip(rightCounts)
.map{case (leftCount, rightCount) => leftCount + rightCount}

def indexOfLargest(array: Seq[Double]): Int = {
def indexOfLargestArrayElement(array: Array[Double]): Int = {
val result = array.foldLeft(-1,Double.MinValue,0) {
case ((maxIndex, maxValue, currentIndex), currentValue) =>
if(currentValue > maxValue) (currentIndex,currentValue,currentIndex+1)
else (maxIndex,maxValue,currentIndex+1)
}
if (result._1 < 0) result._1 else 0
if (result._1 < 0) 0 else result._1
}

val predict = indexOfLargest(leftRightCounts)
val predict = indexOfLargestArrayElement(leftRightCounts)
val prob = leftRightCounts(predict) / totalCount

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
Expand Down Expand Up @@ -1051,8 +1082,20 @@ object DecisionTree extends Serializable with Logging {
while (featureIndex < numFeatures) {
// Iterate over all splits.
var splitIndex = 0
// TODO: Modify this for categorical variables to go over only valid splits
while (splitIndex < numBins - 1) {
val maxSplitIndex : Double = {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
numBins - 1
} else { // Categorical feature
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
if (strategy.isMultiClassification) {
math.pow(2.0, featureCategories - 1).toInt - 1
} else { // Binary classification
featureCategories
}
}
}
while (splitIndex < maxSplitIndex) {
val gainStats = gains(featureIndex)(splitIndex)
if (gainStats.gain > bestGainStats.gain) {
bestGainStats = gainStats
Expand Down Expand Up @@ -1176,24 +1219,29 @@ object DecisionTree extends Serializable with Logging {
splits(featureIndex)(index) = split
}
} else { // Categorical feature
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)

// Use different bin/split calculation strategy for multiclass classification
if (strategy.isMultiClassification) {
// Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1
// combinations.
// 2^(maxFeatureValue- 1) - 1 combinations
var index = 0
while (index < math.pow(2.0, maxFeatureValue).toInt - 1) {
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
val categories: List[Double]
= extractMultiClassCategories(index + 1, maxFeatureValue)
= extractMultiClassCategories(index + 1, featureCategories)
splits(featureIndex)(index)
= new Split(featureIndex, Double.MinValue, Categorical, categories)
bins(featureIndex)(index) = {
if (index == 0) {
new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0), Categorical, Double.MinValue)
new Bin(
new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0),
Categorical,
Double.MinValue)
} else {
new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical,
new Bin(
splits(featureIndex)(index - 1),
splits(featureIndex)(index),
Categorical,
Double.MinValue)
}
}
Expand All @@ -1210,7 +1258,7 @@ object DecisionTree extends Serializable with Logging {

// Check for missing categorical variables and putting them last in the sorted list.
val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
for (i <- 0 until maxFeatureValue) {
for (i <- 0 until featureCategories) {
if (centroidForCategories.contains(i)) {
fullCentroidForCategories(i) = centroidForCategories(i)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ object Entropy extends Impurity {
var impurity = 0.0
var classIndex = 0
while (classIndex < numClasses) {
val freq = counts(classIndex) / totalCount
impurity -= freq * log2(freq)
val classCount = counts(classIndex)
if (classCount != 0) {
val freq = classCount / totalCount
impurity -= freq * log2(freq)
}
classIndex += 1
}
impurity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
* @param category categorical label value accepted in the bin for binary classification
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 100)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(bins.length === 2)
Expand All @@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Classification,
Gini,
maxDepth = 3,
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
Expand Down Expand Up @@ -130,6 +131,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Classification,
Gini,
maxDepth = 3,
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
Expand Down Expand Up @@ -237,20 +239,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
}

test("split and bin calculations for categorical variables wiht multiclass classification") {
test("split and bin calculations for categorical variables with multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
maxDepth = 3,
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
numClassesForClassification = 3)
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)

// Expecting 2^3 - 1 = 7 bins/splits
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
assert(splits(0)(0).threshold === Double.MinValue)
assert(splits(0)(0).featureType === Categorical)
Expand Down Expand Up @@ -287,6 +289,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(1)(2).categories.contains(1.0))

assert(splits(0)(3) === null)
assert(splits(1)(3) === null)


// Check bins.
Expand Down Expand Up @@ -329,29 +332,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

}

test("split and bin calculations for categorical variables with no sample for one category " +
"for multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3),
numClassesForClassification = 3)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)

}

test("classification stump with all categorical variables") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
numClassesForClassification = 2,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
Expand All @@ -367,8 +355,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

val stats = bestSplits(0)._2
assert(stats.gain > 0)
assert(stats.predict > 0.4)
assert(stats.predict < 0.5)
assert(stats.predict === 0)
assert(stats.prob > 0.5)
assert(stats.prob < 0.6)
assert(stats.impurity > 0.2)
}

Expand Down Expand Up @@ -403,7 +392,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 100)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand All @@ -426,7 +415,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 100)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand All @@ -450,7 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 100)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand All @@ -474,7 +463,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 100)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand All @@ -498,7 +487,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 100)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand Down

0 comments on commit ab5cb21

Please sign in to comment.