-
Notifications
You must be signed in to change notification settings - Fork 363
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SW-1207] Implement H2OTargetEncoder for Scala API (#1192)
- Loading branch information
Showing
16 changed files
with
933 additions
and
41 deletions.
There are no files selected for viewing
24 changes: 24 additions & 0 deletions
24
core/src/main/scala/org/apache/spark/sql/DatasetExtensions.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql | ||
|
||
object DatasetExtensions { | ||
implicit class DatasetWrapper(dataset: Dataset[_]) { | ||
def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = dataset.withColumns(colNames, cols) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
ml/src/main/java/ai/h2o/sparkling/ml/features/H2OTargetEncoderHoldoutStrategy.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.h2o.sparkling.ml.features; | ||
|
||
public enum H2OTargetEncoderHoldoutStrategy { | ||
LeaveOneOut, | ||
KFold, | ||
None | ||
} |
95 changes: 95 additions & 0 deletions
95
ml/src/main/scala/ai/h2o/sparkling/ml/features/H2OTargetEncoder.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.h2o.sparkling.ml.features | ||
|
||
import ai.h2o.automl.targetencoding._ | ||
import ai.h2o.sparkling.ml.models.H2OTargetEncoderModel | ||
import ai.h2o.sparkling.ml.params.H2OAlgoParamsHelper | ||
import org.apache.spark.h2o.{Frame, H2OContext} | ||
import org.apache.spark.ml.Estimator | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} | ||
import org.apache.spark.sql.{Dataset, SparkSession} | ||
|
||
class H2OTargetEncoder(override val uid: String) | ||
extends Estimator[H2OTargetEncoderModel] | ||
with H2OTargetEncoderBase | ||
with DefaultParamsWritable { | ||
|
||
def this() = this(Identifiable.randomUID("H2OTargetEncoder")) | ||
|
||
override def fit(dataset: Dataset[_]): H2OTargetEncoderModel = { | ||
val h2oContext = H2OContext.getOrCreate(SparkSession.builder().getOrCreate()) | ||
val input = h2oContext.asH2OFrame(dataset.toDF()) | ||
convertRelevantColumnsToCategorical(input) | ||
val targetEncoderModel = trainTargetEncodingModel(input) | ||
val model = new H2OTargetEncoderModel(uid, targetEncoderModel).setParent(this) | ||
copyValues(model) | ||
} | ||
|
||
private def trainTargetEncodingModel(trainingFrame: Frame) = try { | ||
val targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters() | ||
targetEncoderParameters._withBlending = getBlendedAvgEnabled() | ||
targetEncoderParameters._blendingParams = new BlendingParams(getBlendedAvgInflectionPoint(), getBlendedAvgSmoothing()) | ||
targetEncoderParameters._response_column = getLabelCol() | ||
targetEncoderParameters._teFoldColumnName = getFoldCol() | ||
targetEncoderParameters._columnNamesToEncode = getInputCols() | ||
targetEncoderParameters.setTrain(trainingFrame._key) | ||
|
||
val builder = new TargetEncoderBuilder(targetEncoderParameters) | ||
builder.trainModel().get() // Calling get() to wait until the model training is finished. | ||
builder.getTargetEncoderModel() | ||
} catch { | ||
case e: IllegalStateException if e.getMessage.contains("We do not support multi-class target case") => | ||
throw new RuntimeException("The label column can not contain more than two unique values.") | ||
} | ||
|
||
override def copy(extra: ParamMap): H2OTargetEncoder = defaultCopy(extra) | ||
|
||
|
||
// | ||
// Parameter Setters | ||
// | ||
def setFoldCol(value: String): this.type = set(foldCol, value) | ||
|
||
def setLabelCol(value: String): this.type = set(labelCol, value) | ||
|
||
def setInputCols(values: Array[String]): this.type = set(inputCols, values) | ||
|
||
def setHoldoutStrategy(value: String): this.type = { | ||
set(holdoutStrategy, H2OAlgoParamsHelper.getValidatedEnumValue[H2OTargetEncoderHoldoutStrategy](value)) | ||
} | ||
|
||
def setBlendedAvgEnabled(value: Boolean): this.type = set(blendedAvgEnabled, value) | ||
|
||
def setBlendedAvgInflectionPoint(value: Double): this.type = set(blendedAvgInflectionPoint, value) | ||
|
||
def setBlendedAvgSmoothing(value: Double): this.type = { | ||
require(value > 0.0, "The smoothing value has to be a positive number.") | ||
set(blendedAvgSmoothing, value) | ||
} | ||
|
||
def setNoise(value: Double): this.type = { | ||
require(value >= 0.0, "Noise can't be a negative value.") | ||
set(noise, value) | ||
} | ||
|
||
def setNoiseSeed(value: Long): this.type = set(noiseSeed, value) | ||
} | ||
|
||
object H2OTargetEncoder extends DefaultParamsReadable[H2OTargetEncoder] |
52 changes: 52 additions & 0 deletions
52
ml/src/main/scala/ai/h2o/sparkling/ml/features/H2OTargetEncoderBase.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.h2o.sparkling.ml.features | ||
|
||
import ai.h2o.sparkling.ml.params.H2OTargetEncoderParams | ||
import org.apache.spark.h2o.Frame | ||
import org.apache.spark.ml.PipelineStage | ||
import org.apache.spark.sql.types.{DoubleType, StructField, StructType} | ||
|
||
trait H2OTargetEncoderBase extends PipelineStage with H2OTargetEncoderParams { | ||
override def transformSchema(schema: StructType): StructType = { | ||
validateSchema(schema) | ||
StructType(schema.fields ++ getOutputCols().map(StructField(_, DoubleType, nullable = true))) | ||
} | ||
|
||
private def validateSchema(flatSchema: StructType): Unit = { | ||
require(getLabelCol() != null, "Label column can't be null!") | ||
require(getInputCols() != null && getInputCols().nonEmpty, "The list of input columns can't be null or empty!") | ||
val fields = flatSchema.fields | ||
val fieldNames = fields.map(_.name) | ||
require(fieldNames.contains(getLabelCol()), | ||
s"The specified label column '${getLabelCol()}' was not found in the input dataset!") | ||
getInputCols().foreach { inputCol => | ||
require(fieldNames.contains(inputCol), | ||
s"The specified input column '$inputCol' was not found in the input dataset!") | ||
} | ||
val ioIntersection = getInputCols().intersect(getOutputCols()) | ||
require(ioIntersection.isEmpty, | ||
s"""The columns [${ioIntersection.map(i => s"'$i'").mkString(", ")}] are specified | ||
|as input columns and also as output columns. There can't be an overlap.""".stripMargin) | ||
} | ||
|
||
protected def convertRelevantColumnsToCategorical(frame: Frame): Unit = { | ||
val relevantColumns = getInputCols() ++ Array(getLabelCol()) | ||
relevantColumns.foreach(frame.toCategoricalCol(_)) | ||
} | ||
} |
78 changes: 78 additions & 0 deletions
78
ml/src/main/scala/ai/h2o/sparkling/ml/models/H2OTargetEncoderModel.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.h2o.sparkling.ml.models | ||
|
||
import ai.h2o.automl.targetencoding.TargetEncoderModel | ||
import ai.h2o.sparkling.ml.features.{H2OTargetEncoderBase, H2OTargetEncoderHoldoutStrategy} | ||
import org.apache.spark.h2o.H2OContext | ||
import org.apache.spark.h2o.utils.H2OSchemaUtils | ||
import org.apache.spark.ml.Model | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.util.{MLWritable, MLWriter} | ||
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} | ||
import org.apache.spark.sql.functions._ | ||
import water.support.ModelSerializationSupport | ||
|
||
class H2OTargetEncoderModel( | ||
override val uid: String, | ||
targetEncoderModel: TargetEncoderModel) | ||
extends Model[H2OTargetEncoderModel] with H2OTargetEncoderBase with MLWritable { | ||
|
||
lazy val mojoModel: H2OTargetEncoderMojoModel = { | ||
val mojoData = ModelSerializationSupport.getMojoData(targetEncoderModel) | ||
val model = new H2OTargetEncoderMojoModel() | ||
copyValues(model).setMojoData(mojoData) | ||
} | ||
|
||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
if (inTrainingMode) { | ||
transformTrainingDataset(dataset) | ||
} else { | ||
mojoModel.transform(dataset) | ||
} | ||
} | ||
|
||
def transformTrainingDataset(dataset: Dataset[_]): DataFrame = { | ||
val h2oContext = H2OContext.getOrCreate(SparkSession.builder().getOrCreate()) | ||
val temporaryColumn = getClass.getSimpleName + "_temporary_id" | ||
val withIdDF = dataset.withColumn(temporaryColumn, monotonically_increasing_id) | ||
val flatDF = H2OSchemaUtils.flattenDataFrame(withIdDF) | ||
val relevantColumns = getInputCols() ++ Array(getLabelCol(), getFoldCol(), temporaryColumn).flatMap(Option(_)) | ||
val relevantColumnsDF = flatDF.select(relevantColumns.map(col(_)): _*) | ||
val input = h2oContext.asH2OFrame(relevantColumnsDF) | ||
convertRelevantColumnsToCategorical(input) | ||
val holdoutStrategyId = H2OTargetEncoderHoldoutStrategy.valueOf(getHoldoutStrategy()).ordinal().asInstanceOf[Byte] | ||
val outputFrame = targetEncoderModel.transform(input, holdoutStrategyId, getNoise(), getNoiseSeed()) | ||
val outputColumnsOnlyFrame = outputFrame.subframe(getOutputCols() ++ Array(temporaryColumn)) | ||
val outputColumnsOnlyDF = h2oContext.asDataFrame(outputColumnsOnlyFrame) | ||
withIdDF | ||
.join(outputColumnsOnlyDF, Seq(temporaryColumn), joinType="left") | ||
.drop(temporaryColumn) | ||
} | ||
|
||
private def inTrainingMode: Boolean = { | ||
val stackTrace = Thread.currentThread().getStackTrace() | ||
stackTrace.exists(e => e.getMethodName == "fit" && e.getClassName == "org.apache.spark.ml.Pipeline") | ||
} | ||
|
||
override def copy(extra: ParamMap): H2OTargetEncoderModel = { | ||
defaultCopy[H2OTargetEncoderModel](extra).setParent(parent) | ||
} | ||
|
||
override def write: MLWriter = mojoModel.write | ||
} |
84 changes: 84 additions & 0 deletions
84
ml/src/main/scala/ai/h2o/sparkling/ml/models/H2OTargetEncoderMojoModel.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.h2o.sparkling.ml.models | ||
|
||
|
||
import ai.h2o.sparkling.ml.features.H2OTargetEncoderBase | ||
|
||
import hex.genmodel.easy.EasyPredictModelWrapper | ||
import hex.genmodel.algos.targetencoder.TargetEncoderMojoModel | ||
|
||
import org.apache.spark.h2o.converters.RowConverter | ||
import org.apache.spark.ml.Model | ||
import org.apache.spark.ml.h2o.models.{H2OMOJOFlattenedInput, H2OMOJOReadable, H2OMOJOWritable} | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
|
||
import water.support.ModelSerializationSupport | ||
|
||
class H2OTargetEncoderMojoModel(override val uid: String) extends Model[H2OTargetEncoderMojoModel] | ||
with H2OTargetEncoderBase with H2OMOJOWritable with H2OMOJOFlattenedInput { | ||
|
||
override protected def inputColumnNames: Array[String] = getInputCols() | ||
|
||
override protected def outputColumnName: String = getClass.getSimpleName + "_output" | ||
|
||
def this() = this(Identifiable.randomUID(getClass.getSimpleName)) | ||
|
||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
import org.apache.spark.sql.DatasetExtensions._ | ||
val outputCols = getOutputCols() | ||
val udfWrapper = H2OTargetEncoderMojoUdfWrapper(getMojoData(), outputCols) | ||
val withPredictionsDF = applyPredictionUdf(dataset, _ => udfWrapper.mojoUdf) | ||
withPredictionsDF | ||
.withColumns(outputCols, outputCols.zipWithIndex.map{ case (c, i) => col(outputColumnName)(i) as c }) | ||
.drop(outputColumnName) | ||
} | ||
|
||
override def copy(extra: ParamMap): H2OTargetEncoderMojoModel = defaultCopy(extra) | ||
} | ||
|
||
/** | ||
* The class holds all necessary dependencies of udf that needs to be serialized. | ||
*/ | ||
case class H2OTargetEncoderMojoUdfWrapper(mojoData: Array[Byte], outputCols: Array[String]) { | ||
@transient private lazy val easyPredictModelWrapper: EasyPredictModelWrapper = { | ||
val model = ModelSerializationSupport | ||
.getMojoModel(mojoData) | ||
.asInstanceOf[TargetEncoderMojoModel] | ||
val config = new EasyPredictModelWrapper.Config() | ||
config.setModel(model) | ||
config.setConvertUnknownCategoricalLevelsToNa(true) | ||
config.setConvertInvalidNumbersToNa(true) | ||
new EasyPredictModelWrapper(config) | ||
} | ||
|
||
val mojoUdf = udf[Array[Option[Double]], Row] { r: Row => | ||
val inputRowData = RowConverter.toH2ORowData(r) | ||
try { | ||
val prediction = easyPredictModelWrapper.transformWithTargetEncoding(inputRowData) | ||
prediction.transformations.map(Some(_)) | ||
} catch { | ||
case _: Throwable => outputCols.map(_ => None) | ||
} | ||
} | ||
} | ||
|
||
object H2OTargetEncoderMojoModel extends H2OMOJOReadable[H2OTargetEncoderMojoModel] |
Oops, something went wrong.