From d75ac3211f4b951e6771451894f7b24718f7c08c Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 10 Jul 2014 17:48:37 -0700 Subject: [PATCH] removed WeightedLabeledPoint from this PR --- .../spark/mllib/point/PointConverter.scala | 35 ----------- .../mllib/point/WeightedLabeledPoint.scala | 32 ---------- .../spark/mllib/tree/DecisionTree.scala | 16 +++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 58 +++++++++---------- 4 files changed, 36 insertions(+), 105 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala deleted file mode 100644 index 1f31c4dadc21c..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.point - -import scala.language.implicitConversions - -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.regression.LabeledPoint - -/** - * Class to convert between different point formats. - */ -object PointConverter { - - implicit def LabeledPoint2WeightedLabeledPoint( - points : RDD[LabeledPoint]): RDD[WeightedLabeledPoint] = { - points.map(point => new WeightedLabeledPoint(point.label,point.features)) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala deleted file mode 100644 index cc2d3caa0b86c..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.point - -import org.apache.spark.mllib.linalg.Vector - -/** - * Class that represents the features and labels of a data point. - * - * @param label Label for this data point. - * @param features List of features for this data point. - */ -case class WeightedLabeledPoint(label: Double, features: Vector, weight:Double = 1) { - override def toString: String = { - "LabeledPoint(%s, %s, %s)".format(label, features, weight) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9bda064bee554..a4524319c7ff4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.tree -import org.apache.spark.mllib.point.PointConverter._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint @@ -29,7 +28,6 @@ import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.mllib.point.WeightedLabeledPoint /** * :: Experimental :: @@ -47,7 +45,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * @return a DecisionTreeModel that can be used for prediction */ - def train(input: RDD[WeightedLabeledPoint]): DecisionTreeModel = { + def train(input: RDD[LabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. input.cache() @@ -352,7 +350,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( - input: RDD[WeightedLabeledPoint], + input: RDD[LabeledPoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -400,7 +398,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[WeightedLabeledPoint], + input: RDD[LabeledPoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -469,7 +467,7 @@ object DecisionTree extends Serializable with Logging { * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: WeightedLabeledPoint): Boolean = { + def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { // leaf if ((level > 0) && (parentFilters.length == 0)) { return false @@ -506,7 +504,7 @@ object DecisionTree extends Serializable with Logging { /** * Find bin for one feature. */ - def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint, + def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -595,7 +593,7 @@ object DecisionTree extends Serializable with Logging { * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. */ - def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = { + def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) // First element of the array is the label of the instance. @@ -1283,7 +1281,7 @@ object DecisionTree extends Serializable with Logging { * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ protected[tree] def findSplitsBins( - input: RDD[WeightedLabeledPoint], + input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 6b6cab97935b0..5961a618c59d9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite -import org.apache.spark.mllib.point.WeightedLabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.model.Split @@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -664,86 +664,86 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { object DecisionTreeSuite { - def generateOrderedLabeledPointsWithLabel0(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](1000) + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } arr } - def generateOrderedLabeledPointsWithLabel1(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](1000) + def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } arr } - def generateOrderedLabeledPoints(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](1000) + def generateOrderedLabeledPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { if (i < 600) { - val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } else { - val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } } arr } - def generateCategoricalDataPoints(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](1000) + def generateCategoricalDataPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { if (i < 600) { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(0.0, 1.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { - arr(i) = new WeightedLabeledPoint(0.0, Vectors.dense(1.0, 0.0)) + arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) } } arr } - def generateCategoricalDataPointsForMulticlass(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](3000) + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { if (i < 1000) { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) } else { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } } arr } - def generateContinuousDataPointsForMulticlass(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](3000) + def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { if (i < 2000) { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, i)) + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i)) } else { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, i)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i)) } } arr } def generateCategoricalDataPointsForMulticlassForOrderedFeatures(): - Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](3000) + Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { if (i < 1000) { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) } else { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 2.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0)) } } arr