From cbcdd91864354243bbc41573b3357a9d26fec2dd Mon Sep 17 00:00:00 2001 From: jsarni <45796640+jsarni@users.noreply.github.com> Date: Tue, 25 May 2021 18:19:04 +0200 Subject: [PATCH] Feature/yaml parser (#9) * Added Evaluator parser * Evolution of parser * Added tuner parser * Added companion object to CaraParser * Added tuner to CaraPipeline --- .../scala/io/github/jsarni/PipelineParser/CaraParser.scala | 3 ++- .../scala/io/github/jsarni/PipelineParser/CaraPipeline.scala | 3 ++- .../scala/io/github/jsarni/PipelineParser/CaraParserTest.scala | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/scala/io/github/jsarni/PipelineParser/CaraParser.scala b/src/main/scala/io/github/jsarni/PipelineParser/CaraParser.scala index e50a375..3382d1f 100644 --- a/src/main/scala/io/github/jsarni/PipelineParser/CaraParser.scala +++ b/src/main/scala/io/github/jsarni/PipelineParser/CaraParser.scala @@ -18,7 +18,8 @@ class CaraParser(caraYaml: CaraYaml) extends ParserUtils with CaraStageMapper{ for { pipeline <- parsePipeline() evaluator <- parseEvaluator() - } yield CaraPipeline(pipeline, evaluator) + tunerDesc <- parseTuner() + } yield CaraPipeline(pipeline, evaluator, tunerDesc) } private[PipelineParser] def parsePipeline(): Try[Pipeline] = { diff --git a/src/main/scala/io/github/jsarni/PipelineParser/CaraPipeline.scala b/src/main/scala/io/github/jsarni/PipelineParser/CaraPipeline.scala index 2a00924..693acee 100644 --- a/src/main/scala/io/github/jsarni/PipelineParser/CaraPipeline.scala +++ b/src/main/scala/io/github/jsarni/PipelineParser/CaraPipeline.scala @@ -1,6 +1,7 @@ package io.github.jsarni.PipelineParser +import io.github.jsarni.CaraStage.TuningStage.TuningStageDescription import org.apache.spark.ml.Pipeline import org.apache.spark.ml.evaluation.Evaluator -case class CaraPipeline(pipeline: Pipeline, evaluator: Evaluator) +case class CaraPipeline(pipeline: Pipeline, evaluator: Evaluator, tuner: TuningStageDescription) diff --git a/src/test/scala/io/github/jsarni/PipelineParser/CaraParserTest.scala b/src/test/scala/io/github/jsarni/PipelineParser/CaraParserTest.scala index 200247c..d8eb551 100644 --- a/src/test/scala/io/github/jsarni/PipelineParser/CaraParserTest.scala +++ b/src/test/scala/io/github/jsarni/PipelineParser/CaraParserTest.scala @@ -237,5 +237,6 @@ class CaraParserTest extends TestBase { res.get.evaluator.isInstanceOf[RegressionEvaluator] shouldBe true res.get.pipeline.getStages.map(_.extractParamMap().toSeq.map(_.value)).head should contain theSameElementsAs exprectedRes.getStages.map(_.extractParamMap().toSeq.map(_.value)).head + res.get.tuner shouldBe TuningStageDescription("CrossValidator", "NumFolds", "3") } }