Skip to content

Commit

Permalink
[SW-1129] Fix support for unsupervised mojo models (#1050)
Browse files Browse the repository at this point in the history
(cherry picked from commit e37211b)
  • Loading branch information
jakubhava committed Jan 21, 2019
1 parent dc0a8ab commit 9337796
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.{ml, mllib}
import water.support.ModelSerializationSupport

import scala.reflect.{ClassTag, classTag}

class H2OMOJOModel(val mojoData: Array[Byte], override val uid: String)
extends SparkModel[H2OMOJOModel] with H2OModelParams with MLWritable {

Expand All @@ -59,6 +60,8 @@ class H2OMOJOModel(val mojoData: Array[Byte], override val uid: String)

case class WordEmbeddingPrediction(wordEmbeddings: util.HashMap[String, Array[Float]])

case class AnomalyPrediction(score: Double, normalizedScore: Double)

def predictionSchema(): Seq[StructField] = {
val fields = getOrCreateEasyModelWrapper().getModelCategory match {
case ModelCategory.Binomial => StructField("p0", DoubleType) :: StructField("p1", DoubleType) :: Nil
Expand Down Expand Up @@ -98,6 +101,10 @@ class H2OMOJOModel(val mojoData: Array[Byte], override val uid: String)
implicit def toWordEmbeddingPrediction(pred: AbstractPrediction) = WordEmbeddingPrediction(
pred.asInstanceOf[Word2VecPrediction].wordEmbeddings)

implicit def toAnomalyPrediction(pred: AbstractPrediction) = AnomalyPrediction(
pred.asInstanceOf[AnomalyDetectionPrediction].score,
pred.asInstanceOf[AnomalyDetectionPrediction].normalizedScore
)
def getModelUdf() = {
val modelUdf = {
getOrCreateEasyModelWrapper().getModelCategory match {
Expand All @@ -122,7 +129,10 @@ class H2OMOJOModel(val mojoData: Array[Byte], override val uid: String)
case ModelCategory.WordEmbedding => udf[WordEmbeddingPrediction, Row] { r: Row =>
getOrCreateEasyModelWrapper().predict(rowToRowData(r))
}
case _ => throw new RuntimeException("Unknown model category")
case ModelCategory.AnomalyDetection => udf[AnomalyPrediction, Row] { r: Row =>
getOrCreateEasyModelWrapper().predict(rowToRowData(r))
}
case _ => throw new RuntimeException("Unknown model category " + getOrCreateEasyModelWrapper().getModelCategory)
}
}
modelUdf
Expand Down Expand Up @@ -232,7 +242,7 @@ private[models] class H2OMOJOModelReader[T <: H2OMOJOModel : ClassTag]
override def load(path: String): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)

val inputPath = new Path(path, defaultFileName)
val inputPath = new Path(path, defaultFileName)
val fs = inputPath.getFileSystem(sc.hadoopConfiguration)
val qualifiedInputPath = inputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
val is = fs.open(qualifiedInputPath)
Expand All @@ -250,7 +260,7 @@ private[models] class H2OMOJOModelReader[T <: H2OMOJOModel : ClassTag]
}
}

class H2OMOJOModelHelper[T<: H2OMOJOModel](implicit m: ClassTag[T]) extends MLReadable[T]{
class H2OMOJOModelHelper[T <: H2OMOJOModel](implicit m: ClassTag[T]) extends MLReadable[T] {
val defaultFileName = "mojo_model"

@Since("1.6.0")
Expand All @@ -260,7 +270,7 @@ class H2OMOJOModelHelper[T<: H2OMOJOModel](implicit m: ClassTag[T]) extends MLRe
override def load(path: String): T = super.load(path)

def createFromMojo(path: String): T = {
val inputPath = new Path(path)
val inputPath = new Path(path)
val fs = inputPath.getFileSystem(SparkSession.builder().getOrCreate().sparkContext.hadoopConfiguration)
val qualifiedInputPath = inputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
val is = fs.open(qualifiedInputPath)
Expand All @@ -274,8 +284,12 @@ class H2OMOJOModelHelper[T<: H2OMOJOModel](implicit m: ClassTag[T]) extends MLRe
val sparkMojoModel = m.runtimeClass.getConstructor(classOf[Array[Byte]], classOf[String]).
newInstance(mojoData, uid).asInstanceOf[T]
// Reconstruct state of Spark H2O MOJO transformer based on H2O's Mojo
sparkMojoModel.setFeaturesCols(mojoModel.getNames.filter(_ != mojoModel.getResponseName))
sparkMojoModel.setPredictionCol(mojoModel.getResponseName)
if (mojoModel.isSupervised) {
sparkMojoModel.setFeaturesCols(mojoModel.getNames.filter(_ != mojoModel.getResponseName))
sparkMojoModel.setPredictionCol(mojoModel.getResponseName)
} else {
sparkMojoModel.setFeaturesCols(mojoModel.getNames)
}
sparkMojoModel
}
}
Expand Down
Binary file added ml/src/test/resources/isolation_forest.mojo
Binary file not shown.
11 changes: 11 additions & 0 deletions py/tests/tests_unit_mojo_predictions.py
Expand Up @@ -80,6 +80,17 @@ def test_h2o_mojo_model_serialization_in_pipeline(self):
model.write().overwrite().save( "file://" + os.path.abspath("build/test_spark_pipeline_model_mojo_model"))
PipelineModel.load( "file://" + os.path.abspath("build/test_spark_pipeline_model_mojo_model"))

def test_h2o_mojo_unsupervised(self):
mojo = H2OMOJOModel.create_from_mojo(
"file://" + os.path.abspath("../ml/src/test/resources/isolation_forest.mojo"))

row_for_scoring = Row("V1")

df = self._spark.createDataFrame(self._spark.sparkContext.
parallelize([(5.1,)]).
map(lambda r: row_for_scoring(*r)))
mojo.predict(df).repartition(1).collect()


if __name__ == '__main__':
generic_test_utils.run_tests([H2OMojoPredictionsTest], file_name="py_unit_tests_mojo_predictions_report")

0 comments on commit 9337796

Please sign in to comment.