Skip to content

Commit

Permalink
[SW-1281] Fix bad representation of predictionCol on H2OMOJOModel (#1199
Browse files Browse the repository at this point in the history
)

(cherry picked from commit ae3f9c0)
(cherry picked from commit a0b6c64)
  • Loading branch information
jakubhava committed May 15, 2019
1 parent 52077ee commit b21ac58
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class H2OMOJOModel(override val uid: String)
case _ => throw new RuntimeException("Unknown model category")
}

Seq(StructField(getOutputCol(), StructType(fields), nullable = false))
Seq(StructField(getPredictionCol(), StructType(fields), nullable = false))
}

private def supportsCalibratedProbabilities(): Boolean = {
Expand Down Expand Up @@ -163,7 +163,7 @@ class H2OMOJOModel(override val uid: String)
val flattenedDF = H2OSchemaUtils.flattenDataFrame(dataset.toDF())
val relevantColumnNames = flattenedDF.columns.intersect(getFeaturesCols())
val args = relevantColumnNames.map(flattenedDF(_))
flattenedDF.select(col("*"), getModelUdf()(struct(args: _*)).as(getOutputCol()))
flattenedDF.select(col("*"), getModelUdf()(struct(args: _*)).as(getPredictionCol()))
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class H2OMOJOPipelineModel(override val uid: String)
val args = relevantColumnNames.map(flattenedDF(_))

// get the altered frame
val frameWithPredictions = flattenedDF.select(col("*"), modelUdf(relevantColumnNames)(struct(args: _*)).as(getPredictionCol))
val frameWithPredictions = flattenedDF.select(col("*"), modelUdf(relevantColumnNames)(struct(args: _*)).as(getPredictionCol()))

val fr = if (getNamedMojoOutputColumns()) {

Expand All @@ -150,12 +150,12 @@ class H2OMOJOPipelineModel(override val uid: String)
var frameWithExtractedPredictions: DataFrame = frameWithPredictions
getOutputNames().indices.foreach { idx =>
frameWithExtractedPredictions = frameWithExtractedPredictions.withColumn(tempColNames(idx),
selectFromArray(idx)(frameWithExtractedPredictions.col(getPredictionCol + ".preds")))
selectFromArray(idx)(frameWithExtractedPredictions.col(getPredictionCol() + ".preds")))
}

// Transform the columns at the top level under "output" column
val nestedPredictionCols = tempColNames.indices.map { idx => tempCols(idx).alias(getOutputNames()(idx)) }
val frameWithNestedPredictions = frameWithExtractedPredictions.withColumn(getPredictionCol, struct(nestedPredictionCols: _*))
val frameWithNestedPredictions = frameWithExtractedPredictions.withColumn(getPredictionCol(), struct(nestedPredictionCols: _*))

// Remove the temporary columns at the top level and return
val frameWithoutTempCols = frameWithNestedPredictions.drop(tempColNames: _*)
Expand All @@ -172,7 +172,7 @@ class H2OMOJOPipelineModel(override val uid: String)

def predictionSchema(): Seq[StructField] = {
val fields = StructField("original", ArrayType(DoubleType)) :: Nil
Seq(StructField(getPredictionCol, StructType(fields), nullable = false))
Seq(StructField(getPredictionCol(), StructType(fields), nullable = false))
}

override def transformSchema(schema: StructType): StructType = {
Expand All @@ -193,10 +193,10 @@ class H2OMOJOPipelineModel(override val uid: String)
val func = udf[Double, Double] {
identity
}
func(col(s"$getPredictionCol.`$column`")).alias(column)
func(col(s"${getPredictionCol()}.`$column`")).alias(column)
} else {
val func = selectFromArray(getOutputNames().indexOf(column))
func(col(s"$getPredictionCol.preds")).alias(column)
func(col(s"${getPredictionCol()}.preds")).alias(column)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,48 @@
package org.apache.spark.ml.h2o.models

import org.apache.hadoop.fs.Path
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.util.Utils
import org.json4s.jackson.JsonMethods._
import org.json4s.{DefaultFormats, JObject, JsonAST}

private[models] class H2OMOJOReader[T <: HasMojoData] extends DefaultParamsReader[T] {

private def getAndSetParams(
instance: Params,
metadata: Metadata,
skipParams: Option[List[String]] = None): Unit = {
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
if (skipParams == None || !skipParams.get.contains(paramName)) {
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}

override def load(path: String): T = {
val model = super.load(path)
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]

val parsedParams = metadata.params.asInstanceOf[JsonAST.JObject].obj.map(_._1)
val allowedParams = instance.params.map(_.name)
val filteredParams = parsedParams.diff(allowedParams)

getAndSetParams(instance, metadata, Some(filteredParams))
val model = instance.asInstanceOf[T]

val inputPath = new Path(path, H2OMOJOProps.serializedFileName)
val fs = inputPath.getFileSystem(SparkSession.builder().getOrCreate().sparkContext.hadoopConfiguration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,66 +17,41 @@
package org.apache.spark.ml.h2o.param

import org.apache.spark.ml.param._
import water.util.DeprecatedMethod

/**
* Parameters which need to be available on the model itself for prediction purposes. This can't be backed
* byt H2OAlgoParamsHelper as at the time of prediction we might be using mojo and binary parameters are not available.
*/
trait H2OMOJOModelParams extends DeprecatableParams {

override protected def renamingMap: Map[String, String] = Map(
"predictionCol" -> "labelCol"
)
trait H2OMOJOModelParams extends Params {

//
// Param definitions
//
private val labelCol: Param[String] = new Param[String](this, "labelCol", "Label column name")
private val predictionCol: Param[String] = new Param[String](this, "predictionCol", "Prediction column name")
protected final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols", "Name of feature columns")
private val outputCol: Param[String] = new Param[String](this, "outputCol", "Column where predictions are created")
private val convertUnknownCategoricalLevelsToNa = new BooleanParam(this,
"convertUnknownCategoricalLevelsToNa",
"Convert unknown categorical levels to NA during predictions")
"If set to 'true', the model converts unknown categorical levels to NA during making predictions.")
//
// Default values
//
setDefault(labelCol -> "label")
setDefault(featuresCols -> Array.empty[String])
setDefault(outputCol -> "prediction_output")
setDefault(predictionCol -> "prediction_output")
setDefault(convertUnknownCategoricalLevelsToNa -> false)

//
// Getters
//
@DeprecatedMethod
def getPredictionsCol(): String = getLabelCol()

@DeprecatedMethod
def getLabelCol(): String = $(labelCol)
def getPredictionCol(): String = $(predictionCol)

def getFeaturesCols(): Array[String] = $(featuresCols)

@DeprecatedMethod
def getOutputCol(): String = $(outputCol)

def getConvertUnknownCategoricalLevelsToNa(): Boolean = $(convertUnknownCategoricalLevelsToNa)


//
// Setters
//
@DeprecatedMethod
def setFeaturesCols(cols: Array[String]): this.type = set(featuresCols, cols)

@DeprecatedMethod
def setPredictionCol(value: String): this.type = setLabelCol(value)

@DeprecatedMethod
def setLabelCol(value: String): this.type = set(labelCol, value)

@DeprecatedMethod
def setOutputCol(value: String): this.type = set(outputCol, value)

def setConvertUnknownCategoricalLevelsToNa(value: Boolean): this.type = set(convertUnknownCategoricalLevelsToNa, value)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
package org.apache.spark.ml.h2o.param

import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasPredictionCol

/**
* Parameters which need to be available on the model itself for prediction purposes. This can't be backed
* byt H2OAlgoParamsHelper as at the time of prediction we might be using mojo and binary parameters are not available.
*/
trait H2OMOJOPipelineModelParams extends Params with HasPredictionCol {
trait H2OMOJOPipelineModelParams extends Params {

//
// Param definitions
//
private val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name")
private val namedMojoOutputColumns: Param[Boolean] = new BooleanParam(this, "namedMojoOutputColumns", "Mojo Output is not stored" +
" in the array but in the properly named columns")
protected final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols", "Name of feature columns")
Expand All @@ -36,13 +36,18 @@ trait H2OMOJOPipelineModelParams extends Params with HasPredictionCol {
// Default values
//
setDefault(
predictionCol -> "prediction",
namedMojoOutputColumns -> true,
featuresCols -> Array.empty[String]
)


//
// Getters
//

def getPredictionCol(): String = $(predictionCol)

def getNamedMojoOutputColumns() = $(namedMojoOutputColumns)

def getFeaturesCols(): Array[String] = $(featuresCols)
Expand Down

0 comments on commit b21ac58

Please sign in to comment.