Skip to content

Commit

Permalink
Merge pull request #5 from etrain/deep_tree
Browse files Browse the repository at this point in the history
Parameterizing max memory.
  • Loading branch information
manishamde committed Apr 22, 2014
2 parents 50b143a + abc5a23 commit 2f6072c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.util.Utils.memoryStringToMb
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
Expand Down Expand Up @@ -79,7 +80,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Calculate level for single group construction

// Max memory usage for aggregates
val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB
val maxMemoryUsage = strategy.maxMemory * 1024 * 1024
logDebug("max memory usage for aggregates = " + maxMemoryUsage)
val numElementsPerNode = {
strategy.algo match {
Expand Down Expand Up @@ -1158,10 +1159,13 @@ object DecisionTree extends Serializable with Logging {

val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt
val maxBins = options.getOrElse('maxBins, "100").toString.toInt
val maxMemUsage = memoryStringToMb(options.getOrElse('maxMemory, "128m").toString)

val strategy = new Strategy(algo, impurity, maxDepth, maxBins)
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, maxMemory=maxMemUsage)
val model = DecisionTree.train(trainData, strategy)



// Load test data.
val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ class Strategy (
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemory: Int = 128) extends Serializable

0 comments on commit 2f6072c

Please sign in to comment.