New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SW-1281] Fix bad representation of predictionCol on H2OMOJOModel #1199
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nits, otherwise LGTM.
ml/src/main/scala/org/apache/spark/ml/h2o/models/H2OMOJOModel.scala
Outdated
Show resolved
Hide resolved
@@ -163,7 +163,7 @@ class H2OMOJOModel(override val uid: String) | |||
val flattenedDF = H2OSchemaUtils.flattenDataFrame(dataset.toDF()) | |||
val relevantColumnNames = flattenedDF.columns.intersect(getFeaturesCols()) | |||
val args = relevantColumnNames.map(flattenedDF(_)) | |||
flattenedDF.select(col("*"), getModelUdf()(struct(args: _*)).as(getOutputCol())) | |||
flattenedDF.select(col("*"), getModelUdf()(struct(args: _*)).as(getPredictionCol)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ditto
|
||
val parsedParams = metadata.params.asInstanceOf[JsonAST.JObject].obj.map(_._1) | ||
val allowedParams = instance.params.map(_.name) | ||
val filteredParams = parsedParams.filter(!allowedParams.contains(_)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about using theintersect
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks just for reference -> as per our discussion in the office, intersect is not a good fit, but diff
is
ml/src/main/scala/org/apache/spark/ml/h2o/models/H2OMOJOReader.scala
Outdated
Show resolved
Hide resolved
) (cherry picked from commit ae3f9c0)
) (cherry picked from commit ae3f9c0)
) (cherry picked from commit d3f0662)
) (cherry picked from commit d3f0662)
No description provided.