Skip to content

Commit

Permalink
Feature/yaml parser (#7)
Browse files Browse the repository at this point in the history
* Added Evaluator parser

* Evolution of parser

* Added tuner parser
  • Loading branch information
jsarni committed May 25, 2021
1 parent 3968405 commit 8b661fb
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 23 deletions.
42 changes: 40 additions & 2 deletions src/main/scala/io/github/jsarni/CaraStage/CaraStageMapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package io.github.jsarni.CaraStage

import io.github.jsarni.CaraStage.DatasetStage.CaraDataset
import io.github.jsarni.CaraStage.ModelStage._
import io.github.jsarni.CaraStage.TuningStage.TuningStageDescription
import org.apache.spark.ml.evaluation._

import scala.util.Try
import scala.util.{Try, Success, Failure}

trait CaraStageMapper {

Expand All @@ -13,7 +15,8 @@ trait CaraStageMapper {

def mapModelStage(stageDescription: CaraStageDescription): CaraModel = {
stageDescription.stageName match {
case "LogisticRegression" => LogisticRegression(stageDescription.params)
case "LogisticRegression" =>
LogisticRegression(stageDescription.params)
case _ => throw
new Exception(s"${stageDescription.stageName} is not a valid Cara Stage name. Please verify your Yaml File")
}
Expand All @@ -27,4 +30,39 @@ trait CaraStageMapper {
}
}

def mapEvaluator(evaluatorName: String): Evaluator = {
evaluatorName match {
case "RegressionEvaluator" => new RegressionEvaluator()
case _ =>
throw
new Exception(s"${evaluatorName} is not a valid SparkML Validator name. Please verify your Yaml File")
}
}

def mapTuner(tuningStageDesc: TuningStageDescription): TuningStageDescription = {
tuningStageDesc.tuningStage match {
case "CrossValidator" =>
if (!tuningStageDesc.paramName.equals("NumFolds"))
throw new IllegalArgumentException("The only parameter available for CrossValidator is NumFolds")
Try(tuningStageDesc.paramValue.toInt) match {
case Success(_) =>
tuningStageDesc
case Failure(_) =>
throw new IllegalArgumentException("The NumFolds parameter value must be an Integer")
}
case "TrainValidationSplit" =>
if (!tuningStageDesc.paramName.equals("TrainRatio"))
throw new IllegalArgumentException("The only parameter available for TrainValidationSplit is TrainRatio")
Try(tuningStageDesc.paramValue.toDouble) match {
case Success(value) =>
if (value > 1 || value < 0)
tuningStageDesc
else
throw new IllegalArgumentException("The TrainRation parameter value must be a Double between 0 and 1")
case Failure(_) =>
throw new IllegalArgumentException("The TrainRation parameter value must be a Double between 0 and 1")
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package io.github.jsarni.CaraStage.TuningStage

case class TuningStageDescription(tuningStage: String, paramName: String, paramValue: String)
72 changes: 69 additions & 3 deletions src/main/scala/io/github/jsarni/PipelineParser/CaraParser.scala
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
package io.github.jsarni.PipelineParser

import com.fasterxml.jackson.databind.JsonNode
import io.github.jsarni.CaraStage.TuningStage.TuningStageDescription
import io.github.jsarni.CaraStage.{CaraStage, CaraStageDescription, CaraStageMapper}
import io.github.jsarni.CaraYaml.CaraYaml
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.{Pipeline, PipelineStage}

import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.{Try, Success, Failure}

class CaraParser(caraYaml: CaraYaml) extends ParserUtils with CaraStageMapper{

val contentTry = caraYaml.loadFile()

def parse(): Try[Pipeline] = {
def build(): Try[CaraPipeline] = {
for {
pipeline <- parsePipeline()
evaluator <- parseEvaluator()
} yield CaraPipeline(pipeline, evaluator)
}

private[PipelineParser] def parsePipeline(): Try[Pipeline] = {
for {
content <- contentTry
stagesDescriptions <- extractStages(content)
Expand All @@ -22,8 +31,27 @@ class CaraParser(caraYaml: CaraYaml) extends ParserUtils with CaraStageMapper{
} yield pipeline
}

private[PipelineParser] def parseEvaluator(): Try[Evaluator] = {
for {
content <- contentTry
evaluatorName <- extractEvaluator(content)
evaluator = mapEvaluator(evaluatorName)
} yield evaluator
}

private[PipelineParser] def parseTuner(): Try[TuningStageDescription] = {
for {
content <- contentTry
tunerDesc <- extractTuner(content)
validatedTunerDesc = mapTuner(tunerDesc)
} yield validatedTunerDesc
}



private[PipelineParser] def extractStages(fileContent: JsonNode): Try[List[CaraStageDescription]] = Try {
val stagesList = fileContent.at(s"/CaraPipeline").iterator().asScala.toList
val stagesList =
fileContent.at(s"/CaraPipeline").iterator().asScala.toList.filter(_.has("stage"))
val stages = stagesList.map{
stageDesc =>
val name = stageDesc.at("/stage").asText()
Expand All @@ -50,6 +78,44 @@ class CaraParser(caraYaml: CaraYaml) extends ParserUtils with CaraStageMapper{
stages
}

private[PipelineParser] def extractEvaluator(fileContent: JsonNode): Try[String] = Try {

val stagesList = fileContent.at(s"/CaraPipeline").iterator().asScala.toList.filter(_.has("evaluator"))

val evaluatorList = stagesList.map{ stageDesc =>stageDesc.at("/evaluator").asText()}

evaluatorList.length match {
case 1 => evaluatorList.head
case _ =>
throw new Exception("Error: You must define exactly one SparkML Evaluator")
}
}

private[PipelineParser] def extractTuner(fileContent: JsonNode): Try[TuningStageDescription] = {

val tunersList = fileContent.at(s"/CaraPipeline").iterator().asScala.toList.filter(_.has("tuner"))

tunersList.length match {
case l if l <= 1 =>
val tunerJson = tunersList.head
val tunerName = tunerJson.at("/tuner").textValue()

val paramsJson = tunerJson.at("/params")
val paramList = paramsJson.iterator().asScala.toList
paramList.length match {
case 1 =>
val paramName = paramList.flatMap { r => r.fieldNames().asScala.toList }.head
val paramValue = paramList.head.at(s"/$paramName").asText()

Success(TuningStageDescription(tunerName, paramName, paramValue))
case _ =>
Failure(new IllegalArgumentException("Tuners must have exactly one param"))
}
case _ =>
Failure(new IllegalArgumentException("Error: You must define exactly one SparkML Evaluator"))
}
}

private[PipelineParser] def parseStage(stageDescription: CaraStageDescription): Try[Any] =
for {
stageClass <- Try(Class.forName(s"io.github.jsarni.CaraStage.ModelStage.${stageDescription.stageName}"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.github.jsarni.PipelineParser

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.Evaluator

case class CaraPipeline(pipeline: Pipeline, evaluator: Evaluator)
4 changes: 4 additions & 0 deletions src/test/resources/cara_for_build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ CaraPipeline:
- MaxIter: 10
- RegParam: 0.3
- ElasticNetParam: 0.1
- evaluator: RegressionEvaluator
- tuner: CrossValidator
params:
- NumFolds: 3
8 changes: 8 additions & 0 deletions src/test/resources/cara_two_evaluator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CaraPipeline:
- stage: LogisticRegression
params:
- MaxIter: 10
- RegParam: 0.3
- ElasticNetParam: 0.1
- tuner: Tuner 1
- tuner: Tuner 2
8 changes: 8 additions & 0 deletions src/test/resources/cara_zero_evaluator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CaraPipeline:
- stage: LogisticRegression
params:
- MaxIter: 10
- RegParam: 0.3
- ElasticNetParam: 0.1
- evaluator: RegressionEvaluator
- evaluator: OtherEvaluator
136 changes: 118 additions & 18 deletions src/test/scala/io/github/jsarni/PipelineParser/CaraParserTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ package io.github.jsarni.PipelineParser

import io.github.jsarni.CaraStage.{CaraStage, CaraStageDescription}
import io.github.jsarni.CaraStage.ModelStage.LogisticRegression
import io.github.jsarni.CaraStage.TuningStage.TuningStageDescription
import io.github.jsarni.CaraYaml.CaraYaml
import io.github.jsarni.TestBase
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.ml.classification.{LogisticRegression => SparkLR}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.evaluation.Evaluator

import scala.util.Try

class CaraParserTest extends TestBase {

"extractStages" should "return parse the yaml description file to a json object" in {
"extractTuner" should "return parse the yaml description file to a json object" in {
val caraPath = getClass.getResource("/cara.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)
Expand All @@ -30,21 +34,6 @@ class CaraParserTest extends TestBase {
result.get should contain theSameElementsAs expectedResult
}


// "parseStage" should "return parse the yaml description file to a json object" in {
// val caraPath = getClass.getResource("/cara.yaml").getPath
// val caraYaml = CaraYaml(caraPath)
// val caraParser = new CaraParser(caraYaml)
//
//
// val stageDesc =
// CaraStageDescription("TestStage", Map("MaxIter" -> "10", "RegParam" -> "0.3", "ElasticNetParam" -> "0.1"))
//
// val res = caraParser.parseStage(stageDesc)
// print(res.get)
//
// }

"parseSingleStageMap" should "parse a CaraStageDescription to a CaraStage " in {
val caraPath = getClass.getResource("/cara.yaml").getPath
val caraParser = new CaraParser(CaraYaml(caraPath))
Expand Down Expand Up @@ -126,16 +115,127 @@ class CaraParserTest extends TestBase {
res.get.getStages shouldBe new Pipeline().setStages(stagesList.toArray).getStages
}

"parse" should "build the described Pipeline of the Yaml File" in {
"parsePipeline" should "build the described Pipeline of the Yaml File" in {
val caraPath = getClass.getResource("/cara_for_build.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val res = caraParser.parse()

val parsePipeline = PrivateMethod[Try[Pipeline]]('parsePipeline)
val res = caraParser.invokePrivate(parsePipeline())
val exprectedRes = new Pipeline().setStages(Array(new SparkLR().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.1)))

res.isSuccess shouldBe true
res.get.getStages.map(_.extractParamMap().toSeq.map(_.value)).head should contain theSameElementsAs
exprectedRes.getStages.map(_.extractParamMap().toSeq.map(_.value)).head
}

"extractTuner" should "get the correct Evaluator Name from the Yaml File" in {
val caraPath = getClass.getResource("/cara_for_build.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val myJson = caraYaml.loadFile()

val extractEvaluator = PrivateMethod[Try[String]]('extractEvaluator)
val result = caraParser.invokePrivate(extractEvaluator(myJson.get))

result.isSuccess shouldBe true
result.get shouldBe "RegressionEvaluator"
}

it should "Raise an exception if there is no evaluator specified" in {
val caraPath = getClass.getResource("/cara_zero_evaluator.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val myJson = caraYaml.loadFile()

val extractEvaluator = PrivateMethod[Try[String]]('extractEvaluator)
val result = caraParser.invokePrivate(extractEvaluator(myJson.get))

result.isFailure shouldBe true
}

it should "Raise an exception if there is more than one evaluator specified" in {
val caraPath = getClass.getResource("/cara_two_evaluator.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val myJson = caraYaml.loadFile()

val extractEvaluator = PrivateMethod[Try[String]]('extractEvaluator)
val result = caraParser.invokePrivate(extractEvaluator(myJson.get))

result.isFailure shouldBe true
}

"parseEvaluator" should "build the described evaluator of the Yaml File" in {
val caraPath = getClass.getResource("/cara_for_build.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val parseEvaluator = PrivateMethod[Try[Evaluator]]('parseEvaluator)
val res = caraParser.invokePrivate(parseEvaluator())

res.isSuccess shouldBe true
res.get.isInstanceOf[RegressionEvaluator] shouldBe true
}

"extractTuner" should "get the correct Tuner Description from the Yaml File" in {
val caraPath = getClass.getResource("/cara_for_build.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val myJson = caraYaml.loadFile()

val extractTuner = PrivateMethod[Try[TuningStageDescription]]('extractTuner)
val result = caraParser.invokePrivate(extractTuner(myJson.get))

result.isSuccess shouldBe true
result.get shouldBe TuningStageDescription("CrossValidator", "NumFolds", "3")
}

it should "raise an exception ilf there is more than one tuner in the Yaml File" in {
val caraPath = getClass.getResource("/cara_two_evaluator.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val myJson = caraYaml.loadFile()

val extractTuner = PrivateMethod[Try[TuningStageDescription]]('extractTuner)
val result = caraParser.invokePrivate(extractTuner(myJson.get))

result.isFailure shouldBe true
an [IllegalArgumentException] should be thrownBy result.get
}

"parseTuner" should "build the described Tuner of the Yaml File" in {
val caraPath = getClass.getResource("/cara_for_build.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val myJson = caraYaml.loadFile()

val extractTuner = PrivateMethod[Try[TuningStageDescription]]('extractTuner)
val result = caraParser.invokePrivate(extractTuner(myJson.get))

result.isSuccess shouldBe true
result.get shouldBe TuningStageDescription("CrossValidator", "NumFolds", "3")
}

"build" should "build the described Pipeline of the Yaml File" in {
val caraPath = getClass.getResource("/cara_for_build.yaml").getPath
val caraYaml = CaraYaml(caraPath)
val caraParser = new CaraParser(caraYaml)

val res = caraParser.build()

val exprectedRes = new Pipeline().setStages(Array(new SparkLR().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.1)))

res.isSuccess shouldBe true
res.get.evaluator.isInstanceOf[RegressionEvaluator] shouldBe true
res.get.pipeline.getStages.map(_.extractParamMap().toSeq.map(_.value)).head should contain theSameElementsAs
exprectedRes.getStages.map(_.extractParamMap().toSeq.map(_.value)).head
}
}

0 comments on commit 8b661fb

Please sign in to comment.