Skip to content

Commit

Permalink
update default values handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 8, 2015
1 parent dcab97a commit abb7a3b
Show file tree
Hide file tree
Showing 15 changed files with 102 additions and 87 deletions.
3 changes: 2 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.annotation.varargs
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -65,7 +66,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
/** @group setParam */
def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]

/** @group setParam */
/** @goup setParam */
def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.ml.classification

import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
import org.apache.spark.ml.param.{Params, ParamMap}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
Expand All @@ -36,6 +37,8 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
private[spark] trait ClassifierParams extends PredictorParams
with HasRawPredictionCol {

setDefault(rawPredictionCol, "rawPrediction")

override protected def validateAndTransformSchema(
schema: StructType,
paramMap: ParamMap,
Expand Down Expand Up @@ -67,8 +70,7 @@ private[spark] abstract class Classifier[
with ClassifierParams {

/** @group setParam */
def setRawPredictionCol(value: String): E =
set(rawPredictionCol, value).asInstanceOf[E]
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]

// TODO: defaultEvaluator (follow-up PR)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,23 @@ package org.apache.spark.ml.classification

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel


/**
* Params for logistic regression.
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasMaxIter with HasThreshold
with HasRegParam with HasMaxIter with HasThreshold {

setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5)
}



/**
Expand All @@ -45,10 +49,6 @@ class LogisticRegression
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams {

setRegParam(0.1)
setMaxIter(100)
setThreshold(0.5)

/** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)

Expand Down Expand Up @@ -96,8 +96,6 @@ class LogisticRegressionModel private[ml] (
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams {

setThreshold(0.5)

/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
Expand All @@ -40,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
* @group param
*/
val metricName: Param[String] = new Param(this, "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
"metric name in evaluation (areaUnderROC|areaUnderPR)")

/** @group getParam */
def getMetricName: String = get(metricName)
Expand All @@ -51,7 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
/** @group setParam */
def setScoreCol(value: String): this.type = set(rawPredictionCol, value)

/** @group setParam */
/** @goup setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)

override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ import org.apache.spark.sql.types.DataType
@AlphaComponent
class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {

setDefault(numFeatures -> (1 << 18))

/**
* number of features
* @group param
*/
val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
val numFeatures = new IntParam(this, "numFeatures", "number of features")

/** @group getParam */
def getNumFeatures: Int = get(numFeatures)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ import org.apache.spark.sql.types.DataType
@AlphaComponent
class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {

setDefault(p -> 2.0)

/**
* Normalization in L^p^ space, p = 2 by default.
* @group param
*/
val p = new DoubleParam(this, "p", "the p norm value", Some(2))
val p = new DoubleParam(this, "p", "the p norm value")

/** @group getParam */
def getP: Double = get(p)
Expand All @@ -50,4 +52,3 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {

override protected def outputDataType: DataType = new VectorUDT()
}

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
Expand Down
10 changes: 5 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
@AlphaComponent
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {

setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+")

/**
* param for minimum token length, default is one to avoid returning empty strings
* @group param
*/
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1))
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length")

/** @group setParam */
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)
Expand All @@ -68,8 +70,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
* param sets regex as splitting on gaps (true) or matching tokens (false)
* @group param
*/
val gaps: BooleanParam = new BooleanParam(
this, "gaps", "Set regex to match gaps or tokens", Some(false))
val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens")

/** @group setParam */
def setGaps(value: Boolean): this.type = set(gaps, value)
Expand All @@ -81,8 +82,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
* param sets regex pattern used by tokenizer
* @group param
*/
val pattern: Param[String] = new Param(
this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+"))
val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")

/** @group setParam */
def setPattern(value: String): this.type = set(pattern, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.impl.estimator
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
Expand Down
91 changes: 48 additions & 43 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.Identifiable
import org.apache.spark.sql.types.{DataType, StructField, StructType}


/**
* :: AlphaComponent ::
* A param with self-contained documentation and optionally default value. Primitive-typed param
Expand All @@ -38,12 +37,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType}
* @tparam T param value type
*/
@AlphaComponent
class Param[T] (
val parent: Params,
val name: String,
val doc: String,
val defaultValue: Option[T] = None)
extends Serializable {
class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable {

/**
* Creates a param pair with the given value (for Java).
Expand All @@ -55,58 +49,42 @@ class Param[T] (
*/
def ->(value: T): ParamPair[T] = ParamPair(this, value)

override def toString: String = {
if (defaultValue.isDefined) {
s"$name: $doc (default: ${defaultValue.get})"
} else {
s"$name: $doc"
}
}
override def toString: String = s"$name: $doc"
}

// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...

/** Specialized version of [[Param[Double]]] for Java. */
class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double])
extends Param[Double](parent, name, doc, defaultValue) {

def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
class DoubleParam(parent: Params, name: String, doc: String)
extends Param[Double](parent, name, doc) {

override def w(value: Double): ParamPair[Double] = super.w(value)
}

/** Specialized version of [[Param[Int]]] for Java. */
class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int])
extends Param[Int](parent, name, doc, defaultValue) {

def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
class IntParam(parent: Params, name: String, doc: String)
extends Param[Int](parent, name, doc) {

override def w(value: Int): ParamPair[Int] = super.w(value)
}

/** Specialized version of [[Param[Float]]] for Java. */
class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float])
extends Param[Float](parent, name, doc, defaultValue) {

def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
class FloatParam(parent: Params, name: String, doc: String)
extends Param[Float](parent, name, doc) {

override def w(value: Float): ParamPair[Float] = super.w(value)
}

/** Specialized version of [[Param[Long]]] for Java. */
class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long])
extends Param[Long](parent, name, doc, defaultValue) {

def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
class LongParam(parent: Params, name: String, doc: String)
extends Param[Long](parent, name, doc) {

override def w(value: Long): ParamPair[Long] = super.w(value)
}

/** Specialized version of [[Param[Boolean]]] for Java. */
class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean])
extends Param[Boolean](parent, name, doc, defaultValue) {

def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
class BooleanParam(parent: Params, name: String, doc: String)
extends Param[Boolean](parent, name, doc) {

override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
Expand All @@ -124,7 +102,10 @@ case class ParamPair[T](param: Param[T], value: T)
@AlphaComponent
trait Params extends Identifiable with Serializable {

/** Returns all params. */
/**
* Returns all params. The default implementation uses Java reflection to list all public methods
* that have return type [[Param]].
*/
def params: Array[Param[_]] = {
val methods = this.getClass.getMethods
methods.filter { m =>
Expand Down Expand Up @@ -159,7 +140,7 @@ trait Params extends Identifiable with Serializable {
}

/** Gets a param by its name. */
private[ml] def getParam(paramName: String): Param[Any] = {
protected final def getParam(paramName: String): Param[Any] = {
val m = this.getClass.getMethod(paramName)
assert(Modifier.isPublic(m.getModifiers) &&
classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
Expand All @@ -170,7 +151,7 @@ trait Params extends Identifiable with Serializable {
/**
* Sets a parameter in the embedded param map.
*/
protected def set[T](param: Param[T], value: T): this.type = {
protected final def set[T](param: Param[T], value: T): this.type = {
require(param.parent.eq(this))
paramMap.put(param.asInstanceOf[Param[Any]], value)
this
Expand All @@ -179,22 +160,48 @@ trait Params extends Identifiable with Serializable {
/**
* Sets a parameter (by name) in the embedded param map.
*/
private[ml] def set(param: String, value: Any): this.type = {
protected final def set(param: String, value: Any): this.type = {
set(getParam(param), value)
}

/**
* Gets the value of a parameter in the embedded param map.
*/
protected def get[T](param: Param[T]): T = {
protected final def get[T](param: Param[T]): T = {
require(param.parent.eq(this))
paramMap(param)
}

/**
* Internal param map.
*/
protected val paramMap: ParamMap = ParamMap.empty
protected final val paramMap: ParamMap = ParamMap.empty

/**
* Internal param map for default values.
*/
protected final val defaultValues: ParamMap = ParamMap.empty

/**
* Sets a default value.
*/
protected final def setDefault[T](param: Param[T], value: T): this.type = {
require(param.parent.eq(this))
defaultValues.put(param, value)
this
}

protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
setDefault(p.param.asInstanceOf[Param[Any]], p.value)
}
this
}

protected final def getDefault[T](param: Param[T]): Option[T] = {
require(param.parent.eq(this))
defaultValues.get(param)
}

/**
* Check whether the given schema contains an input column.
Expand Down Expand Up @@ -283,9 +290,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* Optionally returns the value associated with a param or its default.
*/
def get[T](param: Param[T]): Option[T] = {
map.get(param.asInstanceOf[Param[Any]])
.orElse(param.defaultValue)
.asInstanceOf[Option[T]]
map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]]
}

/**
Expand Down
Loading

0 comments on commit abb7a3b

Please sign in to comment.