Skip to content

Commit

Permalink
map ++ paramMap => extractValues
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 8, 2015
1 parent 0d3fc5b commit eeeffe8
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 39 deletions.
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
*/
@varargs
def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
val map = new ParamMap().put(paramPairs: _*)
val map = ParamMap(paramPairs: _*)
fit(dataset, map)
}

Expand Down
8 changes: 4 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class Pipeline extends Estimator[PipelineModel] {
*/
override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val theStages = map(stages)
// Search for the last estimator.
var indexOfLastEstimator = -1
Expand Down Expand Up @@ -135,7 +135,7 @@ class Pipeline extends Estimator[PipelineModel] {
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
"Cannot have duplicate components in a pipeline.")
Expand Down Expand Up @@ -174,14 +174,14 @@ class PipelineModel private[ml] (

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
val map = fittingParamMap ++ extractValues(paramMap)
transformSchema(dataset.schema, map, logging = true)
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
val map = fittingParamMap ++ extractValues(paramMap)
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
}
}
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
protected def validateInputType(inputType: DataType): Unit = {}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val inputType = schema(map(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains(map(outputCol))) {
Expand All @@ -100,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
dataset.withColumn(map(outputCol),
callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private[spark] trait ClassifierParams extends PredictorParams
fitting: Boolean,
featuresDataType: DataType): StructType = {
val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
}
}
Expand Down Expand Up @@ -109,7 +109,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur

// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)

// Prepare model
val tmpModel = if (paramMap.size != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class LogisticRegressionModel private[ml] (
// Check schema
transformSchema(dataset.schema, paramMap, logging = true)

val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)

// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
Expand Down Expand Up @@ -178,7 +178,7 @@ class LogisticRegressionModel private[ml] (
* The behavior of this can be adjusted using [[threshold]].
*/
override protected def predict(features: Vector): Double = {
if (score(features) > paramMap(threshold)) 1 else 0
if (score(features) > getThreshold) 1 else 0
}

override protected def predictProbabilities(features: Vector): Vector = {
Expand All @@ -193,7 +193,7 @@ class LogisticRegressionModel private[ml] (

override protected def copy(): LogisticRegressionModel = {
val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
Params.inheritValues(this.paramMap, this, m)
Params.inheritValues(this.extractValues(), this, m)
m
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private[classification] trait ProbabilisticClassifierParams
fitting: Boolean,
featuresDataType: DataType): StructType = {
val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT)
}
}
Expand Down Expand Up @@ -103,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[

// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)

// Prepare model
val tmpModel = if (paramMap.size != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ import org.apache.spark.sql.types.DoubleType
class BinaryClassificationEvaluator extends Evaluator with Params
with HasRawPredictionCol with HasLabelCol {

setDefault(metricName -> "areaUnderROC")

/**
* param for metric name in evaluation
* @group param
Expand All @@ -57,8 +55,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)

setDefault(metricName -> "areaUnderROC")

override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)

val schema = dataset.schema
checkInputColumn(schema, map(rawPredictionCol), new VectorUDT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP

override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(this, map, scaler)
Expand All @@ -57,7 +57,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${map(inputCol)} must be a vector column")
Expand Down Expand Up @@ -87,13 +87,13 @@ class StandardScalerModel private[ml] (

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${map(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ private[spark] trait PredictorParams extends Params
paramMap: ParamMap,
fitting: Boolean,
featuresDataType: DataType): StructType = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
checkInputColumn(schema, map(featuresCol), featuresDataType)
if (fitting) {
Expand Down Expand Up @@ -99,7 +99,7 @@ private[spark] abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val model = train(dataset, map)
Params.inheritValues(map, this, model) // copy params to model
model
Expand Down Expand Up @@ -142,7 +142,7 @@ private[spark] abstract class Predictor[
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
dataset.select(map(labelCol), map(featuresCol))
.map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
Expand Down Expand Up @@ -202,7 +202,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel

// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)

// Prepare model
val tmpModel = if (paramMap.size != 0) {
Expand Down
17 changes: 9 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ trait Params extends Identifiable with Serializable {
/** Checks whether a param is explicitly set. */
def isSet(param: Param[_]): Boolean = {
require(param.parent.eq(this))
paramMap.contains(param)
values.contains(param)
}

/** Gets a param by its name. */
Expand All @@ -153,7 +153,7 @@ trait Params extends Identifiable with Serializable {
*/
protected final def set[T](param: Param[T], value: T): this.type = {
require(param.parent.eq(this))
paramMap.put(param.asInstanceOf[Param[Any]], value)
values.put(param.asInstanceOf[Param[Any]], value)
this
}

Expand All @@ -169,26 +169,23 @@ trait Params extends Identifiable with Serializable {
*/
protected final def get[T](param: Param[T]): T = {
require(param.parent.eq(this))
paramMap(param)
values(param)
}

/**
* Internal param map.
*/
protected final val paramMap: ParamMap = ParamMap.empty
private val values: ParamMap = ParamMap.empty

/**
* Internal param map for default values.
*/
protected final val defaultValues: ParamMap = ParamMap.empty
private val defaultValues: ParamMap = ParamMap.empty

/**
* Sets a default value.
*/
protected final def setDefault[T](param: Param[T], value: T): this.type = {
println(s"param: $param")
println(param.parent)
println(value)
require(param.parent.eq(this))
defaultValues.put(param, value)
this
Expand All @@ -206,6 +203,10 @@ trait Params extends Identifiable with Serializable {
defaultValues.get(param)
}

protected final def extractValues(extraValues: ParamMap = ParamMap.empty): ParamMap = {
defaultValues ++ values ++ extraValues
}

/**
* Check whether the given schema contains an input column.
* @param colName Parameter name for the input column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
assert(schema(map(userCol)).dataType == IntegerType)
assert(schema(map(itemCol)).dataType== IntegerType)
val ratingType = schema(map(ratingCol)).dataType
Expand Down Expand Up @@ -175,7 +175,7 @@ class ALSModel private[ml] (

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext.implicits._
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val users = userFactors.toDF("id", "features")
val items = itemFactors.toDF("id", "features")

Expand Down Expand Up @@ -287,7 +287,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setCheckpointInterval(10)

override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val ratings = dataset
.select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
.map { row =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class LinearRegressionModel private[ml] (

override protected def copy(): LinearRegressionModel = {
val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept)
Params.inheritValues(this.paramMap, this, m)
Params.inheritValues(extractValues(), this, m)
m
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
def setNumFolds(value: Int): this.type = set(numFolds, value)

override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
val schema = dataset.schema
transformSchema(dataset.schema, paramMap, logging = true)
val sqlCtx = dataset.sqlContext
Expand Down Expand Up @@ -132,7 +132,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val map = extractValues(paramMap)
map(estimator).transformSchema(schema, paramMap)
}
}
Expand Down

0 comments on commit eeeffe8

Please sign in to comment.