From c1ce0d786416cb98776fb598e80c58a56626ac03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20H=C3=A1va?= Date: Thu, 26 Apr 2018 12:45:42 +0200 Subject: [PATCH] [SW-824] Fix NPE in mojo pipeline predictions (#689) (cherry picked from commit a0b0029dd3cd3acef459108eee41c9aeea1fce8e) --- .../ml/h2o/models/H2OMOJOPipelineModel.scala | 5 ++--- .../spark/models/H2OMOJOPipelineModelTest.scala | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/ml/src/main/scala/org/apache/spark/ml/h2o/models/H2OMOJOPipelineModel.scala b/ml/src/main/scala/org/apache/spark/ml/h2o/models/H2OMOJOPipelineModel.scala index b6ecf2e0df..5201784506 100644 --- a/ml/src/main/scala/org/apache/spark/ml/h2o/models/H2OMOJOPipelineModel.scala +++ b/ml/src/main/scala/org/apache/spark/ml/h2o/models/H2OMOJOPipelineModel.scala @@ -24,7 +24,6 @@ import ai.h2o.mojos.runtime.readers.MojoPipelineReaderBackendFactory import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.h2o.utils.H2OSchemaUtils -import org.apache.spark.ml.h2o.models.H2OMOJOModel.createFromMojo import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util._ import org.apache.spark.ml.{Model => SparkModel} @@ -58,10 +57,10 @@ class H2OMOJOPipelineModel(val mojoData: Array[Byte], override val uid: String) r: Row => val m = getOrCreateModel() val builder = m.getInputFrameBuilder - val data = r.getValuesMap[Any](names).values.toArray.map(_.toString).zip(r.getValuesMap[Any](names).keys) + val data = r.getValuesMap[Any](names).filter{case (_, v) => v != null}.values.toArray.map(_.toString).zip(r.getValuesMap[Any](names).keys) val rowBuilder = builder.getMojoRowBuilder - data.foreach{ + data.foreach { case (colData, colName) => rowBuilder.setValue(colName, colData) } diff --git a/ml/src/test/scala/org/apache/spark/ml/spark/models/H2OMOJOPipelineModelTest.scala b/ml/src/test/scala/org/apache/spark/ml/spark/models/H2OMOJOPipelineModelTest.scala index 105c89f7b9..aece1249a4 100644 --- a/ml/src/test/scala/org/apache/spark/ml/spark/models/H2OMOJOPipelineModelTest.scala +++ b/ml/src/test/scala/org/apache/spark/ml/spark/models/H2OMOJOPipelineModelTest.scala @@ -60,6 +60,23 @@ class H2OMOJOPipelineModelTest extends FunSuite with SparkTestContext { println(preds.mkString("\n")) } + /** + * The purpose of this test is to simply pass and don't throw NullPointerException + */ + test("Prediction with null as row element"){ + val df = spark.read.option("header", "true").csv("examples/smalldata/prostate/prostate.csv") + // Test mojo + val mojo = H2OMOJOPipelineModel.createFromMojo( + this.getClass.getClassLoader.getResourceAsStream("mojo2data/pipeline.mojo"), + "prostate_pipeline.mojo") + + import spark.implicits._ + val rdd = sc.parallelize(Seq(Row("1", "0", "65", "1", "2", "1", "1.4", "0", null))) + val df2 = spark.createDataFrame(rdd, df.first().schema) + val preds = mojo.transform(df2) + // materialize the frame to see that it is passing + preds.collect() + } private def assertPredictedValues(preds: Array[Row]): Unit = { assert(preds(0).getSeq[String](0).head == "65.36320409515132") assert(preds(1).getSeq[String](0).head == "64.96902128114817")