Skip to content

Commit

Permalink
An implementation via usage of MOJO models
Browse files Browse the repository at this point in the history
  • Loading branch information
mn-mikke committed Jun 13, 2019
1 parent f106c5f commit 467aa02
Show file tree
Hide file tree
Showing 13 changed files with 451 additions and 82 deletions.
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.h2o.features;

public enum H2OTargetEncoderHoldoutStrategy {
LeaveOneOut,
KFold,
None
}
Expand Up @@ -16,10 +16,10 @@
*/
package org.apache.spark.ml.h2o.features

import ai.h2o.automl.targetencoding.TargetEncoder
import org.apache.spark.h2o.H2OContext
import ai.h2o.automl.targetencoding._
import org.apache.spark.h2o.{Frame, H2OContext}
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.h2o.models.{H2OTargetEncoderModel, H2OTargetEncoderTrainingModel}
import org.apache.spark.ml.h2o.models.H2OTargetEncoderModel
import org.apache.spark.ml.h2o.param.H2OTargetEncoderParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
Expand All @@ -35,13 +35,49 @@ class H2OTargetEncoder(override val uid: String)
override def fit(dataset: Dataset[_]): H2OTargetEncoderModel = {
val h2oContext = H2OContext.getOrCreate(SparkSession.builder().getOrCreate())
val input = h2oContext.asH2OFrame(dataset.toDF())
val targetEncoder = new TargetEncoder(getInputCols())
val encodingMap = targetEncoder.prepareEncodingMap(input, getLabelCol(), getFoldCol())
val model = new H2OTargetEncoderTrainingModel(uid, targetEncoder, encodingMap, dataset).setParent(this)
changeRelevantColumnsToCategorical(input)
val targetEncoderModel = trainTargetEncodingModel(input)
val model = new H2OTargetEncoderModel(uid, targetEncoderModel).setParent(this)
copyValues(model)
}

private def trainTargetEncodingModel(trainingFrame: Frame) = {
val targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters()
val blending = Option(getBlending())
targetEncoderParameters._withBlending = blending.isDefined
targetEncoderParameters._blendingParams = blending.map(_.toBlendingParams()).getOrElse(null)
targetEncoderParameters._response_column = getLabelCol()
targetEncoderParameters._teFoldColumnName = getFoldCol()
targetEncoderParameters._columnNamesToEncode = getInputCols()
targetEncoderParameters.setTrain(trainingFrame._key)

val builder = new TargetEncoderBuilder(targetEncoderParameters)
builder.trainModel().get()
}

override def copy(extra: ParamMap): H2OTargetEncoder = defaultCopy(extra)


//
// Parameter Setters
//
def setFoldCol(value: String): this.type = set(foldCol, value)

def setLabelCol(value: String): this.type = set(labelCol, value)

def setInputCols(values: Array[String]): this.type = set(inputCols, values)

def setHoldoutStrategy(value: H2OTargetEncoderHoldoutStrategy): this.type = set(holdoutStrategy, value)

def setBlending(settings: H2OTargetEncoderBlendingSettings): this.type = set(blending, settings)

def setNoise(settings: H2OTargetEncoderNoiseSettings): this.type = set(noise, settings)
}

object H2OTargetEncoder extends DefaultParamsReadable[H2OTargetEncoder]

case class H2OTargetEncoderBlendingSettings(inflectionPoint: Double, smoothing: Double) {
def toBlendingParams(): BlendingParams = new BlendingParams(inflectionPoint, smoothing)
}

case class H2OTargetEncoderNoiseSettings(amount: Double = 0.01, seed: Long = -1)
Expand Up @@ -19,12 +19,11 @@ package org.apache.spark.ml.h2o.models

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.h2o.param.H2OMOJOModelParams
import org.apache.spark.ml.util.{MLWritable, MLWriter}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.ml.{Model => SparkModel}

abstract class H2OMOJOModelBase[T <: SparkModel[T]]
extends SparkModel[T] with H2OMOJOModelParams with MLWritable with HasMojoData {
abstract class H2OMOJOModelBase[T <: H2OMOJOModelBase[T]] extends SparkModel[T]
with H2OMOJOModelParams with HasMojoData with H2OMOJOWritable {

protected def getPredictionSchema(): Seq[StructField]

Expand All @@ -35,6 +34,4 @@ abstract class H2OMOJOModelBase[T <: SparkModel[T]]
// and model will be able to still provide a prediction
StructType(schema.fields ++ getPredictionSchema())
}

override def write: MLWriter = new H2OMOJOWriter(this, getMojoData)
}
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.h2o.models

import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.{MLWritable, MLWriter}

trait H2OMOJOWritable extends MLWritable with Params with HasMojoData {
override def write: MLWriter = new H2OMOJOWriter(this, getMojoData())
}
Expand Up @@ -16,24 +16,54 @@
*/
package org.apache.spark.ml.h2o.models

import ai.h2o.automl.targetencoding.TargetEncoderModel
import org.apache.spark.h2o.H2OContext
import org.apache.spark.ml.Model
import org.apache.spark.ml.h2o.param.H2OTargetEncoderParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.util.{MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import water.support.ModelSerializationSupport

class H2OTargetEncoderModel(
override val uid: String,
encodingMap: Map[String, Map[String, Array[Int]]])
extends Model[H2OTargetEncoderModel] with H2OTargetEncoderParams {
targetEncoderModel: TargetEncoderModel)
extends Model[H2OTargetEncoderModel] with H2OTargetEncoderParams with MLWritable {

lazy val mojoModel: H2OTargetEncoderMojoModel = {
val mojoData = ModelSerializationSupport.getMojoData(targetEncoderModel)
val model = new H2OTargetEncoderMojoModel()
copyValues(model).setMojoData(mojoData)
}

override def transform(dataset: Dataset[_]): DataFrame = {
// TODO The way how to apply encoding table for individual records should be defined in H2O-3 and exposed out.
getInputCols().zip(getOutputCols()).foldLeft(dataset.toDF()){
case (df, (in, out)) => df.withColumn(out, col(in))
if(inTrainingMode) {
transformTrainingDataset(dataset)
} else {
mojoModel.transform(dataset)
}
}

def transformTrainingDataset(dataset: Dataset[_]): DataFrame = {
val h2oContext = H2OContext.getOrCreate(SparkSession.builder().getOrCreate())
val input = h2oContext.asH2OFrame(dataset.toDF())
changeRelevantColumnsToCategorical(input)
val noise = getNoise()
val holdoutStrategyId = getHoldoutStrategy().ordinal().asInstanceOf[Byte]
val output = if (noise == null) {
targetEncoderModel.transform(input, holdoutStrategyId, 0L)
} else {
targetEncoderModel.transform(input, holdoutStrategyId, noise.amount, 0L)
}
h2oContext.asDataFrame(output)
}

private def inTrainingMode: Boolean = {
val stackTrace = Thread.currentThread().getStackTrace()
stackTrace.exists(e => e.getMethodName == "fit" && e.getClassName == "org.apache.spark.ml.Pipeline")
}

override def copy(extra: ParamMap): H2OTargetEncoderModel = defaultCopy(extra)

override def write: MLWriter = mojoModel.write
}
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.h2o.models

import hex.genmodel.algos.targetencoder.TargetEncoderMojoModel
import org.apache.spark.h2o.converters.RowConverter
import org.apache.spark.h2o.utils.H2OSchemaUtils
import org.apache.spark.ml.Model
import org.apache.spark.ml.h2o.param.H2OTargetEncoderParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.util.Identifiable
import water.support.ModelSerializationSupport

class H2OTargetEncoderMojoModel(override val uid: String) extends Model[H2OTargetEncoderMojoModel]
with H2OMOJOWritable with H2OTargetEncoderParams {

def this() = this(Identifiable.randomUID("H2OTargetEncoderMojoModel"))

override def transform(dataset: Dataset[_]): DataFrame = {
val outputCols = getOutputCols()
val udfWrapper = H2OTargetEncoderMojoUdfWrapper(getMojoData(), getOutputCols())
val outputColumnName = this.getClass.getSimpleName + "_output"
val flattenedDF = H2OSchemaUtils.flattenDataFrame(dataset.toDF())
val relevantColumnNames = flattenedDF.columns.intersect(getInputCols())
val args = relevantColumnNames.map(flattenedDF(_))
flattenedDF
.withColumn(outputColumnName, udfWrapper.mojoUdf(struct(args: _*)))
.withColumns(outputCols, outputCols.zipWithIndex.map{ case (c, i) => col(outputColumnName)(i) as c })
.drop(outputColumnName)
}

override def copy(extra: ParamMap): H2OTargetEncoderMojoModel = defaultCopy(extra)

}

/**
* The class holds all necessary dependencies of udf that needs to be serialized.
*/
case class H2OTargetEncoderMojoUdfWrapper(mojoData: Array[Byte], outputCols: Array[String]) {
@transient private lazy val mojoModel = ModelSerializationSupport
.getMojoModel(mojoData)
.asInstanceOf[TargetEncoderMojoModel]

val mojoUdf = udf[Array[Option[Double]], Row] { r: Row =>
val inputRowData = RowConverter.toH2ORowData(r)
try {
val outputRawData = mojoModel.transform0(inputRowData)
outputCols.map(c => Option(outputRawData.get(c).asInstanceOf[Double]))
} catch {
case _: Throwable => outputCols.map(_ => None)
}
}
}

object H2OTargetEncoderMojoModel extends H2OMOJOReadable[H2OTargetEncoderMojoModel]

This file was deleted.

Expand Up @@ -23,9 +23,10 @@ import org.apache.spark.sql.SparkSession
private[models] trait HasMojoData {

// Called during init of the model
def setMojoData(mojoData : Array[Byte]): Unit = {
def setMojoData(mojoData : Array[Byte]): this.type = {
this.mojoData = mojoData
broadcastMojo = SparkSession.builder().getOrCreate().sparkContext.broadcast(this.mojoData)
this
}

protected def getMojoData(): Array[Byte] = broadcastMojo.value
Expand Down
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.h2o.param

import org.apache.spark.ml.param.{Param, Params}
import org.json4s._
import org.json4s.jackson.Serialization.{read, write}

class CaseClassParam[T <: AnyRef with Product : Manifest](parent: Params, name: String, doc: String, isValid: T => Boolean)
extends Param[T](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) = this(parent, name, doc, _ => true)

@transient private implicit val formats = DefaultFormats

override def jsonEncode(value: T): String = write[T](value)

override def jsonDecode(json: String): T = {
if (json == null) null.asInstanceOf[T] else read[T](json)
}
}

0 comments on commit 467aa02

Please sign in to comment.