Skip to content

Commit

Permalink
changing variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 12, 2014
1 parent 5c78e1a commit e006f9d
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
* @return a DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {

// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
def train(input: RDD[WeightedLabeledPoint]): DecisionTreeModel = {

// Cache input RDD for speedup during multiple passes.
weightedInput.cache()
input.cache()
logDebug("algo = " + strategy.algo)

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
val (splits, bins) = DecisionTree.findSplitsBins(weightedInput, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val numBins = bins(0).length
logDebug("numBins = " + numBins)

Expand All @@ -74,7 +71,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
val numFeatures = weightedInput.take(1)(0).features.size
val numFeatures = input.take(1)(0).features.size

// Calculate level for single group construction

Expand Down Expand Up @@ -113,7 +110,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("#####################################")

// Find best split for all nodes at a level.
val splitsStatsForLevel = DecisionTree.findBestSplits(weightedInput, parentImpurities,
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities,
strategy, level, filters, splits, bins, maxLevelForSingleGroup)

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
Expand Down Expand Up @@ -216,7 +213,9 @@ object DecisionTree extends Serializable with Logging {
* @return a DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
}

/**
Expand All @@ -238,7 +237,9 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth)
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
}


Expand Down Expand Up @@ -273,7 +274,9 @@ object DecisionTree extends Serializable with Logging {
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
categoricalFeaturesInfo)
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
}

private val InvalidBinIndex = -1
Expand Down

0 comments on commit e006f9d

Please sign in to comment.