Skip to content

Commit

Permalink
prepared for multiclass without breaking binary classification
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 12, 2014
1 parent 46e06ee commit 6c7af22
Showing 1 changed file with 107 additions and 82 deletions.
189 changes: 107 additions & 82 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ object DecisionTree extends Serializable with Logging {
logDebug("numFeatures = " + numFeatures)
val numBins = bins(0).length
logDebug("numBins = " + numBins)
val numClasses = strategy.numClassesForClassification
logDebug("numClasses = " + numClasses)

// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex
Expand Down Expand Up @@ -545,10 +547,10 @@ object DecisionTree extends Serializable with Logging {
* incremented based upon whether the feature is classified as 0 or 1.
*
* @param agg Array[Double] storing aggregate calculation of size
* 2 * numSplits * numFeatures*numNodes for classification
* numClasses * numSplits * numFeatures*numNodes for classification
* @param arr Array[Double] of size 1 + (numFeatures * numNodes)
* @return Array[Double] storing aggregate calculation of size
* 2 * numSplits * numFeatures * numNodes for classification
* numClasses * numSplits * numFeatures * numNodes for classification
*/
def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
// Iterate over all nodes.
Expand All @@ -562,16 +564,16 @@ object DecisionTree extends Serializable with Logging {
val label = arr(0)
// Iterate over all features.
var featureIndex = 0
// TODO: Multiclass modification here
while (featureIndex < numFeatures) {
// Find the bin index for this feature.
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggShift = 2 * numBins * numFeatures * nodeIndex
val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
label match {
case n: Double => agg(aggIndex) = agg(aggIndex + n.toInt) + 1
val aggShift = numClasses * numBins * numFeatures * nodeIndex
val aggIndex
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
label.toInt match {
case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + 1
}
featureIndex += 1
}
Expand Down Expand Up @@ -632,7 +634,7 @@ object DecisionTree extends Serializable with Logging {

// Calculate bin aggregate length for classification or regression.
val binAggregateLength = strategy.algo match {
case Classification => 2 * numBins * numFeatures * numNodes
case Classification => numClasses * numBins * numFeatures * numNodes
case Regression => 3 * numBins * numFeatures * numNodes
}
logDebug("binAggregateLength = " + binAggregateLength)
Expand Down Expand Up @@ -672,20 +674,20 @@ object DecisionTree extends Serializable with Logging {
* @return information gain and statistics for all splits
*/
def calculateGainForSplit(
leftNodeAgg: Array[Array[Double]],
leftNodeAgg: Array[Array[Array[Double]]],
featureIndex: Int,
splitIndex: Int,
rightNodeAgg: Array[Array[Double]],
rightNodeAgg: Array[Array[Array[Double]]],
topImpurity: Double): InformationGainStats = {
strategy.algo match {
case Classification =>
// TODO: Modify here
val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0)
val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1)
val leftCount = left0Count + left1Count

val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0)
val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1)
val rightCount = right0Count + right1Count

val impurity = {
Expand Down Expand Up @@ -722,13 +724,13 @@ object DecisionTree extends Serializable with Logging {

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
case Regression =>
val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)

val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)

val impurity = {
if (level > 0) {
Expand Down Expand Up @@ -777,98 +779,121 @@ object DecisionTree extends Serializable with Logging {
* Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
*/
def extractLeftRightNodeAggregates(
binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
strategy.algo match {
case Classification =>
// TODO: Multiclass modification here
// Initialize left and right split aggregates.
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
// shift for this featureIndex
val shift = 2 * featureIndex * numBins

// left node aggregate for the lowest split
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
leftNodeAgg(featureIndex)(1) = binData(shift + 1)

// right node aggregate for the highest split
rightNodeAgg(featureIndex)(2 * (numBins - 2))
= binData(shift + (2 * (numBins - 1)))
rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1)
= binData(shift + (2 * (numBins - 1)) + 1)

// Iterate over all splits.
var splitIndex = 1
while (splitIndex < numBins - 1) {
// calculating left node aggregate for a split as a sum of left node aggregate of a
// lower split and the left bin aggregate of a bin where the split is a high split
leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
leftNodeAgg(featureIndex)(2 * splitIndex - 2)
leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) +
leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)

// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
binData(shift + (2 *(numBins - 2 - splitIndex))) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)

splitIndex += 1
// Initialize left and right split aggregates.
val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)

if (strategy.isMultiClassification) {
var featureIndex = 0
while (featureIndex < numFeatures){
val numCategories = strategy.categoricalFeaturesInfo(featureIndex)
val maxSplits = math.pow(2, numCategories) - 1
var i = 0
// TODO: Add multiclass case here
while (i < maxSplits) {
var classIndex = 0
while (classIndex < numClasses) {
// shift for this featureIndex
val shift = numClasses * featureIndex * numBins

classIndex += 1
}
i += 1
}
featureIndex += 1
}
} else {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
// shift for this featureIndex
val shift = 2 * featureIndex * numBins

// left node aggregate for the lowest split
leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)

// right node aggregate for the highest split
rightNodeAgg(featureIndex)(numBins - 2)(0)
= binData(shift + (2 * (numBins - 1)))
rightNodeAgg(featureIndex)(numBins - 2)(1)
= binData(shift + (2 * (numBins - 1)) + 1)

// Iterate over all splits.
var splitIndex = 1
while (splitIndex < numBins - 1) {
// calculating left node aggregate for a split as a sum of left node aggregate of a
// lower split and the left bin aggregate of a bin where the split is a high split
leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) +
leftNodeAgg(featureIndex)(splitIndex - 1)(0)
leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex +
1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1)

// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) =
binData(shift + (2 *(numBins - 2 - splitIndex))) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0)
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) =
binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1)

splitIndex += 1
}
featureIndex += 1
}
featureIndex += 1
}
(leftNodeAgg, rightNodeAgg)
case Regression =>
// Initialize left and right split aggregates.
val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
// shift for this featureIndex
val shift = 3 * featureIndex * numBins
// left node aggregate for the lowest split
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
leftNodeAgg(featureIndex)(2) = binData(shift + 2)
leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)

// right node aggregate for the highest split
rightNodeAgg(featureIndex)(3 * (numBins - 2)) =
rightNodeAgg(featureIndex)(numBins - 2)(0) =
binData(shift + (3 * (numBins - 1)))
rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) =
rightNodeAgg(featureIndex)(numBins - 2)(1) =
binData(shift + (3 * (numBins - 1)) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) =
rightNodeAgg(featureIndex)(numBins - 2)(2) =
binData(shift + (3 * (numBins - 1)) + 2)

// Iterate over all splits.
var splitIndex = 1
while (splitIndex < numBins - 1) {
// calculating left node aggregate for a split as a sum of left node aggregate of a
// lower split and the left bin aggregate of a bin where the split is a high split
leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
leftNodeAgg(featureIndex)(3 * splitIndex - 3)
leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) +
leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) +
leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) +
leftNodeAgg(featureIndex)(splitIndex - 1)(0)
leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) +
leftNodeAgg(featureIndex)(splitIndex - 1)(1)
leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) +
leftNodeAgg(featureIndex)(splitIndex - 1)(2)

// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) =
binData(shift + (3 * (numBins - 2 - splitIndex))) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0)
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) =
binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1)
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) =
binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2)

splitIndex += 1
}
Expand All @@ -882,8 +907,8 @@ object DecisionTree extends Serializable with Logging {
* Calculates information gain for all nodes splits.
*/
def calculateGainsForAllNodeSplits(
leftNodeAgg: Array[Array[Double]],
rightNodeAgg: Array[Array[Double]],
leftNodeAgg: Array[Array[Array[Double]]],
rightNodeAgg: Array[Array[Array[Double]]],
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)

Expand Down

0 comments on commit 6c7af22

Please sign in to comment.