-
Notifications
You must be signed in to change notification settings - Fork 0
/
CaraModel.scala
75 lines (60 loc) · 2.59 KB
/
CaraModel.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package io.github.jsarni
import io.github.jsarni.CaraYaml.CaraYamlReader
import io.github.jsarni.PipelineParser.{CaraParser, CaraPipeline}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder, TrainValidationSplit}
import scala.util.Try
final class CaraModel(yamlPath: String, dataset: Dataset[_], savePath: String)(implicit spark: SparkSession) {
val yaml = CaraYamlReader(yamlPath)
val parser = CaraParser(yaml)
def run(): Try[Unit] = for {
caraPipeline <- parser.build()
sparkPipeline <- generateModel(caraPipeline)
fittedModel <- train(sparkPipeline, dataset)
// _ <- generateReport(fittedModel)
_ <- save(fittedModel)
} yield ()
// def generateReport(model: PipelineModel) : Try[Unit] = ???
def evaluate(dataset: Dataset[_]): Dataset[_] = {
val model = PipelineModel.load(savePath)
model.transform(dataset)
}
private def generateModel(caraPipeline: CaraPipeline) : Try[Pipeline] = Try {
val pipeline = caraPipeline.pipeline
val evaluator = caraPipeline.evaluator
val tuningStage = caraPipeline.tuner.tuningStage
val methodeName = "set" + caraPipeline.tuner.paramName
val model = tuningStage match {
case "CrossValidator" => {
val paramValue = caraPipeline.tuner.paramValue.toInt
val crossValidatorModel = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(new ParamGridBuilder().build())
.setParallelism(2)
crossValidatorModel.getClass.getMethod(methodeName, paramValue.getClass )
.invoke(crossValidatorModel,paramValue.asInstanceOf[java.lang.Integer])
new Pipeline().setStages(Array(crossValidatorModel))
}
case "TrainValidationSplit" => {
val paramValue = caraPipeline.tuner.paramValue.toDouble
val validationSplitModel = new TrainValidationSplit()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(new ParamGridBuilder().build())
.setParallelism(2)
validationSplitModel.getClass.getMethod(methodeName, paramValue.getClass )
.invoke(validationSplitModel,paramValue.asInstanceOf[java.lang.Double])
new Pipeline().setStages(Array(validationSplitModel))
}
}
model
}
private def train(pipeline: Pipeline , dataset: Dataset[_]): Try[PipelineModel] = Try {
pipeline.fit(dataset)
}
private def save(model: PipelineModel) : Try[Unit] = Try {
model.write.save(savePath)
}
}