Skip to content

Commit

Permalink
[SW-734] Make sure we use the unique key names in split method (#586)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubhava committed Feb 21, 2018
1 parent f1a994a commit c09705f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 35 deletions.
22 changes: 4 additions & 18 deletions ml/src/main/scala/org/apache/spark/ml/h2o/algos/H2OAlgorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ abstract class H2OAlgorithm[P <: Model.Parameters : ClassTag, M <: SparkModel[M]
// check if we need to do any splitting
if ($(ratio) < 1.0) {
// need to do splitting
val keys = split(input, hc)
getParams._train = keys(0)
val keys = H2OFrameSupport.split(input, Seq(Key.rand(), Key.rand()), Seq($(ratio)))
getParams._train = keys(0)._key
if (keys.length > 1) {
getParams._valid = keys(1)
getParams._valid = keys(1)._key
}
} else {
getParams._train = input._key
Expand Down Expand Up @@ -110,21 +110,7 @@ abstract class H2OAlgorithm[P <: Model.Parameters : ClassTag, M <: SparkModel[M]

@Since("1.6.0")
override def write: MLWriter = new H2OAlgorithmWriter(this)

private def split(fr: H2OFrame, hc: H2OContext): Array[Key[Frame]] = {
val trainKey = Key.make[Frame]("train")
val validKey = Key.make[Frame]("valid")
val keys = Array(trainKey, validKey)
val ratios = Array[Double]($(ratio))

val splitter = new FrameSplitter(fr, ratios, keys, null)
water.H2O.submitTask(splitter)
// return results
splitter.getResult

keys
}


def defaultFileName: String
}

Expand Down
20 changes: 3 additions & 17 deletions ml/src/main/scala/org/apache/spark/ml/h2o/algos/H2OAutoML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ class H2OAutoML(val automlBuildSpec: Option[AutoMLBuildSpec], override val uid:
// check if we need to do any splitting
if (getRatio() < 1.0) {
// need to do splitting
val keys = split(input, hc)
spec.input_spec.training_frame = keys(0)
val keys = H2OFrameSupport.split(input, Seq(Key.rand(), Key.rand()), Seq(getRatio()))
spec.input_spec.training_frame = keys(0)._key
if (keys.length > 1) {
spec.input_spec.validation_frame = keys(1)
spec.input_spec.validation_frame = keys(1)._key
}
} else {
spec.input_spec.training_frame = input._key
Expand Down Expand Up @@ -98,20 +98,6 @@ class H2OAutoML(val automlBuildSpec: Option[AutoMLBuildSpec], override val uid:
model
}

private def split(fr: H2OFrame, hc: H2OContext): Array[Key[Frame]] = {
val trainKey = Key.make[Frame]("train")
val validKey = Key.make[Frame]("valid")
val keys = Array(trainKey, validKey)
val ratios = Array[Double](getRatio())

val splitter = new FrameSplitter(fr, ratios, keys, null)
water.H2O.submitTask(splitter)
// return results
splitter.getResult

keys
}

@DeveloperApi
override def transformSchema(schema: StructType): StructType = {
schema
Expand Down

0 comments on commit c09705f

Please sign in to comment.