Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/model schema #3

Merged
merged 4 commits into from
May 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions src/main/scala/io/github/jsarni/CaraStage/CaraStage.scala
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
package io.github.jsarni.CaraStage

import org.apache.spark.ml.PipelineStage
import java.lang.reflect.Method
import scala.util.Try

trait CaraStage {

//TODO: Add builder function
def build(): PipelineStage
def build(): Try[PipelineStage]

// Function to get methode by name and do invoke with the right params types and values
def GetMethode(lr : PipelineStage, field : Any, field_name : String) = {
val MethodeName = "set"+field_name
def getMethode(stage : PipelineStage, field : Any, fieldName : String): Method = {
val methodeName = "set" + fieldName
field match {
case _ : Any if field.getClass == Array[Double]().getClass => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Array[Double]].getClass )
case _ : Any if field.getClass == Array[String]().getClass => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Array[String]].getClass )
case _ : Any if field.getClass == Array[Float]().getClass => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Array[Float]].getClass )
case _ : Any if field.getClass == Array[Short]().getClass => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Array[Short]].getClass )
case _ : Any if field.getClass == Array[Char]().getClass => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Array[Char]].getClass )
case _ : Any if field.getClass == Array[Byte]().getClass => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Array[Byte]].getClass )
case _ : Any if field.getClass == Array[Long]().getClass => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Array[Long]].getClass )
case _ : Any if field.getClass == Array[Int]().getClass => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Array[Int]].getClass )
case _ : java.lang.Boolean => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Boolean].getClass )
case _ : java.lang.Double => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Double].getClass )
case _ : java.lang.Float => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Float].getClass )
case _ : java.lang.Short => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Short].getClass )
case _ : java.lang.Character => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Char].getClass )
case _ : java.lang.Byte => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Byte].getClass )
case _ :java.lang.Long => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Long].getClass)
case _: java.lang.Integer => lr.getClass.getMethod(MethodeName, field.asInstanceOf[Int].getClass)
case _ : java.lang.String => lr.getClass.getMethod(MethodeName, field.getClass )
case _ : Any if field.getClass == Array[Double]().getClass => stage.getClass.getMethod(methodeName, field.asInstanceOf[Array[Double]].getClass )
case _ : Any if field.getClass == Array[String]().getClass => stage.getClass.getMethod(methodeName, field.asInstanceOf[Array[String]].getClass )
case _ : Any if field.getClass == Array[Float]().getClass => stage.getClass.getMethod(methodeName, field.asInstanceOf[Array[Float]].getClass )
case _ : Any if field.getClass == Array[Short]().getClass => stage.getClass.getMethod(methodeName, field.asInstanceOf[Array[Short]].getClass )
case _ : Any if field.getClass == Array[Char]().getClass => stage.getClass.getMethod(methodeName, field.asInstanceOf[Array[Char]].getClass )
case _ : Any if field.getClass == Array[Byte]().getClass => stage.getClass.getMethod(methodeName, field.asInstanceOf[Array[Byte]].getClass )
case _ : Any if field.getClass == Array[Long]().getClass => stage.getClass.getMethod(methodeName, field.asInstanceOf[Array[Long]].getClass )
case _ : Any if field.getClass == Array[Int]().getClass => stage.getClass.getMethod(methodeName, field.asInstanceOf[Array[Int]].getClass )
case _ : java.lang.Boolean => stage.getClass.getMethod(methodeName, field.asInstanceOf[Boolean].getClass )
case _ : java.lang.Double => stage.getClass.getMethod(methodeName, field.asInstanceOf[Double].getClass )
case _ : java.lang.Float => stage.getClass.getMethod(methodeName, field.asInstanceOf[Float].getClass )
case _ : java.lang.Short => stage.getClass.getMethod(methodeName, field.asInstanceOf[Short].getClass )
case _ : java.lang.Character => stage.getClass.getMethod(methodeName, field.asInstanceOf[Char].getClass )
case _ : java.lang.Byte => stage.getClass.getMethod(methodeName, field.asInstanceOf[Byte].getClass )
case _ : java.lang.Long => stage.getClass.getMethod(methodeName, field.asInstanceOf[Long].getClass)
case _ : java.lang.Integer => stage.getClass.getMethod(methodeName, field.asInstanceOf[Int].getClass)
case _ : java.lang.String => stage.getClass.getMethod(methodeName, field.getClass )
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package io.github.jsarni.CaraStage.ModelStage
import io.github.jsarni.CaraStage.Annotation.MapperConstructor
import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.classification.{LogisticRegression => log}
import org.apache.spark.ml.classification.{LogisticRegression => SparkLR}
import scala.util.Try



Expand All @@ -17,27 +18,27 @@ case class LogisticRegression(MaxIter: Option[Int], RegParam: Option[Double], El
params.get("MaxIter").map(_.toInt),
params.get("RegParam").map(_.toDouble),
params.get("ElasticNetParam").map(_.toDouble),
params.get("Family").map(_.toString),
params.get("FeaturesCol").map(_.toString),
params.get("Family"),
params.get("FeaturesCol"),
params.get("FitIntercept").map(_.toBoolean),
params.get("PredictionCol").map(_.toString),
params.get("ProbabilityCol").map(_.toString),
params.get("RawPredictionCol").map(_.toString),
params.get("PredictionCol"),
params.get("ProbabilityCol"),
params.get("RawPredictionCol"),
params.get("Standardization").map(_.toBoolean),
params.get("Thresholds").map(_.split(",").map(_.toDouble)),
params.get("Tol").map(_.toDouble),
params.get("WeightCol").map(_.toString)
params.get("WeightCol")

)
}

override def build(): PipelineStage = {
val lr = new log()
override def build(): Try[PipelineStage] = Try {
val lr = new SparkLR()
val definedFields = this.getClass.getDeclaredFields.filter(f => f.get(this).asInstanceOf[Option[Any]].isDefined)
val names = definedFields.map(f => f.getName)
val values = definedFields.map(f => f.get(this))
val zipFields = names zip values
zipFields.map(f=> GetMethode(lr,f._2 match {case Some(s) => s },f._1).invoke(lr,(f._2 match {case Some(value) => value.asInstanceOf[f._2.type ] })))
zipFields.map(f=> getMethode(lr,f._2 match {case Some(s) => s },f._1).invoke(lr,(f._2 match {case Some(value) => value.asInstanceOf[f._2.type ] })))
lr

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class CaraParser(caraYaml: CaraYaml) extends ParserUtils with CaraStageMapper{
}

private[PipelineParser] def buildStages(stagesList: List[CaraStage]): Try[List[PipelineStage]] = {
Try(stagesList.map(_.build()))
Try(stagesList.map(_.build().get))
}

private[PipelineParser] def buildPipeline(mlStages: List[PipelineStage]): Try[Pipeline] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.github.jsarni.CaraStage.ModelStage

import io.github.jsarni.TestBase
import org.apache.spark.ml.classification.{LogisticRegression => SparkLR}


class LogisticRegressionTest extends TestBase {

"build" should "Create an lr model and set all parameters with there args values or set default ones" in {
val params = Map(
"MaxIter" -> "10",
"RegParam" -> "0.3",
"ElasticNetParam" -> "0.1",
"Family" -> "multinomial",
"FeaturesCol" -> "FeatureColname",
"FitIntercept" -> "True",
"PredictionCol" -> "Age",
"ProbabilityCol" -> "ProbaColname",
"RawPredictionCol"-> "RawPredictColname",
"Standardization" -> "True",
"Tol" -> "0.13",
"WeightCol" -> "WeightColname"
)
val lr = LogisticRegression(params)
val lrWithTwoParams = new SparkLR()
.setRegParam(0.8)
.setStandardization(false)

val expectedResult = List(
new SparkLR()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.1)
.setFamily("multinomial")
.setFeaturesCol("FeatureColname")
.setFitIntercept(true)
.setPredictionCol("Age")
.setProbabilityCol("ProbaColname")
.setRawPredictionCol("RawPredictColname")
.setStandardization(true).setTol(0.13)
.setWeightCol("WeightColname")
)
lr.build().isSuccess shouldBe true

val res = List(lr.build().get)
val resParameters = res.map(_.extractParamMap().toSeq.map(_.value))
val expectedParameters = expectedResult.map(_.extractParamMap().toSeq.map(_.value))

resParameters.head should contain theSameElementsAs expectedParameters.head

// Test default values of unset params
lrWithTwoParams.getMaxIter shouldBe 100
lrWithTwoParams.getFamily shouldBe "auto"
lrWithTwoParams.getTol shouldBe 0.000001

}
"GetMethode" should "Return the appropriate methode by it's name" in {
val params = Map(
"MaxIter" -> "10",
"RegParam" -> "0.3",
"ElasticNetParam" -> "0.1",
"Family" -> "multinomial",
"FeaturesCol" -> "FeatureColname",
"FitIntercept" -> "True",
"PredictionCol" -> "Age",
"ProbabilityCol" -> "ProbaColname",
"RawPredictionCol"-> "RawPredictColname",
"Standardization" -> "True",
"Tol" -> "0.13",
"WeightCol" -> "WeightColname"
)
val caraLr = LogisticRegression(params)
val model =caraLr.build().get.asInstanceOf[SparkLR]

caraLr.getMethode(model,10,"MaxIter").getName shouldBe "setMaxIter"
caraLr.getMethode(model,0.0,"RegParam").getName shouldBe "setRegParam"
caraLr.getMethode(model, false ,"Standardization").getName shouldBe "setStandardization"

}
}