From 5841c2838e6834fc8c767f3c83dba7ef99375fa4 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 26 Jan 2014 22:34:49 -0800 Subject: [PATCH] unit tests for categorical features Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTreeSuite.scala | 228 +++++++++++++++--- 1 file changed, 191 insertions(+), 37 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 15b5b40b06532..39635a7e654a2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -27,11 +27,12 @@ import org.apache.spark.SparkContext._ import org.jblas._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import scala.collection.mutable +import org.apache.spark.mllib.tree.configuration.FeatureType._ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -56,7 +57,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins.length==2) assert(splits(0).length==99) assert(bins(0).length==100) - //println(splits(1)(98)) } test("split and bin calculation for categorical variables"){ @@ -69,13 +69,71 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins.length==2) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(0)(0)) - println(splits(0)(1)) - println(bins(0)(0)) - println(splits(1)(0)) - println(splits(1)(1)) - println(bins(1)(0)) - //TODO: Add asserts + + //Checking splits + + assert(splits(0)(0).feature == 0) + assert(splits(0)(0).threshold == Double.MinValue) + assert(splits(0)(0).featureType == Categorical) + assert(splits(0)(0).categories.length == 1) + assert(splits(0)(0).categories.contains(1.0)) + + + assert(splits(0)(1).feature == 0) + assert(splits(0)(1).threshold == Double.MinValue) + assert(splits(0)(1).featureType == Categorical) + assert(splits(0)(1).categories.length == 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2) == null) + + assert(splits(1)(0).feature == 1) + assert(splits(1)(0).threshold == Double.MinValue) + assert(splits(1)(0).featureType == Categorical) + assert(splits(1)(0).categories.length == 1) + assert(splits(1)(0).categories.contains(0.0)) + + + assert(splits(1)(1).feature == 1) + assert(splits(1)(1).threshold == Double.MinValue) + assert(splits(1)(1).featureType == Categorical) + assert(splits(1)(1).categories.length == 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2) == null) + + + // Checks bins + + assert(bins(0)(0).category == 1.0) + assert(bins(0)(0).lowSplit.categories.length == 0) + assert(bins(0)(0).highSplit.categories.length == 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category == 0.0) + assert(bins(0)(1).lowSplit.categories.length == 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length == 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2).category == Double.MaxValue) + + assert(bins(1)(0).category == 0.0) + assert(bins(1)(0).lowSplit.categories.length == 0) + assert(bins(1)(0).highSplit.categories.length == 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category == 1.0) + assert(bins(1)(1).lowSplit.categories.length == 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length == 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2).category == Double.MaxValue) } @@ -85,29 +143,106 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(bins.length==2) - assert(splits(0).length==99) - assert(bins(0).length==100) - println(splits(0)(0)) - println(splits(0)(1)) - println(splits(0)(2)) - println(bins(0)(0)) - println(bins(0)(1)) - println(bins(0)(2)) - println(splits(1)(0)) - println(splits(1)(1)) - println(splits(1)(2)) - println(bins(1)(0)) - println(bins(1)(1)) - println(bins(0)(2)) - println(bins(0)(3)) - //TODO: Add asserts - } + //Checking splits + + assert(splits(0)(0).feature == 0) + assert(splits(0)(0).threshold == Double.MinValue) + assert(splits(0)(0).featureType == Categorical) + assert(splits(0)(0).categories.length == 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature == 0) + assert(splits(0)(1).threshold == Double.MinValue) + assert(splits(0)(1).featureType == Categorical) + assert(splits(0)(1).categories.length == 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2).feature == 0) + assert(splits(0)(2).threshold == Double.MinValue) + assert(splits(0)(2).featureType == Categorical) + assert(splits(0)(2).categories.length == 3) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(2.0)) + + assert(splits(0)(3) == null) + + assert(splits(1)(0).feature == 1) + assert(splits(1)(0).threshold == Double.MinValue) + assert(splits(1)(0).featureType == Categorical) + assert(splits(1)(0).categories.length == 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature == 1) + assert(splits(1)(1).threshold == Double.MinValue) + assert(splits(1)(1).featureType == Categorical) + assert(splits(1)(1).categories.length == 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2).feature == 1) + assert(splits(1)(2).threshold == Double.MinValue) + assert(splits(1)(2).featureType == Categorical) + assert(splits(1)(2).categories.length == 3) + assert(splits(1)(2).categories.contains(1.0)) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(2.0)) + + assert(splits(1)(3) == null) + + + // Checks bins + + assert(bins(0)(0).category == 1.0) + assert(bins(0)(0).lowSplit.categories.length == 0) + assert(bins(0)(0).highSplit.categories.length == 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category == 0.0) + assert(bins(0)(1).lowSplit.categories.length == 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length == 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2).category == 2.0) + assert(bins(0)(2).lowSplit.categories.length == 2) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).lowSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.length == 3) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.contains(2.0)) + + assert(bins(0)(3).category == Double.MaxValue) + + assert(bins(1)(0).category == 0.0) + assert(bins(1)(0).lowSplit.categories.length == 0) + assert(bins(1)(0).highSplit.categories.length == 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category == 1.0) + assert(bins(1)(1).lowSplit.categories.length == 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length == 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2).category == 2.0) + assert(bins(1)(2).lowSplit.categories.length == 2) + assert(bins(1)(2).lowSplit.categories.contains(0.0)) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length == 3) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(2.0)) + + assert(bins(1)(3).category == Double.MaxValue) - //TODO: Test max feature value > num bins + } test("classification stump with all categorical variables"){ val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -117,22 +252,41 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) - println(bestSplits(0)._1) - println(bestSplits(0)._2) - //TODO: Add asserts + + val split = bestSplits(0)._1 + assert(split.categories.length == 1) + assert(split.categories.contains(1.0)) + assert(split.featureType == Categorical) + assert(split.threshold == Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } test("regression stump with all categorical variables"){ val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) - println(bestSplits(0)._1) - println(bestSplits(0)._2) - //TODO: Add asserts + + val split = bestSplits(0)._1 + assert(split.categories.length == 1) + assert(split.categories.contains(1.0)) + assert(split.featureType == Categorical) + assert(split.threshold == Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) } @@ -157,7 +311,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - println(bestSplits(0)._2.predict) + } test("stump with fixed label 1 for Gini"){