Skip to content

Commit

Permalink
[SW-843] Fix data obtaining for mojo pipeline (#714)
Browse files Browse the repository at this point in the history
(cherry picked from commit 08f915f)
  • Loading branch information
jakubhava committed May 18, 2018
1 parent e7808d2 commit fb34c68
Showing 1 changed file with 33 additions and 22 deletions.
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.h2o.models
import java.io._

import ai.h2o.mojos.runtime.MojoPipeline
import ai.h2o.mojos.runtime.frame.MojoColumn
import ai.h2o.mojos.runtime.readers.MojoPipelineReaderBackendFactory
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
Expand Down Expand Up @@ -52,30 +53,40 @@ class H2OMOJOPipelineModel(val mojoData: Array[Byte], override val uid: String)
val outputCol = "prediction"

case class Mojo2Prediction(preds: List[String])

private val modelUdf = (names: Array[String]) =>
udf[Mojo2Prediction, Row] {
r: Row =>
val m = getOrCreateModel()
val builder = m.getInputFrameBuilder
val data = r.getValuesMap[Any](names).filter{case (_, v) => v != null}.values.toArray.map(_.toString).zip(r.getValuesMap[Any](names).keys)
val rowBuilder = builder.getMojoRowBuilder

data.foreach {
case (colData, colName) =>
rowBuilder.setValue(colName, colData)
}
builder.addRow(rowBuilder)
val output = m.transform(builder.toMojoFrame)
val predictions = output.getColumnNames.zipWithIndex.map { case (_, i) =>
val predictedRows =output.getColumnData(i).asInstanceOf[Array[_]]
if (predictedRows.length != 1) {
throw new RuntimeException("Invalid state, we predict on each row by row, independently at this moment.")
r: Row =>
val m = getOrCreateModel()
val builder = m.getInputFrameBuilder
val rowBuilder = builder.getMojoRowBuilder
val filtered = r.getValuesMap[Any](names).filter { case (n, _) => m.getInputMeta.contains(n) }
filtered.foreach {
case (colName, colData) =>
val data = if (colData == null) {
null
} else if (m.getInputMeta.getColumnType(colName).isnumeric && colData.toString.toLowerCase() == "true") {
1.toString
} else if (m.getInputMeta.getColumnType(colName).isnumeric && colData.toString.toLowerCase() == "false") {
0.toString
} else {
colData.toString
}
rowBuilder.setValue(colName.toString, data)
}
predictedRows(0).toString
}

Mojo2Prediction(predictions.toList)
}
builder.addRow(rowBuilder)
val output = m.transform(builder.toMojoFrame)
val predictions = output.getColumnNames.zipWithIndex.map { case (_, i) =>
val predictedRows = output.getColumnData(i).asInstanceOf[Array[_]]
if (predictedRows.length != 1) {
throw new RuntimeException("Invalid state, we predict on each row by row, independently at this moment.")
}
predictedRows(0).toString
}

Mojo2Prediction(predictions.toList)
}

def defaultFileName: String = H2OMOJOPipelineModel.defaultFileName

Expand Down Expand Up @@ -128,7 +139,7 @@ private[models] class H2OMOJOModelPipelineReader
override def load(path: String): H2OMOJOPipelineModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val inputPath = new Path(path, defaultFileName)
val inputPath = new Path(path, defaultFileName)
val fs = inputPath.getFileSystem(sc.hadoopConfiguration)
val qualifiedInputPath = inputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
val is = fs.open(qualifiedInputPath)
Expand Down Expand Up @@ -156,7 +167,7 @@ object H2OMOJOPipelineModel extends MLReadable[H2OMOJOPipelineModel] {
override def load(path: String): H2OMOJOPipelineModel = super.load(path)

def createFromMojo(path: String): H2OMOJOPipelineModel = {
val inputPath = new Path(path)
val inputPath = new Path(path)
val fs = inputPath.getFileSystem(SparkSession.builder().getOrCreate().sparkContext.hadoopConfiguration)
val qualifiedInputPath = inputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
val is = fs.open(qualifiedInputPath)
Expand Down

0 comments on commit fb34c68

Please sign in to comment.