Skip to content

Commit

Permalink
changing default values for num classes
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 12, 2014
1 parent 6b912dc commit 18d2835
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
algo: Algo = Classification,
numClasses: Int = 2,
numClassesForClassification: Int = 2,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 100)
Expand All @@ -69,9 +69,10 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
opt[Int]("numClasses")
.text(s"number of classes for classification, default: ${defaultParams.numClasses}")
.action((x, c) => c.copy(numClasses = x))
opt[Int]("numClassesForClassification")
.text(s"number of classes for classification, "
+ s"default: ${defaultParams.numClassesForClassification}")
.action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
Expand Down Expand Up @@ -122,7 +123,13 @@ object DecisionTreeRunner {
case Variance => impurity.Variance
}

val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
val strategy
= new Strategy(
algo = params.algo,
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
numClassesForClassification = params.numClassesForClassification)
val model = DecisionTree.train(training, strategy)

if (params.algo == Classification) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ object DecisionTree extends Serializable with Logging {
algo: Algo,
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth)
val strategy = new Strategy(algo, impurity, maxDepth)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
Expand All @@ -253,16 +253,16 @@ object DecisionTree extends Serializable with Logging {
* @param algo algorithm, classification or regression
* @param impurity impurity criterion used for information gain calculation
* @param maxDepth maxDepth maximum depth of the tree
* @param numClasses number of classes for classification
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @return a DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int): DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth,numClasses)
numClassesForClassification: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
Expand All @@ -282,7 +282,7 @@ object DecisionTree extends Serializable with Logging {
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param numClasses number of classes for classification
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
Expand All @@ -297,11 +297,11 @@ object DecisionTree extends Serializable with Logging {
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
numClassesForClassification: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
Expand Down Expand Up @@ -851,10 +851,8 @@ object DecisionTree extends Serializable with Logging {
if (strategy.isMultiClassification) {
var featureIndex = 0
while (featureIndex < numFeatures){
val numCategories = strategy.categoricalFeaturesInfo(featureIndex)
val maxSplits = math.pow(2, numCategories) - 1
var splitIndex = 0
while (splitIndex < maxSplits) {
while (splitIndex < numBins - 1) {
var classIndex = 0
while (classIndex < numClasses) {
// shift for this featureIndex
Expand Down

0 comments on commit 18d2835

Please sign in to comment.