Skip to content

Commit

Permalink
changing instance format to weighted labeled point
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 12, 2014
1 parent a1a6e09 commit 14aea48
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.mllib.point.WeightedLabeledPoint

/**
* :: Experimental ::
Expand All @@ -47,13 +48,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
*/
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {

// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))

// Cache input RDD for speedup during multiple passes.
input.cache()
weightedInput.cache()
logDebug("algo = " + strategy.algo)

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(weightedInput, strategy)
val numBins = bins(0).length
logDebug("numBins = " + numBins)

Expand All @@ -70,7 +74,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
val numFeatures = input.take(1)(0).features.size
val numFeatures = weightedInput.take(1)(0).features.size

// Calculate level for single group construction

Expand Down Expand Up @@ -109,8 +113,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("#####################################")

// Find best split for all nodes at a level.
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
level, filters, splits, bins, maxLevelForSingleGroup)
val splitsStatsForLevel = DecisionTree.findBestSplits(weightedInput, parentImpurities,
strategy, level, filters, splits, bins, maxLevelForSingleGroup)

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
// Extract info for nodes at the current level.
Expand Down Expand Up @@ -291,7 +295,7 @@ object DecisionTree extends Serializable with Logging {
* @return array of splits with best splits for all nodes at a given level.
*/
protected[tree] def findBestSplits(
input: RDD[LabeledPoint],
input: RDD[WeightedLabeledPoint],
parentImpurities: Array[Double],
strategy: Strategy,
level: Int,
Expand Down Expand Up @@ -339,7 +343,7 @@ object DecisionTree extends Serializable with Logging {
* @return array of splits with best splits for all nodes at a given level.
*/
private def findBestSplitsPerGroup(
input: RDD[LabeledPoint],
input: RDD[WeightedLabeledPoint],
parentImpurities: Array[Double],
strategy: Strategy,
level: Int,
Expand Down Expand Up @@ -399,7 +403,7 @@ object DecisionTree extends Serializable with Logging {
* Find whether the sample is valid input for the current node, i.e., whether it passes through
* all the filters for the current node.
*/
def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
def isSampleValid(parentFilters: List[Filter], labeledPoint: WeightedLabeledPoint): Boolean = {
// leaf
if ((level > 0) & (parentFilters.length == 0)) {
return false
Expand Down Expand Up @@ -438,7 +442,7 @@ object DecisionTree extends Serializable with Logging {
*/
def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
labeledPoint: WeightedLabeledPoint,
isFeatureContinuous: Boolean): Int = {
val binForFeatures = bins(featureIndex)
val feature = labeledPoint.features(featureIndex)
Expand Down Expand Up @@ -509,7 +513,7 @@ object DecisionTree extends Serializable with Logging {
* where b_ij is an integer between 0 and numBins - 1.
* Invalid sample is denoted by noting bin for feature 1 as -1.
*/
def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = {
// Calculate bin index and label per feature per node.
val arr = new Array[Double](1 + (numFeatures * numNodes))
arr(0) = labeledPoint.label
Expand Down Expand Up @@ -982,7 +986,7 @@ object DecisionTree extends Serializable with Logging {
* .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
*/
protected[tree] def findSplitsBins(
input: RDD[LabeledPoint],
input: RDD[WeightedLabeledPoint],
strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
val count = input.count()

Expand Down

0 comments on commit 14aea48

Please sign in to comment.