Skip to content

Commit

Permalink
[SW-1259] Unify ratio param across pipeline api (#1211)
Browse files Browse the repository at this point in the history
(cherry picked from commit 415b270)
  • Loading branch information
jakubhava committed May 22, 2019
1 parent 56027bb commit ff2519e
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ trait H2OAutoMLParams extends H2OCommonParams with DeprecatableParams {
setDefault(
allStringColumnsToCategorical -> true,
columnsToCategorical -> Array.empty[String],
ratio -> 1.0, // 1.0 means use whole frame as training frame,
ignoredCols -> Array.empty[String],
includeAlgos -> null,
excludeAlgos -> null,
Expand Down Expand Up @@ -212,7 +211,8 @@ trait H2OAutoMLParams extends H2OCommonParams with DeprecatableParams {

def getColumnsToCategorical(): Array[String] = $(columnsToCategorical)

def getRatio(): Double = $(ratio)
@DeprecatedMethod("getSplitRatio")
def getRatio(): Double = getSplitRatio()

@DeprecatedMethod("getFoldCol")
def getFoldColumn() = getFoldCol()
Expand Down Expand Up @@ -271,7 +271,8 @@ trait H2OAutoMLParams extends H2OCommonParams with DeprecatableParams {

def setColumnsToCategorical(columns: Array[String]): this.type = set(columnsToCategorical, columns)

def setRatio(value: Double): this.type = set(ratio, value)
@DeprecatedMethod("setSplitRatio")
def setRatio(value: Double): this.type = setSplitRatio(value)

@DeprecatedMethod("setFoldCol")
def setFoldColumn(value: String): this.type = setFoldCol(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,13 @@ import water.util.DeprecatedMethod
trait H2OAlgoParams[P <: Parameters] extends H2OAlgoParamsHelper[P] with H2OCommonParams with DeprecatableParams {

override protected def renamingMap: Map[String, String] = Map(
"predictionCol" -> "labelCol"
"predictionCol" -> "labelCol",
"splitRatio" -> "ratio"
)

//
// Param definitions
//
private val ratio = doubleParam(
"ratio",
"Determines in which ratios split the dataset")

private val allStringColumnsToCategorical = booleanParam(
"allStringColumnsToCategorical",
"Transform all strings columns to categorical")
Expand All @@ -60,7 +57,6 @@ trait H2OAlgoParams[P <: Parameters] extends H2OAlgoParamsHelper[P] with H2OComm
// Default values
//
setDefault(
ratio -> 1.0, // 1.0 means use whole frame as training frame
nfolds -> parameters._nfolds,
allStringColumnsToCategorical -> true,
columnsToCategorical -> Array.empty[String],
Expand All @@ -75,7 +71,8 @@ trait H2OAlgoParams[P <: Parameters] extends H2OAlgoParamsHelper[P] with H2OComm
//
// Getters
//
def getTrainRatio(): Double = $(ratio)
@DeprecatedMethod("getSplitRatio")
def getTrainRatio(): Double = getSplitRatio()

@DeprecatedMethod("getLabelCol")
def getPredictionCol(): String = getLabelCol()
Expand All @@ -101,7 +98,8 @@ trait H2OAlgoParams[P <: Parameters] extends H2OAlgoParamsHelper[P] with H2OComm
//
// Setters
//
def setTrainRatio(value: Double): this.type = set(ratio, value)
@DeprecatedMethod("setSplitRatio")
def setTrainRatio(value: Double): this.type = setSplitRatio(value)

@DeprecatedMethod("setLabelCol")
def setPredictionCol(value: String): this.type = setLabelCol(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.ml.h2o.param

import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{Param, Params, StringArrayParam}
import org.apache.spark.ml.param.{DoubleParam, Param, Params, StringArrayParam}

/**
* This trait contains parameters that are shared across all algorithms.
Expand All @@ -28,6 +28,9 @@ trait H2OCommonParams extends Params with Logging {
private val labelCol = new Param[String](this, "labelCol", "Label column name")
private val foldCol = new NullableStringParam(this, "foldCol", "Fold column name")
private val weightCol = new NullableStringParam(this, "weightCol", "Weight column name")
private val splitRatio = new DoubleParam(this, "splitRatio",
"Accepts values in range [0, 1.0] which determine how large part of dataset is used for training and for validation. " +
"For example, 0.8 -> 80% training 20% validation.")

//
// Default values
Expand All @@ -36,7 +39,8 @@ trait H2OCommonParams extends Params with Logging {
featuresCols -> Array.empty[String],
labelCol -> "label",
foldCol -> null,
weightCol -> null
weightCol -> null,
splitRatio -> 1.0 // Use whole frame as training frame
)

//
Expand All @@ -53,6 +57,7 @@ trait H2OCommonParams extends Params with Logging {

def getWeightCol(): String = $(weightCol)

def getSplitRatio(): Double = $(splitRatio)
//
// Setters
//
Expand All @@ -71,6 +76,7 @@ trait H2OCommonParams extends Params with Logging {

def setWeightCol(columnName: String): this.type = set(weightCol, columnName)

def setSplitRatio(ratio: Double): this.type = set(splitRatio, ratio)
//
// Other methods
//
Expand Down

0 comments on commit ff2519e

Please sign in to comment.