Skip to content

Commit

Permalink
[SPARK-5957][ML] better handling of parameters
Browse files Browse the repository at this point in the history
The design doc was posted on the JIRA page. Python changes will be in a follow-up PR. jkbradley

1. Use codegen for shared params.
1. Move shared params to package `ml.param.shared`.
1. Set default values in `Params` instead of in `Param`.
1. Add a few methods to `Params` and `ParamMap`.
1. Move schema handling to `SchemaUtils` from `Params`.

- [x] check visibility of the methods added

Author: Xiangrui Meng <meng@databricks.com>

Closes apache#5431 from mengxr/SPARK-5957 and squashes the following commits:

d19236d [Xiangrui Meng] fix test
26ae2d7 [Xiangrui Meng] re-gen code and mark clear protected
38b78c7 [Xiangrui Meng] update Param.toString and remove Params.explain()
409e2d5 [Xiangrui Meng] address comments
2d637bd [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957
eec2264 [Xiangrui Meng] make get* public in Params
4090d95 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957
4fee9e7 [Xiangrui Meng] re-gen shared params
2737c2d [Xiangrui Meng] rename SharedParamCodeGen to SharedParamsCodeGen
e938f81 [Xiangrui Meng] update code to set default parameter values
28ed322 [Xiangrui Meng] merge master
55be1f3 [Xiangrui Meng] merge master
d63b5cc [Xiangrui Meng] fix examples
29b004c [Xiangrui Meng] update ParamsSuite
94fd98e [Xiangrui Meng] fix explain params
48d0e84 [Xiangrui Meng] add remove and update explainParams
4ac6348 [Xiangrui Meng] move schema utils to SchemaUtils add a few methods to Params
0d9594e [Xiangrui Meng] add getOrElse to ParamMap
eeeffe8 [Xiangrui Meng] map ++ paramMap => extractValues
0d3fc5b [Xiangrui Meng] setDefault after param
a9dbf59 [Xiangrui Meng] minor updates
d9302b8 [Xiangrui Meng] generate default values
1c72579 [Xiangrui Meng] pass test compile
abb7a3b [Xiangrui Meng] update default values handling
dcab97a [Xiangrui Meng] add codegen for shared params
  • Loading branch information
mengxr committed Apr 14, 2015
1 parent 0ba3fdd commit 971b95b
Show file tree
Hide file tree
Showing 27 changed files with 820 additions and 396 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class MyJavaLogisticRegression
*/
IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");

int getMaxIter() { return (Integer) get(maxIter); }
int getMaxIter() { return (Integer) getOrDefault(maxIter); }

public MyJavaLogisticRegression() {
setMaxIter(100);
Expand Down Expand Up @@ -211,7 +211,7 @@ public Vector predictRaw(Vector features) {
public MyJavaLogisticRegressionModel copy() {
MyJavaLogisticRegressionModel m =
new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
Params$.MODULE$.inheritValues(this.paramMap(), this, m);
Params$.MODULE$.inheritValues(this.extractParamMap(), this, m);
return m;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
* class since the maxIter parameter is only used during training (not in the Model).
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
def getMaxIter: Int = get(maxIter)
def getMaxIter: Int = getOrDefault(maxIter)
}

/**
Expand Down Expand Up @@ -174,11 +174,11 @@ private class MyLogisticRegressionModel(
* Create a copy of the model.
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
*
* This is used for the defaul implementation of [[transform()]].
* This is used for the default implementation of [[transform()]].
*/
override protected def copy(): MyLogisticRegressionModel = {
val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
Params.inheritValues(this.paramMap, this, m)
Params.inheritValues(extractParamMap(), this, m)
m
}
}
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
*/
@varargs
def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
val map = new ParamMap().put(paramPairs: _*)
val map = ParamMap(paramPairs: _*)
fit(dataset, map)
}

Expand Down
10 changes: 5 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Pipeline extends Estimator[PipelineModel] {
/** param for pipeline stages */
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
def getStages: Array[PipelineStage] = get(stages)
def getStages: Array[PipelineStage] = getOrDefault(stages)

/**
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
Expand All @@ -101,7 +101,7 @@ class Pipeline extends Estimator[PipelineModel] {
*/
override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractParamMap(paramMap)
val theStages = map(stages)
// Search for the last estimator.
var indexOfLastEstimator = -1
Expand Down Expand Up @@ -138,7 +138,7 @@ class Pipeline extends Estimator[PipelineModel] {
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val map = extractParamMap(paramMap)
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
"Cannot have duplicate components in a pipeline.")
Expand Down Expand Up @@ -177,14 +177,14 @@ class PipelineModel private[ml] (

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
val map = fittingParamMap ++ extractParamMap(paramMap)
transformSchema(dataset.schema, map, logging = true)
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
val map = fittingParamMap ++ extractParamMap(paramMap)
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
}
}
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.annotation.varargs
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -86,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
protected def validateInputType(inputType: DataType): Unit = {}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val map = extractParamMap(paramMap)
val inputType = schema(map(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains(map(outputCol))) {
Expand All @@ -99,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractParamMap(paramMap)
dataset.withColumn(map(outputCol),
callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.ml.classification

import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}


Expand All @@ -42,8 +44,8 @@ private[spark] trait ClassifierParams extends PredictorParams
fitting: Boolean,
featuresDataType: DataType): StructType = {
val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
val map = this.paramMap ++ paramMap
addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
val map = extractParamMap(paramMap)
SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
}
}

Expand All @@ -67,8 +69,7 @@ private[spark] abstract class Classifier[
with ClassifierParams {

/** @group setParam */
def setRawPredictionCol(value: String): E =
set(rawPredictionCol, value).asInstanceOf[E]
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]

// TODO: defaultEvaluator (follow-up PR)
}
Expand Down Expand Up @@ -109,7 +110,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur

// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractParamMap(paramMap)

// Prepare model
val tmpModel = if (paramMap.size != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@ package org.apache.spark.ml.classification

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel


/**
* Params for logistic regression.
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold
with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold {

setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5)
}

/**
* :: AlphaComponent ::
Expand All @@ -45,10 +47,6 @@ class LogisticRegression
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams {

setRegParam(0.1)
setMaxIter(100)
setThreshold(0.5)

/** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)

Expand Down Expand Up @@ -100,8 +98,6 @@ class LogisticRegressionModel private[ml] (
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams {

setThreshold(0.5)

/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)

Expand All @@ -123,7 +119,7 @@ class LogisticRegressionModel private[ml] (
// Check schema
transformSchema(dataset.schema, paramMap, logging = true)

val map = this.paramMap ++ paramMap
val map = extractParamMap(paramMap)

// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
Expand Down Expand Up @@ -184,7 +180,7 @@ class LogisticRegressionModel private[ml] (
* The behavior of this can be adjusted using [[threshold]].
*/
override protected def predict(features: Vector): Double = {
if (score(features) > paramMap(threshold)) 1 else 0
if (score(features) > getThreshold) 1 else 0
}

override protected def predictProbabilities(features: Vector): Vector = {
Expand All @@ -199,7 +195,7 @@ class LogisticRegressionModel private[ml] (

override protected def copy(): LogisticRegressionModel = {
val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
Params.inheritValues(this.paramMap, this, m)
Params.inheritValues(this.extractParamMap(), this, m)
m
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}


/**
* Params for probabilistic classification.
*/
Expand All @@ -37,8 +38,8 @@ private[classification] trait ProbabilisticClassifierParams
fitting: Boolean,
featuresDataType: DataType): StructType = {
val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
val map = this.paramMap ++ paramMap
addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT)
val map = extractParamMap(paramMap)
SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT)
}
}

Expand Down Expand Up @@ -102,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[

// Check schema
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val map = extractParamMap(paramMap)

// Prepare model
val tmpModel = if (paramMap.size != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType


/**
* :: AlphaComponent ::
*
Expand All @@ -40,10 +41,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params
* @group param
*/
val metricName: Param[String] = new Param(this, "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
"metric name in evaluation (areaUnderROC|areaUnderPR)")

/** @group getParam */
def getMetricName: String = get(metricName)
def getMetricName: String = getOrDefault(metricName)

/** @group setParam */
def setMetricName(value: String): this.type = set(metricName, value)
Expand All @@ -54,12 +55,14 @@ class BinaryClassificationEvaluator extends Evaluator with Params
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)

setDefault(metricName -> "areaUnderROC")

override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
val map = this.paramMap ++ paramMap
val map = extractParamMap(paramMap)

val schema = dataset.schema
checkInputColumn(schema, map(rawPredictionCol), new VectorUDT)
checkInputColumn(schema, map(labelCol), DoubleType)
SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT)
SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType)

// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
* number of features
* @group param
*/
val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
val numFeatures = new IntParam(this, "numFeatures", "number of features")

/** @group getParam */
def getNumFeatures: Int = get(numFeatures)
def getNumFeatures: Int = getOrDefault(numFeatures)

/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)

setDefault(numFeatures -> (1 << 18))

override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
hashingTF.transform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,20 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
* Normalization in L^p^ space, p = 2 by default.
* @group param
*/
val p = new DoubleParam(this, "p", "the p norm value", Some(2))
val p = new DoubleParam(this, "p", "the p norm value")

/** @group getParam */
def getP: Double = get(p)
def getP: Double = getOrDefault(p)

/** @group setParam */
def setP(value: Double): this.type = set(p, value)

setDefault(p -> 2.0)

override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
val normalizer = new feature.Normalizer(paramMap(p))
normalizer.transform
}

override protected def outputDataType: DataType = new VectorUDT()
}

Loading

0 comments on commit 971b95b

Please sign in to comment.