Skip to content

Commit

Permalink
[SW-2445] Add logic of FrameUtils.guessParserSetup to Sparkling Water (
Browse files Browse the repository at this point in the history
…#2329)

* [SW-2445] Add logic of FrameUtils.guessParserSetup to Sparkling Water

* fix mojo parameter tests
  • Loading branch information
mn-mikke committed Sep 24, 2020
1 parent 57bfa19 commit e6b8d74
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
8 changes: 5 additions & 3 deletions extensions/src/main/scala/water/fvec/H2OFrame.scala
Expand Up @@ -24,7 +24,7 @@ import water._
import water.parser.DefaultParserProviders.GUESS_INFO
import water.parser.ParseSetup
import water.parser.ParseSetup._
import water.util.FrameUtils
import water.H2O

/**
* Wrapper around Java H2O Frame
Expand Down Expand Up @@ -179,6 +179,8 @@ object H2OFrame {
* @param uris URIs of files to parse
* @return guessed parser setup
*/
def parserSetup(userSetup: ParseSetup, uris: URI*): ParseSetup =
FrameUtils.guessParserSetup(defaultParserSetup(), uris: _*)
def parserSetup(userSetup: ParseSetup, uris: URI*): ParseSetup = {
val inKeys = uris.map(H2O.getPM.anyURIToKey(_)).toArray
return ParseSetup.guessSetup(inKeys, userSetup)
}
}
Expand Up @@ -84,6 +84,7 @@ class MOJOParameterTestSuite extends FunSuite with SharedH2OTestContext with Mat
val algorithm = new H2OGAM()
.setLabelCol("CAPSULE")
.setSeed(1)
.setLambdaValue(Array(0.5))
.setGamCols(Array("PSA", "AGE"))
.setNumKnots(Array(5, 5))
.setBs(Array(5, 5))
Expand Down
3 changes: 2 additions & 1 deletion py/tests/unit/with_runtime_sparkling/test_mojo_parameters.py
Expand Up @@ -49,7 +49,8 @@ def testGLMParameters(prostateDataset):

def testGAMParameters(prostateDataset):
features = ['AGE', 'RACE', 'DPROS', 'DCAPS', 'PSA']
algorithm = H2OGAM(seed=1, labelCol="CAPSULE", gamCols=["PSA", "AGE"], numKnots=[5, 5], featuresCols=features)
algorithm = H2OGAM(seed=1, labelCol="CAPSULE", gamCols=["PSA", "AGE"], numKnots=[5, 5], lambdaValue=[0.5],
featuresCols=features)
model = algorithm.fit(prostateDataset)
compareParameterValues(algorithm, model, ["getFeaturesCols"])

Expand Down

0 comments on commit e6b8d74

Please sign in to comment.