Skip to content

Commit

Permalink
[SW-824] Fix NPE in mojo pipeline predictions (#689)
Browse files Browse the repository at this point in the history
(cherry picked from commit a0b0029)
  • Loading branch information
jakubhava committed Apr 26, 2018
1 parent bb1ba60 commit c1ce0d7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
Expand Up @@ -24,7 +24,6 @@ import ai.h2o.mojos.runtime.readers.MojoPipelineReaderBackendFactory
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.h2o.utils.H2OSchemaUtils
import org.apache.spark.ml.h2o.models.H2OMOJOModel.createFromMojo
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Model => SparkModel}
Expand Down Expand Up @@ -58,10 +57,10 @@ class H2OMOJOPipelineModel(val mojoData: Array[Byte], override val uid: String)
r: Row =>
val m = getOrCreateModel()
val builder = m.getInputFrameBuilder
val data = r.getValuesMap[Any](names).values.toArray.map(_.toString).zip(r.getValuesMap[Any](names).keys)
val data = r.getValuesMap[Any](names).filter{case (_, v) => v != null}.values.toArray.map(_.toString).zip(r.getValuesMap[Any](names).keys)
val rowBuilder = builder.getMojoRowBuilder

data.foreach{
data.foreach {
case (colData, colName) =>
rowBuilder.setValue(colName, colData)
}
Expand Down
Expand Up @@ -60,6 +60,23 @@ class H2OMOJOPipelineModelTest extends FunSuite with SparkTestContext {
println(preds.mkString("\n"))
}

/**
* The purpose of this test is to simply pass and don't throw NullPointerException
*/
test("Prediction with null as row element"){
val df = spark.read.option("header", "true").csv("examples/smalldata/prostate/prostate.csv")
// Test mojo
val mojo = H2OMOJOPipelineModel.createFromMojo(
this.getClass.getClassLoader.getResourceAsStream("mojo2data/pipeline.mojo"),
"prostate_pipeline.mojo")

import spark.implicits._
val rdd = sc.parallelize(Seq(Row("1", "0", "65", "1", "2", "1", "1.4", "0", null)))
val df2 = spark.createDataFrame(rdd, df.first().schema)
val preds = mojo.transform(df2)
// materialize the frame to see that it is passing
preds.collect()
}
private def assertPredictedValues(preds: Array[Row]): Unit = {
assert(preds(0).getSeq[String](0).head == "65.36320409515132")
assert(preds(1).getSeq[String](0).head == "64.96902128114817")
Expand Down

0 comments on commit c1ce0d7

Please sign in to comment.