Skip to content

Commit

Permalink
added categorical variable test
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 21, 2014
1 parent bce835f commit 828ff16
Showing 1 changed file with 47 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 3,
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)

// Check splits.
Expand Down Expand Up @@ -483,7 +483,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.predict === 1)
}

test("test second level node building with/without groups") {
test("second level node building with/without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
Expand Down Expand Up @@ -529,6 +529,33 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

}

test("stump with continuous variables for multiclass classification") {
assert(true==true)
}

test("stump with categorical variables for multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)

assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
assert(bestSplit.feature === 0)
assert(bestSplit.categories.length === 1)
assert(bestSplit.categories.contains(0))
assert(bestSplit.featureType === Categorical)
println(bestSplit)
}

test("stump with continuous + categorical variables for multiclass classification") {
assert(true==true)
}

}

object DecisionTreeSuite {
Expand Down Expand Up @@ -576,4 +603,22 @@ object DecisionTreeSuite {
}
arr
}

def generateCategoricalDataPointsForMulticlass(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](3000)
for (i <- 0 until 3000) {
if (i < 1000) {
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
} else if (i < 2000) {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0))
} else {
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
}
}
println(arr(0))
println(arr(1000))
println(arr(2000))
arr
}

}

0 comments on commit 828ff16

Please sign in to comment.