Skip to content

Commit

Permalink
Merge branch 'develop' into feature/cara_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jsarni committed May 25, 2021
2 parents 709100f + cbcdd91 commit 81edac1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
Expand Up @@ -18,7 +18,8 @@ class CaraParser(caraYaml: CaraYaml) extends ParserUtils with CaraStageMapper{
for {
pipeline <- parsePipeline()
evaluator <- parseEvaluator()
} yield CaraPipeline(pipeline, evaluator)
tunerDesc <- parseTuner()
} yield CaraPipeline(pipeline, evaluator, tunerDesc)
}

private[PipelineParser] def parsePipeline(): Try[Pipeline] = {
Expand Down
@@ -1,6 +1,7 @@
package io.github.jsarni.PipelineParser

import io.github.jsarni.CaraStage.TuningStage.TuningStageDescription
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.Evaluator

case class CaraPipeline(pipeline: Pipeline, evaluator: Evaluator)
case class CaraPipeline(pipeline: Pipeline, evaluator: Evaluator, tuner: TuningStageDescription)
Expand Up @@ -237,5 +237,6 @@ class CaraParserTest extends TestBase {
res.get.evaluator.isInstanceOf[RegressionEvaluator] shouldBe true
res.get.pipeline.getStages.map(_.extractParamMap().toSeq.map(_.value)).head should contain theSameElementsAs
exprectedRes.getStages.map(_.extractParamMap().toSeq.map(_.value)).head
res.get.tuner shouldBe TuningStageDescription("CrossValidator", "NumFolds", "3")
}
}

0 comments on commit 81edac1

Please sign in to comment.