Skip to content

Commit

Permalink
add codegen for shared params
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 8, 2015
1 parent fc957dc commit dcab97a
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 151 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.param.shared

import java.io.PrintWriter

import scala.reflect.ClassTag

/**
* Code generator for shared params (sharedParams.scala). Run under the Spark folder with
* {{{
* build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamCodeGen"
* }}}
*/
private[shared] object SharedParamCodeGen {

def main(args: Array[String]): Unit = {
val params = Seq(
ParamDesc[Double]("regParam", "regularization parameter"),
ParamDesc[Int]("maxIter", "max number of iterations"),
ParamDesc[String]("featuresCol", "features column name"),
ParamDesc[String]("labelCol", "label column name"),
ParamDesc[String]("predictionCol", "prediction column name"),
ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name"),
ParamDesc[String](
"probabilityCol", "column name for predicted class conditional probabilities"),
ParamDesc[Double]("threshold", "threshold in prediction"),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[String]("outputCol", "output column name"),
ParamDesc[Int]("checkpointInterval", "checkpoint interval"))

val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
val writer = new PrintWriter(file)
writer.write(code)
writer.close()
}

/** Description of a param. */
private case class ParamDesc[T: ClassTag](name: String, doc: String) {
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
require(doc.nonEmpty) // TODO: more rigorous on doc

def paramTypeName: String = {
val c = implicitly[ClassTag[T]].runtimeClass
c match {
case _ if c == classOf[Int] => "IntParam"
case _ if c == classOf[Long] => "LongParam"
case _ if c == classOf[Float] => "FloatParam"
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
case _ => s"Param[${getTypeString(c)}]"
}
}

def valueTypeName: String = {
val c = implicitly[ClassTag[T]].runtimeClass
getTypeString(c)
}

private def getTypeString(c: Class[_]): String = {
c match {
case _ if c == classOf[Int] => "Int"
case _ if c == classOf[Long] => "Long"
case _ if c == classOf[Float] => "Float"
case _ if c == classOf[Double] => "Double"
case _ if c == classOf[Boolean] => "Boolean"
case _ if c == classOf[String] => "String"
case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]"
}
}
}

/** Generates the HasParam trait code for the input param. */
private def genHasParamTrait(param: ParamDesc[_]): String = {
val name = param.name
val Name = name(0).toUpper +: name.substring(1)
val Param = param.paramTypeName
val T = param.valueTypeName
val doc = param.doc

s"""
|/**
| * :: DeveloperApi ::
| * Trait for shared param $name.
| */
|@DeveloperApi
|trait Has$Name extends Params {
| /**
| * Param for $doc.
| * @group param
| */
| final val $name: $Param = new $Param(this, "$name", "$doc")
|
| /** @group getParam */
| final def get$Name: $T = get($name)
|
| /** @group setParam */
| protected def set$Name(value: $T): this.type = set($name, value)
|}
""".stripMargin
}

/** Generates Scala source code for the input params with header. */
private def genSharedParams(params: Seq[ParamDesc[_]]): String = {
val header =
"""
|/*
| * Licensed to the Apache Software Foundation (ASF) under one or more
| * contributor license agreements. See the NOTICE file distributed with
| * this work for additional information regarding copyright ownership.
| * The ASF licenses this file to You under the Apache License, Version 2.0
| * (the "License"); you may not use this file except in compliance with
| * the License. You may obtain a copy of the License at
| *
| * http://www.apache.org/licenses/LICENSE-2.0
| *
| * Unless required by applicable law or agreed to in writing, software
| * distributed under the License is distributed on an "AS IS" BASIS,
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
| * See the License for the specific language governing permissions and
| * limitations under the License.
| */
|
|package org.apache.spark.ml.param.shared
|
|import org.apache.spark.annotation.DeveloperApi
|import org.apache.spark.ml.param._
|
|// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen.
|
|// scalastyle:off
""".stripMargin

val footer =
"""
|// scalastyle:on
""".stripMargin

val traits = params.map(genHasParamTrait).mkString

header + traits + footer
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.param.shared

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param._

// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen.

// scalastyle:off

/**
* :: DeveloperApi ::
* Trait for shared param regParam.
*/
@DeveloperApi
trait HasRegParam extends Params {
/**
* Param for regularization parameter.
* @group param
*/
final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")

/** @group getParam */
final def getRegParam: Double = get(regParam)
}

/**
* :: DeveloperApi ::
* Trait for shared param maxIter.
*/
@DeveloperApi
trait HasMaxIter extends Params {
/**
* Param for max number of iterations.
* @group param
*/
final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")

/** @group getParam */
final def getMaxIter: Int = get(maxIter)
}

/**
* :: DeveloperApi ::
* Trait for shared param featuresCol.
*/
@DeveloperApi
trait HasFeaturesCol extends Params {
/**
* Param for features column name.
* @group param
*/
final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name")

/** @group getParam */
final def getFeaturesCol: String = get(featuresCol)
}

/**
* :: DeveloperApi ::
* Trait for shared param labelCol.
*/
@DeveloperApi
trait HasLabelCol extends Params {
/**
* Param for label column name.
* @group param
*/
final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name")

/** @group getParam */
final def getLabelCol: String = get(labelCol)
}

/**
* :: DeveloperApi ::
* Trait for shared param predictionCol.
*/
@DeveloperApi
trait HasPredictionCol extends Params {
/**
* Param for prediction column name.
* @group param
*/
final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name")

/** @group getParam */
final def getPredictionCol: String = get(predictionCol)
}

/**
* :: DeveloperApi ::
* Trait for shared param rawPredictionCol.
*/
@DeveloperApi
trait HasRawPredictionCol extends Params {
/**
* Param for raw prediction (a.k.a. confidence) column name.
* @group param
*/
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name")

/** @group getParam */
final def getRawPredictionCol: String = get(rawPredictionCol)
}

/**
* :: DeveloperApi ::
* Trait for shared param probabilityCol.
*/
@DeveloperApi
trait HasProbabilityCol extends Params {
/**
* Param for column name for predicted class conditional probabilities.
* @group param
*/
final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities")

/** @group getParam */
final def getProbabilityCol: String = get(probabilityCol)
}

/**
* :: DeveloperApi ::
* Trait for shared param threshold.
*/
@DeveloperApi
trait HasThreshold extends Params {
/**
* Param for threshold in prediction.
* @group param
*/
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")

/** @group getParam */
final def getThreshold: Double = get(threshold)
}

/**
* :: DeveloperApi ::
* Trait for shared param inputCol.
*/
@DeveloperApi
trait HasInputCol extends Params {
/**
* Param for input column name.
* @group param
*/
final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name")

/** @group getParam */
final def getInputCol: String = get(inputCol)
}

/**
* :: DeveloperApi ::
* Trait for shared param outputCol.
*/
@DeveloperApi
trait HasOutputCol extends Params {
/**
* Param for output column name.
* @group param
*/
final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")

/** @group getParam */
final def getOutputCol: String = get(outputCol)
}

/**
* :: DeveloperApi ::
* Trait for shared param checkpointInterval.
*/
@DeveloperApi
trait HasCheckpointInterval extends Params {
/**
* Param for checkpoint interval.
* @group param
*/
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")

/** @group getParam */
final def getCheckpointInterval: Int = get(checkpointInterval)
}

// scalastyle:on
Loading

0 comments on commit dcab97a

Please sign in to comment.