diff --git a/src/main/scala/io/github/jsarni/CaraModel.scala b/src/main/scala/io/github/jsarni/CaraModel.scala index 6b678d7..4e9847e 100644 --- a/src/main/scala/io/github/jsarni/CaraModel.scala +++ b/src/main/scala/io/github/jsarni/CaraModel.scala @@ -5,6 +5,7 @@ import io.github.jsarni.DatasetLoader.CaraLoader import io.github.jsarni.PipelineParser.{CaraParser, CaraPipeline} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.ml.tuning.{CrossValidator, TrainValidationSplit} import scala.util.Try @@ -25,9 +26,42 @@ final class CaraModel(yamlPath: String, datasetPath: String, format: String, sav def generateReport(model: PipelineModel) : Try[Unit] = ??? - private def generateModel(caraPipeline: CaraPipeline): Try[Pipeline] = ??? + + private def generateModel(caraPipeline: CaraPipeline) : Try[Pipeline] = Try { + val pipeline = caraPipeline.pipeline + val evaluator = caraPipeline.evaluator + val tuningStage = caraPipeline.tuner.tuningStage + val methodeName = "set" + caraPipeline.tuner.paramName + val model = tuningStage match { + case "CrossValidator" => { + val paramValue = caraPipeline.tuner.paramValue.toInt + val crossValidatorModel = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setParallelism(2) - private def train(pipeline: Pipeline, dataset: Dataset[_]): Try[PipelineModel] = Try { + crossValidatorModel.getClass.getMethod(methodeName, paramValue.getClass ) + .invoke(crossValidatorModel,paramValue.asInstanceOf[java.lang.Integer]) + + new Pipeline().setStages(Array(crossValidatorModel)) + } + case "TrainValidationSplit" => { + val paramValue = caraPipeline.tuner.paramValue.toDouble + val validationSplitModel = new TrainValidationSplit() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setParallelism(2) + + validationSplitModel.getClass.getMethod(methodeName, paramValue.getClass ) + .invoke(validationSplitModel,paramValue.asInstanceOf[java.lang.Double]) + + new Pipeline().setStages(Array(validationSplitModel)) + } + } + model + } + + private def train(pipeline: Pipeline , dataset: Dataset[_]): Try[PipelineModel] = Try { pipeline.fit(dataset) } diff --git a/src/test/scala/io/github/jsarni/CaraModelTest.scala b/src/test/scala/io/github/jsarni/CaraModelTest.scala new file mode 100644 index 0000000..e65b79d --- /dev/null +++ b/src/test/scala/io/github/jsarni/CaraModelTest.scala @@ -0,0 +1,50 @@ +package io.github.jsarni +import io.github.jsarni.CaraStage.TuningStage.TuningStageDescription +import io.github.jsarni.PipelineParser.CaraPipeline +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, RegressionEvaluator} +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{CrossValidator, TrainValidationSplit} +import org.apache.spark.sql.SparkSession + +import scala.util.Try + + +class CaraModelTest extends TestBase { + "generateModel" should "Return validation model with the right method and params" in { + val lr = new LinearRegression() + .setMaxIter(10) + + val crossEvaluator = new BinaryClassificationEvaluator + val crossTuner = TuningStageDescription("CrossValidator", "NumFolds", "2") + val splitEvaluator = new RegressionEvaluator + val splitTuner = TuningStageDescription("TrainValidationSplit", "TrainRatio", "0.6") + + implicit val spark: SparkSession = + SparkSession.builder() + .appName("CaraML") + .master("local[1]") + .getOrCreate() + + val caraModel = new CaraModel("YamlPath", "datasetPath", "format", "savePath")(spark) + val pipeline = new Pipeline() + .setStages(Array(lr)) + val crossCaraPipeline = CaraPipeline(pipeline, crossEvaluator, crossTuner) + val splitCaraPipeline = CaraPipeline(pipeline, splitEvaluator, splitTuner) + val method = PrivateMethod[Try[Pipeline]]('generateModel) + + val crossModel = caraModel.invokePrivate(method(crossCaraPipeline)) + val splitModel = caraModel.invokePrivate(method(splitCaraPipeline)) + + crossModel.isSuccess shouldBe true + crossModel.get.getStages.length shouldBe 1 + crossModel.get.getStages.head.isInstanceOf[CrossValidator] shouldBe true + crossModel.get.getStages.head.asInstanceOf[CrossValidator].getNumFolds shouldBe 2 + + splitModel.isSuccess shouldBe true + splitModel.get.getStages.length shouldBe 1 + splitModel.get.getStages.head.isInstanceOf[TrainValidationSplit] shouldBe true + splitModel.get.getStages.head.asInstanceOf[TrainValidationSplit].getTrainRatio shouldBe 0.6 + + } +}