Skip to content

Commit

Permalink
multiclass for continous variables
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 22, 2014
1 parent 8cfd3b6 commit f5f6b83
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 36 deletions.
56 changes: 26 additions & 30 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -987,48 +987,44 @@ object DecisionTree extends Serializable with Logging {
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {


def findAggForOrderedFeature(
def findAggForOrderedFeatureClassification(
leftNodeAgg: Array[Array[Array[Double]]],
rightNodeAgg: Array[Array[Array[Double]]],
featureIndex: Int) {

// shift for this featureIndex
val shift = 2 * featureIndex * numBins

// left node aggregate for the lowest split
leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)

// right node aggregate for the highest split
rightNodeAgg(featureIndex)(numBins - 2)(0)
= binData(shift + (2 * (numBins - 1)))
rightNodeAgg(featureIndex)(numBins - 2)(1)
= binData(shift + (2 * (numBins - 1)) + 1)
val shift = numClasses * featureIndex * numBins

var classIndex = 0
while (classIndex < numClasses) {
// left node aggregate for the lowest split
leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex)
// right node aggregate for the highest split
rightNodeAgg(featureIndex)(numBins - 2)(classIndex)
= binData(shift + (numClasses * (numBins - 1)) + classIndex)
classIndex += 1
}

// Iterate over all splits.
var splitIndex = 1
while (splitIndex < numBins - 1) {
// calculating left node aggregate for a split as a sum of left node aggregate of a
// lower split and the left bin aggregate of a bin where the split is a high split
leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) +
leftNodeAgg(featureIndex)(splitIndex - 1)(0)
leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex +
1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1)

// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) =
binData(shift + (2 * (numBins - 2 - splitIndex))) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0)
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) =
binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1)

var innerClassIndex = 0
while (innerClassIndex < numClasses) {
leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex)
= binData(shift + numClasses * splitIndex + innerClassIndex) +
leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
binData(shift + (numClasses * (numBins - 2 - splitIndex) + innerClassIndex)) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
innerClassIndex += 1
}
splitIndex += 1
}
}

def extractAggForCategoricalFeature(
def findAggregateForCategoricalFeatureClassification(
leftNodeAgg: Array[Array[Array[Double]]],
rightNodeAgg: Array[Array[Array[Double]]],
featureIndex: Int) {
Expand Down Expand Up @@ -1108,12 +1104,12 @@ object DecisionTree extends Serializable with Logging {
if (isMulticlassClassification){
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex)
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
} else {
extractAggForCategoricalFeature(leftNodeAgg, rightNodeAgg, featureIndex)
findAggregateForCategoricalFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
}
} else {
findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex)
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
}
featureIndex += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1

//assert(bestSplit.feature === 1)
//assert(bestSplit.featureType == Continuous)
//assert(bestSplit.threshold > 1000)
println(bestSplit)
assert(bestSplit.feature === 1)
assert(bestSplit.featureType === Continuous)
assert(bestSplit.threshold > 1980)
assert(bestSplit.threshold < 2020)

}

Expand Down Expand Up @@ -639,9 +639,9 @@ object DecisionTreeSuite {
val arr = new Array[WeightedLabeledPoint](3000)
for (i <- 0 until 3000) {
if (i < 2000) {
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 100))
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, i))
} else {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 3000))
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, i))
}
}
arr
Expand Down

0 comments on commit f5f6b83

Please sign in to comment.