Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SW-1207] Implement H2OTargetEncoder for Scala API #1192

Merged
merged 31 commits into from
Jul 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ebd47f0
[SW-1207] Implement H2OTargetEncoder
mn-mikke May 2, 2019
80f44da
Fixing the build
mn-mikke May 9, 2019
8febad2
An implementation via usage of MOJO models
mn-mikke Jun 5, 2019
f06577a
Update the reference to the transform function
mn-mikke Jun 13, 2019
456d09c
Addressing Kuba's comments
mn-mikke Jun 13, 2019
14ea44e
Missing new line
mn-mikke Jun 13, 2019
7ed0323
Setting the parent on the copy of the target model
mn-mikke Jun 13, 2019
b6fd9a6
Addressing review comments from Andrej
mn-mikke Jun 27, 2019
a857201
Removing reference to h2o build
mn-mikke Jul 12, 2019
4ced88b
Update implementation according to the latest version of TE in H20-3
mn-mikke Jul 15, 2019
6220d71
Flattening TE parameters
mn-mikke Jul 16, 2019
aea5045
Remove CaseClassParam
mn-mikke Jul 16, 2019
8dd7c74
Revert h2o-3 build inclusion
mn-mikke Jul 16, 2019
4737030
Improving comment for the noise parameter
mn-mikke Jul 16, 2019
735b882
Update tests according to the latest TE parameters
mn-mikke Jul 16, 2019
93b604f
Getting TE model via getTargetEncoderModel
mn-mikke Jul 17, 2019
9cacd08
Moving TargetEncoder to the package ai.h2o.sparkling.ml
mn-mikke Jul 18, 2019
80dff9e
More tests
mn-mikke Jul 18, 2019
0ce13e0
Adding more tests
mn-mikke Jul 19, 2019
285f94b
Remove cache
mn-mikke Jul 19, 2019
48cf158
Remove includeBuild
mn-mikke Jul 19, 2019
a1a4826
More tests
mn-mikke Jul 21, 2019
2b4fa39
Fixing description
mn-mikke Jul 21, 2019
4d13f1d
Fixing scala style
mn-mikke Jul 22, 2019
5ee4228
Adressing review comments
mn-mikke Jul 22, 2019
81d765b
Using enum for houldout strategy
mn-mikke Jul 23, 2019
1a80364
Updating test cases according to the changes in H2O-3
mn-mikke Jul 23, 2019
5d96508
Updating tests
mn-mikke Jul 23, 2019
6067acf
Adding more test testing Java API scenarios
mn-mikke Jul 23, 2019
e4ab4f2
Updating reference to H2OAlgoParamsHelper
mn-mikke Jul 24, 2019
748ecc7
Fixing problems after rebase
mn-mikke Jul 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

object DatasetExtensions {
implicit class DatasetWrapper(dataset: Dataset[_]) {
def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = dataset.withColumns(colNames, cols)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ object TestFrameUtils extends Matchers {
}

def assertDataFramesAreIdentical(expected: DataFrame, produced: DataFrame): Unit = {
expected.cache()
produced.cache()
val expectedCount = expected.count()
val producedCount = produced.count()
assert(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.features;

public enum H2OTargetEncoderHoldoutStrategy {
mn-mikke marked this conversation as resolved.
Show resolved Hide resolved
LeaveOneOut,
KFold,
None
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.features

import ai.h2o.automl.targetencoding._
import ai.h2o.sparkling.ml.models.H2OTargetEncoderModel
import ai.h2o.sparkling.ml.params.H2OAlgoParamsHelper
import org.apache.spark.h2o.{Frame, H2OContext}
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{Dataset, SparkSession}

class H2OTargetEncoder(override val uid: String)
extends Estimator[H2OTargetEncoderModel]
with H2OTargetEncoderBase
with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("H2OTargetEncoder"))

override def fit(dataset: Dataset[_]): H2OTargetEncoderModel = {
val h2oContext = H2OContext.getOrCreate(SparkSession.builder().getOrCreate())
val input = h2oContext.asH2OFrame(dataset.toDF())
convertRelevantColumnsToCategorical(input)
val targetEncoderModel = trainTargetEncodingModel(input)
val model = new H2OTargetEncoderModel(uid, targetEncoderModel).setParent(this)
copyValues(model)
}

private def trainTargetEncodingModel(trainingFrame: Frame) = try {
val targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters()
targetEncoderParameters._withBlending = getBlendedAvgEnabled()
targetEncoderParameters._blendingParams = new BlendingParams(getBlendedAvgInflectionPoint(), getBlendedAvgSmoothing())
targetEncoderParameters._response_column = getLabelCol()
targetEncoderParameters._teFoldColumnName = getFoldCol()
targetEncoderParameters._columnNamesToEncode = getInputCols()
targetEncoderParameters.setTrain(trainingFrame._key)

val builder = new TargetEncoderBuilder(targetEncoderParameters)
builder.trainModel().get() // Calling get() to wait until the model training is finished.
builder.getTargetEncoderModel()
} catch {
case e: IllegalStateException if e.getMessage.contains("We do not support multi-class target case") =>
throw new RuntimeException("The label column can not contain more than two unique values.")
mn-mikke marked this conversation as resolved.
Show resolved Hide resolved
}

override def copy(extra: ParamMap): H2OTargetEncoder = defaultCopy(extra)


//
// Parameter Setters
//
def setFoldCol(value: String): this.type = set(foldCol, value)

def setLabelCol(value: String): this.type = set(labelCol, value)

def setInputCols(values: Array[String]): this.type = set(inputCols, values)

def setHoldoutStrategy(value: String): this.type = {
set(holdoutStrategy, H2OAlgoParamsHelper.getValidatedEnumValue[H2OTargetEncoderHoldoutStrategy](value))
}

def setBlendedAvgEnabled(value: Boolean): this.type = set(blendedAvgEnabled, value)

def setBlendedAvgInflectionPoint(value: Double): this.type = set(blendedAvgInflectionPoint, value)

def setBlendedAvgSmoothing(value: Double): this.type = {
require(value > 0.0, "The smoothing value has to be a positive number.")
set(blendedAvgSmoothing, value)
}

def setNoise(value: Double): this.type = {
require(value >= 0.0, "Noise can't be a negative value.")
set(noise, value)
}

def setNoiseSeed(value: Long): this.type = set(noiseSeed, value)
}

object H2OTargetEncoder extends DefaultParamsReadable[H2OTargetEncoder]
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.features

import ai.h2o.sparkling.ml.params.H2OTargetEncoderParams
import org.apache.spark.h2o.Frame
import org.apache.spark.ml.PipelineStage
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

trait H2OTargetEncoderBase extends PipelineStage with H2OTargetEncoderParams {
override def transformSchema(schema: StructType): StructType = {
validateSchema(schema)
StructType(schema.fields ++ getOutputCols().map(StructField(_, DoubleType, nullable = true)))
mn-mikke marked this conversation as resolved.
Show resolved Hide resolved
}

private def validateSchema(flatSchema: StructType): Unit = {
require(getLabelCol() != null, "Label column can't be null!")
require(getInputCols() != null && getInputCols().nonEmpty, "The list of input columns can't be null or empty!")
val fields = flatSchema.fields
val fieldNames = fields.map(_.name)
require(fieldNames.contains(getLabelCol()),
s"The specified label column '${getLabelCol()}' was not found in the input dataset!")
getInputCols().foreach { inputCol =>
require(fieldNames.contains(inputCol),
s"The specified input column '$inputCol' was not found in the input dataset!")
}
val ioIntersection = getInputCols().intersect(getOutputCols())
require(ioIntersection.isEmpty,
s"""The columns [${ioIntersection.map(i => s"'$i'").mkString(", ")}] are specified
|as input columns and also as output columns. There can't be an overlap.""".stripMargin)
}

protected def convertRelevantColumnsToCategorical(frame: Frame): Unit = {
val relevantColumns = getInputCols() ++ Array(getLabelCol())
relevantColumns.foreach(frame.toCategoricalCol(_))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.models

import ai.h2o.automl.targetencoding.TargetEncoderModel
import ai.h2o.sparkling.ml.features.{H2OTargetEncoderBase, H2OTargetEncoderHoldoutStrategy}
import org.apache.spark.h2o.H2OContext
import org.apache.spark.h2o.utils.H2OSchemaUtils
import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import water.support.ModelSerializationSupport

class H2OTargetEncoderModel(
override val uid: String,
targetEncoderModel: TargetEncoderModel)
extends Model[H2OTargetEncoderModel] with H2OTargetEncoderBase with MLWritable {

lazy val mojoModel: H2OTargetEncoderMojoModel = {
mn-mikke marked this conversation as resolved.
Show resolved Hide resolved
val mojoData = ModelSerializationSupport.getMojoData(targetEncoderModel)
val model = new H2OTargetEncoderMojoModel()
copyValues(model).setMojoData(mojoData)
}

override def transform(dataset: Dataset[_]): DataFrame = {
if (inTrainingMode) {
transformTrainingDataset(dataset)
} else {
mojoModel.transform(dataset)
}
}

def transformTrainingDataset(dataset: Dataset[_]): DataFrame = {
val h2oContext = H2OContext.getOrCreate(SparkSession.builder().getOrCreate())
val temporaryColumn = getClass.getSimpleName + "_temporary_id"
val withIdDF = dataset.withColumn(temporaryColumn, monotonically_increasing_id)
val flatDF = H2OSchemaUtils.flattenDataFrame(withIdDF)
val relevantColumns = getInputCols() ++ Array(getLabelCol(), getFoldCol(), temporaryColumn).flatMap(Option(_))
val relevantColumnsDF = flatDF.select(relevantColumns.map(col(_)): _*)
val input = h2oContext.asH2OFrame(relevantColumnsDF)
convertRelevantColumnsToCategorical(input)
val holdoutStrategyId = H2OTargetEncoderHoldoutStrategy.valueOf(getHoldoutStrategy()).ordinal().asInstanceOf[Byte]
val outputFrame = targetEncoderModel.transform(input, holdoutStrategyId, getNoise(), getNoiseSeed())
val outputColumnsOnlyFrame = outputFrame.subframe(getOutputCols() ++ Array(temporaryColumn))
val outputColumnsOnlyDF = h2oContext.asDataFrame(outputColumnsOnlyFrame)
withIdDF
.join(outputColumnsOnlyDF, Seq(temporaryColumn), joinType="left")
.drop(temporaryColumn)
}

private def inTrainingMode: Boolean = {
val stackTrace = Thread.currentThread().getStackTrace()
stackTrace.exists(e => e.getMethodName == "fit" && e.getClassName == "org.apache.spark.ml.Pipeline")
}

override def copy(extra: ParamMap): H2OTargetEncoderModel = {
defaultCopy[H2OTargetEncoderModel](extra).setParent(parent)
}

override def write: MLWriter = mojoModel.write
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.models


import ai.h2o.sparkling.ml.features.H2OTargetEncoderBase

import hex.genmodel.easy.EasyPredictModelWrapper
import hex.genmodel.algos.targetencoder.TargetEncoderMojoModel

import org.apache.spark.h2o.converters.RowConverter
import org.apache.spark.ml.Model
import org.apache.spark.ml.h2o.models.{H2OMOJOFlattenedInput, H2OMOJOReadable, H2OMOJOWritable}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import water.support.ModelSerializationSupport

class H2OTargetEncoderMojoModel(override val uid: String) extends Model[H2OTargetEncoderMojoModel]
with H2OTargetEncoderBase with H2OMOJOWritable with H2OMOJOFlattenedInput {

override protected def inputColumnNames: Array[String] = getInputCols()

override protected def outputColumnName: String = getClass.getSimpleName + "_output"

def this() = this(Identifiable.randomUID(getClass.getSimpleName))

override def transform(dataset: Dataset[_]): DataFrame = {
import org.apache.spark.sql.DatasetExtensions._
val outputCols = getOutputCols()
val udfWrapper = H2OTargetEncoderMojoUdfWrapper(getMojoData(), outputCols)
val withPredictionsDF = applyPredictionUdf(dataset, _ => udfWrapper.mojoUdf)
withPredictionsDF
.withColumns(outputCols, outputCols.zipWithIndex.map{ case (c, i) => col(outputColumnName)(i) as c })
.drop(outputColumnName)
}

override def copy(extra: ParamMap): H2OTargetEncoderMojoModel = defaultCopy(extra)
}

/**
* The class holds all necessary dependencies of udf that needs to be serialized.
*/
case class H2OTargetEncoderMojoUdfWrapper(mojoData: Array[Byte], outputCols: Array[String]) {
@transient private lazy val easyPredictModelWrapper: EasyPredictModelWrapper = {
val model = ModelSerializationSupport
.getMojoModel(mojoData)
.asInstanceOf[TargetEncoderMojoModel]
val config = new EasyPredictModelWrapper.Config()
config.setModel(model)
config.setConvertUnknownCategoricalLevelsToNa(true)
mn-mikke marked this conversation as resolved.
Show resolved Hide resolved
config.setConvertInvalidNumbersToNa(true)
new EasyPredictModelWrapper(config)
}

val mojoUdf = udf[Array[Option[Double]], Row] { r: Row =>
val inputRowData = RowConverter.toH2ORowData(r)
try {
val prediction = easyPredictModelWrapper.transformWithTargetEncoding(inputRowData)
prediction.transformations.map(Some(_))
} catch {
case _: Throwable => outputCols.map(_ => None)
}
}
}

object H2OTargetEncoderMojoModel extends H2OMOJOReadable[H2OTargetEncoderMojoModel]