Skip to content

Commit

Permalink
added gain stats class
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent dad0afc commit 4798aae
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
45 changes: 23 additions & 22 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ object DecisionTree extends Serializable {
level: Int,
filters : Array[List[Filter]],
splits : Array[Array[Split]],
bins : Array[Array[Bin]]) : Array[(Split, Double, Long, Long)] = {
bins : Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = {

//Common calculations for multiple nested methods
val numNodes = scala.math.pow(2, level).toInt
Expand Down Expand Up @@ -241,7 +241,7 @@ object DecisionTree extends Serializable {
featureIndex: Int,
index: Int,
rightNodeAgg: Array[Array[Double]],
topImpurity: Double) : (Double, Long, Long) = {
topImpurity: Double) : InformationGainStats = {

val left0Count = leftNodeAgg(featureIndex)(2 * index)
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
Expand All @@ -251,20 +251,22 @@ object DecisionTree extends Serializable {
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
val rightCount = right0Count + right1Count

if (leftCount == 0) return (0, leftCount.toLong, rightCount.toLong)
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong)
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0)

//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)

if (rightCount == 0) return (0, leftCount.toLong, rightCount.toLong)

//println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)

val leftWeight = leftCount.toDouble / (leftCount + rightCount)
val rightWeight = rightCount.toDouble / (leftCount + rightCount)

(topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong)
val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity

new InformationGainStats(gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)

}

Expand Down Expand Up @@ -307,9 +309,9 @@ object DecisionTree extends Serializable {
}

def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double)
: Array[Array[(Double,Long,Long)]] = {
: Array[Array[InformationGainStats]] = {

val gains = Array.ofDim[(Double,Long,Long)](numFeatures, numSplits - 1)
val gains = Array.ofDim[InformationGainStats](numFeatures, numSplits - 1)

for (featureIndex <- 0 until numFeatures) {
for (index <- 0 until numSplits -1) {
Expand All @@ -325,44 +327,43 @@ object DecisionTree extends Serializable {
@param binData Array[Double] of size 2*numSplits*numFeatures
*/
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, Double, Long, Long) = {
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = {
println("node impurity = " + nodeImpurity)
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)

//println("gains.size = " + gains.size)
//println("gains(0).size = " + gains(0).size)

val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount) = {
val (bestFeatureIndex,bestSplitIndex, gainStats) = {
var bestFeatureIndex = 0
var bestSplitIndex = 0
var maxGain = Double.MinValue
var leftSamples = Long.MinValue
var rightSamples = Long.MinValue
//Initialization with infeasible values
var bestGainStats = new InformationGainStats(-1.0,-1.0,-1.0,0,-1.0,0)
// var maxGain = Double.MinValue
// var leftSamples = Long.MinValue
// var rightSamples = Long.MinValue
for (featureIndex <- 0 until numFeatures) {
for (splitIndex <- 0 until numSplits - 1){
val gain = gains(featureIndex)(splitIndex)
val gainStats = gains(featureIndex)(splitIndex)
//println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
if(gain._1 > maxGain) {
maxGain = gain._1
leftSamples = gain._2
rightSamples = gain._3
if(gainStats.gain > bestGainStats.gain) {
bestGainStats = gainStats
bestFeatureIndex = featureIndex
bestSplitIndex = splitIndex
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex
+ ", maxGain = " + maxGain + ", leftSamples = " + leftSamples + ",rightSamples = " + rightSamples)
+ ", gain stats = " + bestGainStats)
}
}
}
(bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples)
(bestFeatureIndex,bestSplitIndex,bestGainStats)
}

(splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount)
//TODO: Return array of node stats with split and impurity information
(splits(bestFeatureIndex)(bestSplitIndex),gainStats)
}

//Calculate best splits for all nodes at a given level
val bestSplits = new Array[(Split, Double, Long, Long)](numNodes)
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
for (node <- 0 until numNodes){
val shift = 2*node*numSplits*numFeatures
val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bestSplits.length == 1)
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
assert(0==bestSplits(0)._2)
assert(10==bestSplits(0)._3)
assert(990==bestSplits(0)._4)
assert(0==bestSplits(0)._2.gain)
assert(10==bestSplits(0)._2.leftSamples)
assert(0==bestSplits(0)._2.leftImpurity)
assert(990==bestSplits(0)._2.rightSamples)
assert(0==bestSplits(0)._2.rightImpurity)
}

test("stump with fixed label 1 for Gini"){
Expand All @@ -93,9 +95,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bestSplits.length == 1)
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
assert(0==bestSplits(0)._2)
assert(10==bestSplits(0)._3)
assert(990==bestSplits(0)._4)
assert(0==bestSplits(0)._2.gain)
assert(10==bestSplits(0)._2.leftSamples)
assert(0==bestSplits(0)._2.leftImpurity)
assert(990==bestSplits(0)._2.rightSamples)
assert(0==bestSplits(0)._2.rightImpurity)
}


Expand All @@ -115,9 +119,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bestSplits.length == 1)
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
assert(0==bestSplits(0)._2)
assert(10==bestSplits(0)._3)
assert(990==bestSplits(0)._4)
assert(0==bestSplits(0)._2.gain)
assert(10==bestSplits(0)._2.leftSamples)
assert(0==bestSplits(0)._2.leftImpurity)
assert(990==bestSplits(0)._2.rightSamples)
assert(0==bestSplits(0)._2.rightImpurity)
}

test("stump with fixed label 1 for Entropy"){
Expand All @@ -136,9 +142,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bestSplits.length == 1)
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
assert(0==bestSplits(0)._2)
assert(10==bestSplits(0)._3)
assert(990==bestSplits(0)._4)
assert(0==bestSplits(0)._2.gain)
assert(10==bestSplits(0)._2.leftSamples)
assert(0==bestSplits(0)._2.leftImpurity)
assert(990==bestSplits(0)._2.rightSamples)
assert(0==bestSplits(0)._2.rightImpurity)
}


Expand Down

0 comments on commit 4798aae

Please sign in to comment.