diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6c99f82f687e8..4d7ac51e2f01e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -275,7 +275,8 @@ object DecisionTree extends Serializable with Logging { private val InvalidBinIndex = -1 /** - * Returns an array of optimal splits for all nodes at a given level + * Returns an array of optimal splits for all nodes at a given level. Splits the tasks into + * multiple groups if the level-wise training tasks could lead to memory overflow. * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4a0b399ca3dde..2155ed7b4a154 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -405,8 +405,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),-1) - val rightFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),1) + val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()), -1) + val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()) ,1) val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter)) val parentImpurities = Array(0.5, 0.5, 0.5) @@ -444,7 +444,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } @@ -453,7 +453,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } @@ -462,7 +462,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { if (i < 600){ val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp @@ -476,7 +476,7 @@ object DecisionTreeSuite { def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { if (i < 600){ arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else {