Skip to content

Commit

Permalink
removed WeightedLabeledPoint from this PR
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Jul 11, 2014
1 parent 0fecd38 commit d75ac32
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 105 deletions.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.tree

import org.apache.spark.mllib.point.PointConverter._
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
Expand All @@ -29,7 +28,6 @@ import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.mllib.point.WeightedLabeledPoint

/**
* :: Experimental ::
Expand All @@ -47,7 +45,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
* @return a DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[WeightedLabeledPoint]): DecisionTreeModel = {
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {

// Cache input RDD for speedup during multiple passes.
input.cache()
Expand Down Expand Up @@ -352,7 +350,7 @@ object DecisionTree extends Serializable with Logging {
* @return array of splits with best splits for all nodes at a given level.
*/
protected[tree] def findBestSplits(
input: RDD[WeightedLabeledPoint],
input: RDD[LabeledPoint],
parentImpurities: Array[Double],
strategy: Strategy,
level: Int,
Expand Down Expand Up @@ -400,7 +398,7 @@ object DecisionTree extends Serializable with Logging {
* @return array of splits with best splits for all nodes at a given level.
*/
private def findBestSplitsPerGroup(
input: RDD[WeightedLabeledPoint],
input: RDD[LabeledPoint],
parentImpurities: Array[Double],
strategy: Strategy,
level: Int,
Expand Down Expand Up @@ -469,7 +467,7 @@ object DecisionTree extends Serializable with Logging {
* Find whether the sample is valid input for the current node, i.e., whether it passes through
* all the filters for the current node.
*/
def isSampleValid(parentFilters: List[Filter], labeledPoint: WeightedLabeledPoint): Boolean = {
def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
// leaf
if ((level > 0) && (parentFilters.length == 0)) {
return false
Expand Down Expand Up @@ -506,7 +504,7 @@ object DecisionTree extends Serializable with Logging {
/**
* Find bin for one feature.
*/
def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint,
def findBin(featureIndex: Int, labeledPoint: LabeledPoint,
isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
val binForFeatures = bins(featureIndex)
val feature = labeledPoint.features(featureIndex)
Expand Down Expand Up @@ -595,7 +593,7 @@ object DecisionTree extends Serializable with Logging {
* classification and the categorical feature value in multiclass classification.
* Invalid sample is denoted by noting bin for feature 1 as -1.
*/
def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = {
def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
// Calculate bin index and label per feature per node.
val arr = new Array[Double](1 + (numFeatures * numNodes))
// First element of the array is the label of the instance.
Expand Down Expand Up @@ -1283,7 +1281,7 @@ object DecisionTree extends Serializable with Logging {
* .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
*/
protected[tree] def findSplitsBins(
input: RDD[WeightedLabeledPoint],
input: RDD[LabeledPoint],
strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
val count = input.count()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree

import org.scalatest.FunSuite

import org.apache.spark.mllib.point.WeightedLabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.model.Split
Expand All @@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.regression.LabeledPoint

class DecisionTreeSuite extends FunSuite with LocalSparkContext {

Expand Down Expand Up @@ -664,86 +664,86 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

object DecisionTreeSuite {

def generateOrderedLabeledPointsWithLabel0(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](1000)
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000) {
val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
arr
}

def generateOrderedLabeledPointsWithLabel1(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](1000)
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000) {
val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
arr(i) = lp
}
arr
}

def generateOrderedLabeledPoints(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](1000)
def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000) {
if (i < 600) {
val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
} else {
val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
}
arr
}

def generateCategoricalDataPoints(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](1000)
def generateCategoricalDataPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000) {
if (i < 600) {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(0.0, 1.0))
arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
} else {
arr(i) = new WeightedLabeledPoint(0.0, Vectors.dense(1.0, 0.0))
arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))
}
}
arr
}

def generateCategoricalDataPointsForMulticlass(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](3000)
def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](3000)
for (i <- 0 until 3000) {
if (i < 1000) {
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
} else if (i < 2000) {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0))
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
} else {
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
}
}
arr
}

def generateContinuousDataPointsForMulticlass(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](3000)
def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](3000)
for (i <- 0 until 3000) {
if (i < 2000) {
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, i))
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i))
} else {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, i))
arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i))
}
}
arr
}

def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](3000)
Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](3000)
for (i <- 0 until 3000) {
if (i < 1000) {
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
} else if (i < 2000) {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0))
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
} else {
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 2.0))
arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
}
}
arr
Expand Down

0 comments on commit d75ac32

Please sign in to comment.