From e006f9d5914b28b30aa8c24b0d1ff9977f23179e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 11 May 2014 21:29:47 -0700 Subject: [PATCH] changing variable names --- .../spark/mllib/tree/DecisionTree.scala | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3da92ed891611..52ae362028f5c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -46,18 +46,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * @return a DecisionTreeModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - - // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + def train(input: RDD[WeightedLabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. - weightedInput.cache() + input.cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(weightedInput, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -74,7 +71,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = weightedInput.take(1)(0).features.size + val numFeatures = input.take(1)(0).features.size // Calculate level for single group construction @@ -113,7 +110,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(weightedInput, parentImpurities, + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { @@ -216,7 +213,9 @@ object DecisionTree extends Serializable with Logging { * @return a DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } /** @@ -238,7 +237,9 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } @@ -273,7 +274,9 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } private val InvalidBinIndex = -1