Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mn-mikke committed Jul 13, 2020
1 parent d5db83a commit d154e3a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
24 changes: 20 additions & 4 deletions ml/src/test/scala/ai/h2o/sparkling/ml/algos/H2OGLRMTestSuite.scala
Expand Up @@ -19,7 +19,8 @@ package ai.h2o.sparkling.ml.algos

import ai.h2o.sparkling.{SharedH2OTestContext, TestUtils}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions.bround
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FunSuite, Matchers}
Expand All @@ -33,27 +34,42 @@ class H2OGLRMTestSuite extends FunSuite with Matchers with SharedH2OTestContext
.option("header", "true")
.option("inferSchema", "true")
.csv(TestUtils.locate("smalldata/iris/iris_wheader.csv"))
.drop("class")

private lazy val Array(trainingDataset, testingDataset) = dataset.randomSplit(Array(0.9, 0.1), 42)

test("H2OGLRM Pipeline serialization and deserialization") {
import spark.implicits._

val algo = new H2OGLRM()
.setK(3)
.setLoss("Quadratic")
.setGammaX(0.5)
.setGammaY(0.5)
.setExpandUserY(false)
.setSeed(42)
.setTransform("standardize")

def roundResult(dataFrame: DataFrame): DataFrame = {
dataFrame.select(
'sepal_len,
'sepal_wid,
'petal_len,
'petal_wid,
bround($"prediction".getItem(0), 5) as "prediction0",
bround($"prediction".getItem(1), 5) as "prediction1",
bround($"prediction".getItem(2), 5) as "prediction2")
}

val pipeline = new Pipeline().setStages(Array(algo))
pipeline.write.overwrite().save("ml/build/glrm_pipeline")
val loadedPipeline = Pipeline.load("ml/build/glrm_pipeline")
val model = loadedPipeline.fit(trainingDataset)
val expected = model.transform(testingDataset)
val expected = roundResult(model.transform(testingDataset))
expected.show()

model.write.overwrite().save("ml/build/glrm_pipeline_model")
val loadedModel = PipelineModel.load("ml/build/glrm_pipeline_model")
val result = loadedModel.transform(testingDataset)
val result = roundResult(loadedModel.transform(testingDataset))

TestUtils.assertDataFramesAreIdentical(expected, result)
}
Expand Down
Expand Up @@ -48,7 +48,7 @@ trait H2OMOJOPredictionDimReduction {
}

def extractDimReductionSimplePredictionColContent(): Column = {
col(s"${getDetailedPredictionCol()}.fields")
col(s"${getDetailedPredictionCol()}.dimensions")
}
}

Expand Down

0 comments on commit d154e3a

Please sign in to comment.