Skip to content

Commit

Permalink
[SW-1207] Implement H2OTargetEncoder for Scala API (#1192)
Browse files Browse the repository at this point in the history
  • Loading branch information
mn-mikke committed Jul 26, 2019
1 parent 99f37e6 commit 3faa9b5
Show file tree
Hide file tree
Showing 16 changed files with 933 additions and 41 deletions.
24 changes: 24 additions & 0 deletions core/src/main/scala/org/apache/spark/sql/DatasetExtensions.scala
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

object DatasetExtensions {
implicit class DatasetWrapper(dataset: Dataset[_]) {
def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = dataset.withColumns(colNames, cols)
}
}
Expand Up @@ -91,6 +91,8 @@ object TestFrameUtils extends Matchers {
}

def assertDataFramesAreIdentical(expected: DataFrame, produced: DataFrame): Unit = {
expected.cache()
produced.cache()
val expectedCount = expected.count()
val producedCount = produced.count()
assert(
Expand Down
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.features;

public enum H2OTargetEncoderHoldoutStrategy {
LeaveOneOut,
KFold,
None
}
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.features

import ai.h2o.automl.targetencoding._
import ai.h2o.sparkling.ml.models.H2OTargetEncoderModel
import ai.h2o.sparkling.ml.params.H2OAlgoParamsHelper
import org.apache.spark.h2o.{Frame, H2OContext}
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{Dataset, SparkSession}

class H2OTargetEncoder(override val uid: String)
extends Estimator[H2OTargetEncoderModel]
with H2OTargetEncoderBase
with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("H2OTargetEncoder"))

override def fit(dataset: Dataset[_]): H2OTargetEncoderModel = {
val h2oContext = H2OContext.getOrCreate(SparkSession.builder().getOrCreate())
val input = h2oContext.asH2OFrame(dataset.toDF())
convertRelevantColumnsToCategorical(input)
val targetEncoderModel = trainTargetEncodingModel(input)
val model = new H2OTargetEncoderModel(uid, targetEncoderModel).setParent(this)
copyValues(model)
}

private def trainTargetEncodingModel(trainingFrame: Frame) = try {
val targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters()
targetEncoderParameters._withBlending = getBlendedAvgEnabled()
targetEncoderParameters._blendingParams = new BlendingParams(getBlendedAvgInflectionPoint(), getBlendedAvgSmoothing())
targetEncoderParameters._response_column = getLabelCol()
targetEncoderParameters._teFoldColumnName = getFoldCol()
targetEncoderParameters._columnNamesToEncode = getInputCols()
targetEncoderParameters.setTrain(trainingFrame._key)

val builder = new TargetEncoderBuilder(targetEncoderParameters)
builder.trainModel().get() // Calling get() to wait until the model training is finished.
builder.getTargetEncoderModel()
} catch {
case e: IllegalStateException if e.getMessage.contains("We do not support multi-class target case") =>
throw new RuntimeException("The label column can not contain more than two unique values.")
}

override def copy(extra: ParamMap): H2OTargetEncoder = defaultCopy(extra)


//
// Parameter Setters
//
def setFoldCol(value: String): this.type = set(foldCol, value)

def setLabelCol(value: String): this.type = set(labelCol, value)

def setInputCols(values: Array[String]): this.type = set(inputCols, values)

def setHoldoutStrategy(value: String): this.type = {
set(holdoutStrategy, H2OAlgoParamsHelper.getValidatedEnumValue[H2OTargetEncoderHoldoutStrategy](value))
}

def setBlendedAvgEnabled(value: Boolean): this.type = set(blendedAvgEnabled, value)

def setBlendedAvgInflectionPoint(value: Double): this.type = set(blendedAvgInflectionPoint, value)

def setBlendedAvgSmoothing(value: Double): this.type = {
require(value > 0.0, "The smoothing value has to be a positive number.")
set(blendedAvgSmoothing, value)
}

def setNoise(value: Double): this.type = {
require(value >= 0.0, "Noise can't be a negative value.")
set(noise, value)
}

def setNoiseSeed(value: Long): this.type = set(noiseSeed, value)
}

object H2OTargetEncoder extends DefaultParamsReadable[H2OTargetEncoder]
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.features

import ai.h2o.sparkling.ml.params.H2OTargetEncoderParams
import org.apache.spark.h2o.Frame
import org.apache.spark.ml.PipelineStage
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

trait H2OTargetEncoderBase extends PipelineStage with H2OTargetEncoderParams {
override def transformSchema(schema: StructType): StructType = {
validateSchema(schema)
StructType(schema.fields ++ getOutputCols().map(StructField(_, DoubleType, nullable = true)))
}

private def validateSchema(flatSchema: StructType): Unit = {
require(getLabelCol() != null, "Label column can't be null!")
require(getInputCols() != null && getInputCols().nonEmpty, "The list of input columns can't be null or empty!")
val fields = flatSchema.fields
val fieldNames = fields.map(_.name)
require(fieldNames.contains(getLabelCol()),
s"The specified label column '${getLabelCol()}' was not found in the input dataset!")
getInputCols().foreach { inputCol =>
require(fieldNames.contains(inputCol),
s"The specified input column '$inputCol' was not found in the input dataset!")
}
val ioIntersection = getInputCols().intersect(getOutputCols())
require(ioIntersection.isEmpty,
s"""The columns [${ioIntersection.map(i => s"'$i'").mkString(", ")}] are specified
|as input columns and also as output columns. There can't be an overlap.""".stripMargin)
}

protected def convertRelevantColumnsToCategorical(frame: Frame): Unit = {
val relevantColumns = getInputCols() ++ Array(getLabelCol())
relevantColumns.foreach(frame.toCategoricalCol(_))
}
}
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.models

import ai.h2o.automl.targetencoding.TargetEncoderModel
import ai.h2o.sparkling.ml.features.{H2OTargetEncoderBase, H2OTargetEncoderHoldoutStrategy}
import org.apache.spark.h2o.H2OContext
import org.apache.spark.h2o.utils.H2OSchemaUtils
import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import water.support.ModelSerializationSupport

class H2OTargetEncoderModel(
override val uid: String,
targetEncoderModel: TargetEncoderModel)
extends Model[H2OTargetEncoderModel] with H2OTargetEncoderBase with MLWritable {

lazy val mojoModel: H2OTargetEncoderMojoModel = {
val mojoData = ModelSerializationSupport.getMojoData(targetEncoderModel)
val model = new H2OTargetEncoderMojoModel()
copyValues(model).setMojoData(mojoData)
}

override def transform(dataset: Dataset[_]): DataFrame = {
if (inTrainingMode) {
transformTrainingDataset(dataset)
} else {
mojoModel.transform(dataset)
}
}

def transformTrainingDataset(dataset: Dataset[_]): DataFrame = {
val h2oContext = H2OContext.getOrCreate(SparkSession.builder().getOrCreate())
val temporaryColumn = getClass.getSimpleName + "_temporary_id"
val withIdDF = dataset.withColumn(temporaryColumn, monotonically_increasing_id)
val flatDF = H2OSchemaUtils.flattenDataFrame(withIdDF)
val relevantColumns = getInputCols() ++ Array(getLabelCol(), getFoldCol(), temporaryColumn).flatMap(Option(_))
val relevantColumnsDF = flatDF.select(relevantColumns.map(col(_)): _*)
val input = h2oContext.asH2OFrame(relevantColumnsDF)
convertRelevantColumnsToCategorical(input)
val holdoutStrategyId = H2OTargetEncoderHoldoutStrategy.valueOf(getHoldoutStrategy()).ordinal().asInstanceOf[Byte]
val outputFrame = targetEncoderModel.transform(input, holdoutStrategyId, getNoise(), getNoiseSeed())
val outputColumnsOnlyFrame = outputFrame.subframe(getOutputCols() ++ Array(temporaryColumn))
val outputColumnsOnlyDF = h2oContext.asDataFrame(outputColumnsOnlyFrame)
withIdDF
.join(outputColumnsOnlyDF, Seq(temporaryColumn), joinType="left")
.drop(temporaryColumn)
}

private def inTrainingMode: Boolean = {
val stackTrace = Thread.currentThread().getStackTrace()
stackTrace.exists(e => e.getMethodName == "fit" && e.getClassName == "org.apache.spark.ml.Pipeline")
}

override def copy(extra: ParamMap): H2OTargetEncoderModel = {
defaultCopy[H2OTargetEncoderModel](extra).setParent(parent)
}

override def write: MLWriter = mojoModel.write
}
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.models


import ai.h2o.sparkling.ml.features.H2OTargetEncoderBase

import hex.genmodel.easy.EasyPredictModelWrapper
import hex.genmodel.algos.targetencoder.TargetEncoderMojoModel

import org.apache.spark.h2o.converters.RowConverter
import org.apache.spark.ml.Model
import org.apache.spark.ml.h2o.models.{H2OMOJOFlattenedInput, H2OMOJOReadable, H2OMOJOWritable}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import water.support.ModelSerializationSupport

class H2OTargetEncoderMojoModel(override val uid: String) extends Model[H2OTargetEncoderMojoModel]
with H2OTargetEncoderBase with H2OMOJOWritable with H2OMOJOFlattenedInput {

override protected def inputColumnNames: Array[String] = getInputCols()

override protected def outputColumnName: String = getClass.getSimpleName + "_output"

def this() = this(Identifiable.randomUID(getClass.getSimpleName))

override def transform(dataset: Dataset[_]): DataFrame = {
import org.apache.spark.sql.DatasetExtensions._
val outputCols = getOutputCols()
val udfWrapper = H2OTargetEncoderMojoUdfWrapper(getMojoData(), outputCols)
val withPredictionsDF = applyPredictionUdf(dataset, _ => udfWrapper.mojoUdf)
withPredictionsDF
.withColumns(outputCols, outputCols.zipWithIndex.map{ case (c, i) => col(outputColumnName)(i) as c })
.drop(outputColumnName)
}

override def copy(extra: ParamMap): H2OTargetEncoderMojoModel = defaultCopy(extra)
}

/**
* The class holds all necessary dependencies of udf that needs to be serialized.
*/
case class H2OTargetEncoderMojoUdfWrapper(mojoData: Array[Byte], outputCols: Array[String]) {
@transient private lazy val easyPredictModelWrapper: EasyPredictModelWrapper = {
val model = ModelSerializationSupport
.getMojoModel(mojoData)
.asInstanceOf[TargetEncoderMojoModel]
val config = new EasyPredictModelWrapper.Config()
config.setModel(model)
config.setConvertUnknownCategoricalLevelsToNa(true)
config.setConvertInvalidNumbersToNa(true)
new EasyPredictModelWrapper(config)
}

val mojoUdf = udf[Array[Option[Double]], Row] { r: Row =>
val inputRowData = RowConverter.toH2ORowData(r)
try {
val prediction = easyPredictModelWrapper.transformWithTargetEncoding(inputRowData)
prediction.transformations.map(Some(_))
} catch {
case _: Throwable => outputCols.map(_ => None)
}
}
}

object H2OTargetEncoderMojoModel extends H2OMOJOReadable[H2OTargetEncoderMojoModel]

0 comments on commit 3faa9b5

Please sign in to comment.