Skip to content

Commit

Permalink
[SW-2474] Fix Monotone Constraints on GBM and XGBoost MOJO Model (#2376)
Browse files Browse the repository at this point in the history
* [SW-2474] Fix Monotone Constraints on GBM and XGBoost MOJO Model

* Fix serialization problem on Spark 2.1

* Revert: Fix serialization problem on Spark 2.1

* Use DictionaryParam on algorithm

* Fallback when reading mojo parameters
  • Loading branch information
mn-mikke committed Oct 30, 2020
1 parent 44c96ef commit 01fcdec
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package ai.h2o.sparkling.ml.params
import java.util

import ai.h2o.sparkling.H2OFrame
import hex.KeyValue

import scala.collection.JavaConverters._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class MOJOParameterTestSuite extends FunSuite with SharedH2OTestContext with Mat
val algorithm = new H2OGBM()
.setLabelCol("CAPSULE")
.setSeed(1)
.setMonotoneConstraints(Map("AGE" -> 1.0, "RACE" -> -1.0))
val mojo = algorithm.fit(dataset)

compareParameterValues(algorithm, mojo)
Expand All @@ -57,7 +58,10 @@ class MOJOParameterTestSuite extends FunSuite with SharedH2OTestContext with Mat
}

test("Test MOJO parameters on XGBoost") {
val algorithm = new H2OXGBoost().setLabelCol("CAPSULE").setSeed(1)
val algorithm = new H2OXGBoost()
.setLabelCol("CAPSULE")
.setSeed(1)
.setMonotoneConstraints(Map("AGE" -> 1.0, "RACE" -> -1.0))
val mojo = algorithm.fit(dataset)

compareParameterValues(algorithm, mojo)
Expand Down
6 changes: 6 additions & 0 deletions py/src/ai/h2o/sparkling/ml/params/H2OTypeConverters.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,12 @@ def scalaMapStringStringToDictStringAny(value):
else:
raise TypeError("Invalid type.")

@staticmethod
def nullableScalaMapStringStringToDictStringAny(value):
if value is None:
return None
else:
H2OTypeConverters.scalaMapStringStringToDictStringAny(value)

@staticmethod
def scalaArrayToPythonArray(array):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
# limitations under the License.
#

from ai.h2o.sparkling.ml.params.H2OTypeConverters import H2OTypeConverters


class HasMonotoneConstraintsOnMOJO:

def getMonotoneConstraints(self):
return self._java_obj.getMonotoneConstraints()
value = self._java_obj.getMonotoneConstraints()
return H2OTypeConverters.nullableScalaMapStringStringToDictStringAny(value)
5 changes: 2 additions & 3 deletions py/tests/unit/with_runtime_sparkling/test_mojo_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def testGBMParameters(prostateDataset):
features = ['AGE', 'RACE', 'DPROS', 'DCAPS', 'PSA']
algorithm = H2OGBM(seed=1, labelCol="CAPSULE", featuresCols=features)
algorithm = H2OGBM(seed=1, labelCol="CAPSULE", featuresCols=features, monotoneConstraints={'AGE': 1, 'RACE': -1})
model = algorithm.fit(prostateDataset)
compareParameterValues(algorithm, model)

Expand All @@ -33,7 +33,7 @@ def testDRFParameters(prostateDataset):

def testXGBoostParameters(prostateDataset):
features = ['AGE', 'RACE', 'DPROS', 'DCAPS', 'PSA']
algorithm = H2OXGBoost(seed=1, labelCol="CAPSULE", featuresCols=features)
algorithm = H2OXGBoost(seed=1, labelCol="CAPSULE", featuresCols=features, monotoneConstraints={'AGE': 1, 'RACE': -1})
model = algorithm.fit(prostateDataset)
compareParameterValues(algorithm, model)

Expand Down Expand Up @@ -88,7 +88,6 @@ def isMethodRelevant(method):
methods = filter(isMethodRelevant, dir(model))

for method in methods:
print(method)
modelValue = getattr(model, method)()
algorithmValue = getattr(algorithm, method)()
assert(valuesAreEqual(algorithmValue, modelValue))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,34 @@
package ai.h2o.sparkling.ml.params

import ai.h2o.sparkling.ml.models.SpecificMOJOParameters
import hex.KeyValue
import hex.genmodel.MojoModel
import hex.genmodel.attributes.parameters.KeyValue
import org.apache.spark.expose.Logging

import scala.collection.JavaConverters._

trait HasMonotoneConstraintsOnMOJO extends ParameterConstructorMethods with SpecificMOJOParameters with Logging {
private val monotoneConstraints = new NullableDictionaryParam[Double](
private val monotoneConstraints = new NullableMapStringDoubleParam(
this,
"monotoneConstraints",
"A key must correspond to a feature name and value could be 1 or -1")

def getMonotoneConstraints(): Map[String, Double] = {
val value = $(monotoneConstraints)
if (value == null) {
null
} else {
value.asScala.toMap
}
}
def getMonotoneConstraints(): Map[String, Double] = $(monotoneConstraints)

override private[sparkling] def setSpecificParams(h2oMojo: MojoModel): Unit = {
super.setSpecificParams(h2oMojo)
try {
val h2oParameters = h2oMojo._modelAttributes.getModelParameters()
val h2oParametersMap = h2oParameters.map(i => i.name -> i.actual_value).toMap
h2oParametersMap.get("monotone_constraints").foreach { value =>
val keyValues = value.asInstanceOf[Array[KeyValue]]
val javaMap = if (keyValues != null) {
keyValues.map(kv => kv.getKey -> kv.getValue).toMap.asJava
val objectArray = value.asInstanceOf[Array[AnyRef]]
val scalaMap = if (objectArray != null) {
val keyValues = objectArray.map(_.asInstanceOf[KeyValue])
keyValues.map(kv => kv.getKey -> kv.getValue).toMap[String, Double]
} else {
null
}
set(monotoneConstraints, javaMap)
set(monotoneConstraints, scalaMap)
}
} catch {
case e: Throwable => logError("An error occurred during a try to access H2O MOJO parameters.", e)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.params

import org.apache.spark.ml.param.{Param, Params}
import org.json4s._
import org.json4s.jackson.JsonMethods.{compact, parse, render}

class NullableMapStringDoubleParam(parent: Params, name: String, doc: String, isValid: Map[String, Double] => Boolean)
extends Param[Map[String, Double]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, _ => true)

override def jsonEncode(value: Map[String, Double]): String = {
val encoded = if (value == null) {
JNull
} else {
JObject(value.map(p => p._1 -> DoubleParam.jValueEncode(p._2)).toList)
}
compact(render(encoded))
}

override def jsonDecode(json: String): Map[String, Double] = {
parse(json) match {
case JNull => null
case JObject(pairs) =>
pairs.map {
case (name, value) =>
(name, DoubleParam.jValueDecode(value))
}.toMap
case _ =>
throw new IllegalArgumentException(s"Cannot decode $json to Map[String, Double].")
}
}
}
14 changes: 11 additions & 3 deletions scoring/src/main/scala/ai/h2o/sparkling/ml/utils/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ package ai.h2o.sparkling.ml.utils
import java.io.File

import hex.genmodel.{ModelMojoReader, MojoModel, MojoReaderBackendFactory}
import org.apache.spark.expose.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow

object Utils {
object Utils extends Logging {
def getMojoModel(mojoFile: File): MojoModel = {
val reader = MojoReaderBackendFactory.createReaderBackend(mojoFile.getAbsolutePath)
ModelMojoReader.readFrom(reader, true)
try {
val reader = MojoReaderBackendFactory.createReaderBackend(mojoFile.getAbsolutePath)
ModelMojoReader.readFrom(reader, true)
} catch {
case e: Throwable =>
logError(s"Reading a mojo model with metadata failed. Trying to load the model without metadata...", e)
val reader = MojoReaderBackendFactory.createReaderBackend(mojoFile.getAbsolutePath)
ModelMojoReader.readFrom(reader, false)
}
}

def arrayToRow[T](array: Array[T]): Row = new GenericRow(array.map(_.asInstanceOf[Any]))
Expand Down

0 comments on commit 01fcdec

Please sign in to comment.