Skip to content

Commit

Permalink
implicit conversion from LabeledPoint to WeightedLabeledPoint
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Jul 7, 2014
1 parent 3d7f911 commit 485eaae
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.point

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint

object PointConverter {

implicit def LabeledPoint2WeightedLabeledPoint(
points : RDD[LabeledPoint]): RDD[WeightedLabeledPoint] = {
points.map(point => new WeightedLabeledPoint(point.label,point.features))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.tree

import org.apache.spark.mllib.point.PointConverter._
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -211,9 +212,7 @@ object DecisionTree extends Serializable with Logging {
* @return a DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
new DecisionTree(strategy).train(input)
}

/**
Expand All @@ -235,9 +234,7 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
new DecisionTree(strategy).train(input)
}

/**
Expand All @@ -261,9 +258,7 @@ object DecisionTree extends Serializable with Logging {
maxDepth: Int,
numClassesForClassification: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
new DecisionTree(strategy).train(input)
}


Expand Down Expand Up @@ -294,9 +289,7 @@ object DecisionTree extends Serializable with Logging {
labelWeights: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification,
labelWeights = labelWeights)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
new DecisionTree(strategy).train(input)
}

/**
Expand Down Expand Up @@ -337,9 +330,7 @@ object DecisionTree extends Serializable with Logging {
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
new DecisionTree(strategy).train(input)
}

private val InvalidBinIndex = -1
Expand Down

0 comments on commit 485eaae

Please sign in to comment.