forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
363 additions
and
151 deletions.
There are no files selected for viewing
159 changes: 159 additions & 0 deletions
159
mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.param.shared | ||
|
||
import java.io.PrintWriter | ||
|
||
import scala.reflect.ClassTag | ||
|
||
/** | ||
* Code generator for shared params (sharedParams.scala). Run under the Spark folder with | ||
* {{{ | ||
* build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamCodeGen" | ||
* }}} | ||
*/ | ||
private[shared] object SharedParamCodeGen { | ||
|
||
def main(args: Array[String]): Unit = { | ||
val params = Seq( | ||
ParamDesc[Double]("regParam", "regularization parameter"), | ||
ParamDesc[Int]("maxIter", "max number of iterations"), | ||
ParamDesc[String]("featuresCol", "features column name"), | ||
ParamDesc[String]("labelCol", "label column name"), | ||
ParamDesc[String]("predictionCol", "prediction column name"), | ||
ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name"), | ||
ParamDesc[String]( | ||
"probabilityCol", "column name for predicted class conditional probabilities"), | ||
ParamDesc[Double]("threshold", "threshold in prediction"), | ||
ParamDesc[String]("inputCol", "input column name"), | ||
ParamDesc[String]("outputCol", "output column name"), | ||
ParamDesc[Int]("checkpointInterval", "checkpoint interval")) | ||
|
||
val code = genSharedParams(params) | ||
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" | ||
val writer = new PrintWriter(file) | ||
writer.write(code) | ||
writer.close() | ||
} | ||
|
||
/** Description of a param. */ | ||
private case class ParamDesc[T: ClassTag](name: String, doc: String) { | ||
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") | ||
require(doc.nonEmpty) // TODO: more rigorous on doc | ||
|
||
def paramTypeName: String = { | ||
val c = implicitly[ClassTag[T]].runtimeClass | ||
c match { | ||
case _ if c == classOf[Int] => "IntParam" | ||
case _ if c == classOf[Long] => "LongParam" | ||
case _ if c == classOf[Float] => "FloatParam" | ||
case _ if c == classOf[Double] => "DoubleParam" | ||
case _ if c == classOf[Boolean] => "BooleanParam" | ||
case _ => s"Param[${getTypeString(c)}]" | ||
} | ||
} | ||
|
||
def valueTypeName: String = { | ||
val c = implicitly[ClassTag[T]].runtimeClass | ||
getTypeString(c) | ||
} | ||
|
||
private def getTypeString(c: Class[_]): String = { | ||
c match { | ||
case _ if c == classOf[Int] => "Int" | ||
case _ if c == classOf[Long] => "Long" | ||
case _ if c == classOf[Float] => "Float" | ||
case _ if c == classOf[Double] => "Double" | ||
case _ if c == classOf[Boolean] => "Boolean" | ||
case _ if c == classOf[String] => "String" | ||
case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" | ||
} | ||
} | ||
} | ||
|
||
/** Generates the HasParam trait code for the input param. */ | ||
private def genHasParamTrait(param: ParamDesc[_]): String = { | ||
val name = param.name | ||
val Name = name(0).toUpper +: name.substring(1) | ||
val Param = param.paramTypeName | ||
val T = param.valueTypeName | ||
val doc = param.doc | ||
|
||
s""" | ||
|/** | ||
| * :: DeveloperApi :: | ||
| * Trait for shared param $name. | ||
| */ | ||
|@DeveloperApi | ||
|trait Has$Name extends Params { | ||
| /** | ||
| * Param for $doc. | ||
| * @group param | ||
| */ | ||
| final val $name: $Param = new $Param(this, "$name", "$doc") | ||
| | ||
| /** @group getParam */ | ||
| final def get$Name: $T = get($name) | ||
| | ||
| /** @group setParam */ | ||
| protected def set$Name(value: $T): this.type = set($name, value) | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
/** Generates Scala source code for the input params with header. */ | ||
private def genSharedParams(params: Seq[ParamDesc[_]]): String = { | ||
val header = | ||
""" | ||
|/* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| | ||
|package org.apache.spark.ml.param.shared | ||
| | ||
|import org.apache.spark.annotation.DeveloperApi | ||
|import org.apache.spark.ml.param._ | ||
| | ||
|// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen. | ||
| | ||
|// scalastyle:off | ||
""".stripMargin | ||
|
||
val footer = | ||
""" | ||
|// scalastyle:on | ||
""".stripMargin | ||
|
||
val traits = params.map(genHasParamTrait).mkString | ||
|
||
header + traits + footer | ||
} | ||
} |
204 changes: 204 additions & 0 deletions
204
mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
|
||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.param.shared | ||
|
||
import org.apache.spark.annotation.DeveloperApi | ||
import org.apache.spark.ml.param._ | ||
|
||
// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen. | ||
|
||
// scalastyle:off | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param regParam. | ||
*/ | ||
@DeveloperApi | ||
trait HasRegParam extends Params { | ||
/** | ||
* Param for regularization parameter. | ||
* @group param | ||
*/ | ||
final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") | ||
|
||
/** @group getParam */ | ||
final def getRegParam: Double = get(regParam) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param maxIter. | ||
*/ | ||
@DeveloperApi | ||
trait HasMaxIter extends Params { | ||
/** | ||
* Param for max number of iterations. | ||
* @group param | ||
*/ | ||
final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") | ||
|
||
/** @group getParam */ | ||
final def getMaxIter: Int = get(maxIter) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param featuresCol. | ||
*/ | ||
@DeveloperApi | ||
trait HasFeaturesCol extends Params { | ||
/** | ||
* Param for features column name. | ||
* @group param | ||
*/ | ||
final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") | ||
|
||
/** @group getParam */ | ||
final def getFeaturesCol: String = get(featuresCol) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param labelCol. | ||
*/ | ||
@DeveloperApi | ||
trait HasLabelCol extends Params { | ||
/** | ||
* Param for label column name. | ||
* @group param | ||
*/ | ||
final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") | ||
|
||
/** @group getParam */ | ||
final def getLabelCol: String = get(labelCol) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param predictionCol. | ||
*/ | ||
@DeveloperApi | ||
trait HasPredictionCol extends Params { | ||
/** | ||
* Param for prediction column name. | ||
* @group param | ||
*/ | ||
final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") | ||
|
||
/** @group getParam */ | ||
final def getPredictionCol: String = get(predictionCol) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param rawPredictionCol. | ||
*/ | ||
@DeveloperApi | ||
trait HasRawPredictionCol extends Params { | ||
/** | ||
* Param for raw prediction (a.k.a. confidence) column name. | ||
* @group param | ||
*/ | ||
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") | ||
|
||
/** @group getParam */ | ||
final def getRawPredictionCol: String = get(rawPredictionCol) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param probabilityCol. | ||
*/ | ||
@DeveloperApi | ||
trait HasProbabilityCol extends Params { | ||
/** | ||
* Param for column name for predicted class conditional probabilities. | ||
* @group param | ||
*/ | ||
final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") | ||
|
||
/** @group getParam */ | ||
final def getProbabilityCol: String = get(probabilityCol) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param threshold. | ||
*/ | ||
@DeveloperApi | ||
trait HasThreshold extends Params { | ||
/** | ||
* Param for threshold in prediction. | ||
* @group param | ||
*/ | ||
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") | ||
|
||
/** @group getParam */ | ||
final def getThreshold: Double = get(threshold) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param inputCol. | ||
*/ | ||
@DeveloperApi | ||
trait HasInputCol extends Params { | ||
/** | ||
* Param for input column name. | ||
* @group param | ||
*/ | ||
final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") | ||
|
||
/** @group getParam */ | ||
final def getInputCol: String = get(inputCol) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param outputCol. | ||
*/ | ||
@DeveloperApi | ||
trait HasOutputCol extends Params { | ||
/** | ||
* Param for output column name. | ||
* @group param | ||
*/ | ||
final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") | ||
|
||
/** @group getParam */ | ||
final def getOutputCol: String = get(outputCol) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Trait for shared param checkpointInterval. | ||
*/ | ||
@DeveloperApi | ||
trait HasCheckpointInterval extends Params { | ||
/** | ||
* Param for checkpoint interval. | ||
* @group param | ||
*/ | ||
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") | ||
|
||
/** @group getParam */ | ||
final def getCheckpointInterval: Int = get(checkpointInterval) | ||
} | ||
|
||
// scalastyle:on |
Oops, something went wrong.