From 720d9dd90a4b7303163a1fbc4307f9e827cae64c Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Tue, 28 Sep 2021 20:49:53 +0200 Subject: [PATCH 01/16] added type hints to custom_obj and custom_eval for Spark persistence --- .../scala/spark/params/CustomParams.scala | 37 ++++++-- .../xgboost4j/scala/spark/CustomObj.scala | 89 +++++++++++++++++++ .../xgboost4j/scala/spark/EvalError.scala | 6 +- .../scala/spark/PersistenceSuite.scala | 37 +++++++- 4 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index 784be2aa0872..6eb996c2da99 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark.params import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.TrackerConf -import org.json4s.{DefaultFormats, Extraction, NoTypeHints} + +import org.json4s.{DefaultFormats, Extraction, NoTypeHints, ShortTypeHints, TypeHints} import org.json4s.jackson.JsonMethods.{compact, parse, render} import org.apache.spark.ml.param.{Param, ParamPair, Params} @@ -32,17 +33,28 @@ class CustomEvalParam( override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value) override def jsonEncode(value: EvalTrait): String = { - import org.json4s.jackson.Serialization - implicit val formats = Serialization.formats(NoTypeHints) + implicit val formats = DefaultFormats.withHints(CustomEvalParam.typeHints) compact(render(Extraction.decompose(value))) } override def jsonDecode(json: String): EvalTrait = { - implicit val formats = DefaultFormats + implicit val formats = DefaultFormats.withHints(CustomEvalParam.typeHints) parse(json).extract[EvalTrait] } } +object CustomEvalParam { + var typeHints: TypeHints = NoTypeHints + + final def addTypeHints(value: TypeHints): Unit = { + typeHints = typeHints + value + } + + final def addShortTypeHint(value: Class[_]): Unit = { + typeHints = typeHints + ShortTypeHints(List(value)) + } +} + class CustomObjParam( parent: Params, name: String, @@ -52,17 +64,28 @@ class CustomObjParam( override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value) override def jsonEncode(value: ObjectiveTrait): String = { - import org.json4s.jackson.Serialization - implicit val formats = Serialization.formats(NoTypeHints) + implicit val formats = DefaultFormats.withHints(CustomObjParam.typeHints) compact(render(Extraction.decompose(value))) } override def jsonDecode(json: String): ObjectiveTrait = { - implicit val formats = DefaultFormats + implicit val formats = DefaultFormats.withHints(CustomObjParam.typeHints) parse(json).extract[ObjectiveTrait] } } +object CustomObjParam { + var typeHints: TypeHints = NoTypeHints + + final def addTypeHints(value: TypeHints): Unit = { + typeHints = typeHints + value + } + + final def addShortTypeHint(value: Class[_]): Unit = { + typeHints = typeHints + ShortTypeHints(List(value)) + } +} + class TrackerConfParam( parent: Params, name: String, diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala new file mode 100644 index 000000000000..106d73ecaccf --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -0,0 +1,89 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.java.XGBoostError +import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} +import ml.dmlc.xgboost4j.scala.spark.params.CustomObjParam.addShortTypeHint +import org.apache.commons.logging.LogFactory +import org.json4s.ShortTypeHints +import scala.collection.mutable.ListBuffer + + +/** + * loglikelihood loss obj function + */ +case class CustomObj() extends ObjectiveTrait { + + val logger = LogFactory.getLog(classOf[CustomObj]) + + addShortTypeHint(classOf[CustomObj]) + + /** + * user define objective function, return gradient and second order gradient + * + * @param predicts untransformed margin predicts + * @param dtrain training data + * @return List with two float array, correspond to first order grad and second order grad + */ + override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix) + : List[Array[Float]] = { + val nrow = predicts.length + val gradients = new ListBuffer[Array[Float]] + var labels: Array[Float] = null + try { + labels = dtrain.getLabel + } catch { + case e: XGBoostError => + logger.error(e) + null + case _: Throwable => + null + } + val grad = new Array[Float](nrow) + val hess = new Array[Float](nrow) + val transPredicts = transform(predicts) + + for (i <- 0 until nrow) { + val predict = transPredicts(i)(0) + grad(i) = predict - labels(i) + hess(i) = predict * (1 - predict) + } + gradients += grad + gradients += hess + gradients.toList + } + + /** + * simple sigmoid func + * + * @param input + * @return Note: this func is not concern about numerical stability, only used as example + */ + def sigmoid(input: Float): Float = { + (1 / (1 + Math.exp(-input))).toFloat + } + + def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = { + val nrow = predicts.length + val transPredicts = Array.fill[Float](nrow, 1)(0) + for (i <- 0 until nrow) { + transPredicts(i)(0) = sigmoid(predicts(i)(0)) + } + transPredicts + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index 91a840911a32..c1f22e8dfcc2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -18,12 +18,16 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} +import ml.dmlc.xgboost4j.scala.spark.params.CustomEvalParam.addShortTypeHint import org.apache.commons.logging.LogFactory +import org.json4s.ShortTypeHints -class EvalError extends EvalTrait { +case class EvalError() extends EvalTrait { val logger = LogFactory.getLog(classOf[EvalError]) + addShortTypeHint(classOf[EvalError]) + private[xgboost4j] var evalMetric: String = "custom_error" /** diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index ebe1d8546544..0ba396910121 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -25,6 +25,7 @@ import scala.util.Random import org.apache.spark.ml.feature._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.functions._ +import org.json4s.ShortTypeHints import org.scalatest.FunSuite class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { @@ -92,7 +93,6 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { } test("test persistence of MLlib pipeline with XGBoostClassificationModel") { - val r = new Random(0) // maybe move to shared context, but requires session to import implicits val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). @@ -133,6 +133,41 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol) } + test("test persistence of XGBoostClassifier and XGBoostClassificationModel " + + "using custom Obj and Eval") { + val eval = new EvalError() + val trainingDF = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "custom_obj" -> new CustomObj, "custom_eval" -> new EvalError, + "num_round" -> "10", "num_workers" -> numWorkers) + val xgb = new XGBoostClassifier(paramMap) + + val xgbc = new XGBoostClassifier(paramMap) + val xgbcPath = new File(tempDir.toFile, "xgbc").getPath + xgbc.write.overwrite().save(xgbcPath) + val xgbc2 = XGBoostClassifier.load(xgbcPath) + val paramMap2 = xgbc2.MLlib2XGBoostParams + paramMap.foreach { + case (k, v) => assert(v.toString == paramMap2(k).toString) + } + + val model = xgbc.fit(trainingDF) + val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults < 0.1) + val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath + model.write.overwrite.save(xgbcModelPath) + val model2 = XGBoostClassificationModel.load(xgbcModelPath) + assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) + + assert(model.getEta === model2.getEta) + assert(model.getNumRound === model2.getNumRound) + assert(model.getRawPredictionCol === model2.getRawPredictionCol) + val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults === evalResults2) + } + test("cross-version model loading (0.82)") { val modelPath = getClass.getResource("/model/0.82/model").getPath val model = XGBoostClassificationModel.read.load(modelPath) From f5e475812fd7837edcfbb2c96ea338fbf9c24d32 Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Thu, 30 Sep 2021 11:11:30 +0200 Subject: [PATCH 02/16] fix to ensure the type hints get added only once --- .../scala/spark/params/CustomParams.scala | 4 ++-- .../ml/dmlc/xgboost4j/scala/spark/CustomObj.scala | 14 ++++++++++++-- .../ml/dmlc/xgboost4j/scala/spark/EvalError.scala | 14 ++++++++++++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index 6eb996c2da99..b170a405bd4d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -50,7 +50,7 @@ object CustomEvalParam { typeHints = typeHints + value } - final def addShortTypeHint(value: Class[_]): Unit = { + final def addTypeHintForClass(value: Class[_]): Unit = { typeHints = typeHints + ShortTypeHints(List(value)) } } @@ -81,7 +81,7 @@ object CustomObjParam { typeHints = typeHints + value } - final def addShortTypeHint(value: Class[_]): Unit = { + final def addTypeHintForClass(value: Class[_]): Unit = { typeHints = typeHints + ShortTypeHints(List(value)) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala index 106d73ecaccf..029807c34e73 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} -import ml.dmlc.xgboost4j.scala.spark.params.CustomObjParam.addShortTypeHint +import ml.dmlc.xgboost4j.scala.spark.params.CustomObjParam._ import org.apache.commons.logging.LogFactory import org.json4s.ShortTypeHints import scala.collection.mutable.ListBuffer @@ -31,7 +31,7 @@ case class CustomObj() extends ObjectiveTrait { val logger = LogFactory.getLog(classOf[CustomObj]) - addShortTypeHint(classOf[CustomObj]) + CustomObj.addTypeHint() /** * user define objective function, return gradient and second order gradient @@ -87,3 +87,13 @@ case class CustomObj() extends ObjectiveTrait { transPredicts } } + +object CustomObj { + private var typeHintAdded = false + def addTypeHint(): Unit = { + if (!typeHintAdded) { + addTypeHintForClass(classOf[CustomObj]) + typeHintAdded = true + } + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index c1f22e8dfcc2..84636575bce4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} -import ml.dmlc.xgboost4j.scala.spark.params.CustomEvalParam.addShortTypeHint +import ml.dmlc.xgboost4j.scala.spark.params.CustomEvalParam._ import org.apache.commons.logging.LogFactory import org.json4s.ShortTypeHints @@ -26,7 +26,7 @@ case class EvalError() extends EvalTrait { val logger = LogFactory.getLog(classOf[EvalError]) - addShortTypeHint(classOf[EvalError]) + EvalError.addTypeHint() private[xgboost4j] var evalMetric: String = "custom_error" @@ -67,3 +67,13 @@ case class EvalError() extends EvalTrait { error / labels.length } } + +object EvalError { + private var typeHintAdded = false + def addTypeHint(): Unit = { + if (!typeHintAdded) { + addTypeHintForClass(classOf[EvalError]) + typeHintAdded = true + } + } +} From e4d4d6b899d205f327097ddfde5ad289fdbc18ce Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Tue, 5 Oct 2021 11:15:35 +0200 Subject: [PATCH 03/16] fixed code comments --- .../test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala | 1 - .../test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala | 1 - .../scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala | 3 +-- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala index 029807c34e73..78620163a168 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -20,7 +20,6 @@ import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.params.CustomObjParam._ import org.apache.commons.logging.LogFactory -import org.json4s.ShortTypeHints import scala.collection.mutable.ListBuffer diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index 84636575bce4..f8425c8d79f3 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -20,7 +20,6 @@ import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} import ml.dmlc.xgboost4j.scala.spark.params.CustomEvalParam._ import org.apache.commons.logging.LogFactory -import org.json4s.ShortTypeHints case class EvalError() extends EvalTrait { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index 0ba396910121..f247aa17d85a 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -142,8 +142,7 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", "custom_obj" -> new CustomObj, "custom_eval" -> new EvalError, "num_round" -> "10", "num_workers" -> numWorkers) - val xgb = new XGBoostClassifier(paramMap) - + val xgbc = new XGBoostClassifier(paramMap) val xgbcPath = new File(tempDir.toFile, "xgbc").getPath xgbc.write.overwrite().save(xgbcPath) From 8b6ee8befe3ac88687301d0eb0116a8ee0374237 Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Tue, 5 Oct 2021 12:28:08 +0200 Subject: [PATCH 04/16] moved adding the type hints to an abstract class --- .../xgboost4j/scala/spark/CustomObj.scala | 17 ++-------- .../xgboost4j/scala/spark/EvalError.scala | 17 ++-------- .../scala/spark/SparkCustomEval.scala | 34 +++++++++++++++++++ .../scala/spark/SparkCustomObj.scala | 34 +++++++++++++++++++ 4 files changed, 72 insertions(+), 30 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala index 78620163a168..39491bca7a3a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -17,8 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError -import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} -import ml.dmlc.xgboost4j.scala.spark.params.CustomObjParam._ +import ml.dmlc.xgboost4j.scala.{DMatrix} import org.apache.commons.logging.LogFactory import scala.collection.mutable.ListBuffer @@ -26,12 +25,10 @@ import scala.collection.mutable.ListBuffer /** * loglikelihood loss obj function */ -case class CustomObj() extends ObjectiveTrait { +case class CustomObj() extends SparkCustomObj { val logger = LogFactory.getLog(classOf[CustomObj]) - CustomObj.addTypeHint() - /** * user define objective function, return gradient and second order gradient * @@ -86,13 +83,3 @@ case class CustomObj() extends ObjectiveTrait { transPredicts } } - -object CustomObj { - private var typeHintAdded = false - def addTypeHint(): Unit = { - if (!typeHintAdded) { - addTypeHintForClass(classOf[CustomObj]) - typeHintAdded = true - } - } -} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index f8425c8d79f3..435567fd1c9a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -17,16 +17,13 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError -import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} -import ml.dmlc.xgboost4j.scala.spark.params.CustomEvalParam._ +import ml.dmlc.xgboost4j.scala.{DMatrix} import org.apache.commons.logging.LogFactory -case class EvalError() extends EvalTrait { +case class EvalError() extends SparkCustomEval { val logger = LogFactory.getLog(classOf[EvalError]) - EvalError.addTypeHint() - private[xgboost4j] var evalMetric: String = "custom_error" /** @@ -66,13 +63,3 @@ case class EvalError() extends EvalTrait { error / labels.length } } - -object EvalError { - private var typeHintAdded = false - def addTypeHint(): Unit = { - if (!typeHintAdded) { - addTypeHintForClass(classOf[EvalError]) - typeHintAdded = true - } - } -} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala new file mode 100644 index 000000000000..b7351c31c7d5 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala @@ -0,0 +1,34 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.scala.EvalTrait +import ml.dmlc.xgboost4j.scala.spark.params.CustomEvalParam._ + +abstract class SparkCustomEval extends EvalTrait { + SparkCustomEval.addTypeHint(this) +} + +object SparkCustomEval { + private var typeHintAdded = false + def addTypeHint(thiz: SparkCustomEval): Unit = { + if (!typeHintAdded) { + addTypeHintForClass(thiz.getClass()) + typeHintAdded = true + } + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala new file mode 100644 index 000000000000..076a9d138993 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala @@ -0,0 +1,34 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.scala.ObjectiveTrait +import ml.dmlc.xgboost4j.scala.spark.params.CustomObjParam._ + +abstract class SparkCustomObj extends ObjectiveTrait { + SparkCustomObj.addTypeHint(this) +} + +object SparkCustomObj { + private var typeHintAdded = false + def addTypeHint(thiz: SparkCustomObj): Unit = { + if (!typeHintAdded) { + addTypeHintForClass(thiz.getClass()) + typeHintAdded = true + } + } +} From 769f132e005c7aa093854527c3c85a0cdaf6ae61 Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Thu, 7 Oct 2021 12:24:25 +0200 Subject: [PATCH 05/16] enabled multiple implementations custom obj/eval --- .../ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala | 7 ++++--- .../ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala index b7351c31c7d5..3312c42fcf21 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala @@ -24,11 +24,12 @@ abstract class SparkCustomEval extends EvalTrait { } object SparkCustomEval { - private var typeHintAdded = false + private var typeHintsAdded = Set[String]() def addTypeHint(thiz: SparkCustomEval): Unit = { - if (!typeHintAdded) { + val className = thiz.getClass().getSimpleName() + if (!typeHintsAdded.contains(className)) { addTypeHintForClass(thiz.getClass()) - typeHintAdded = true + typeHintsAdded += className } } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala index 076a9d138993..11f20a5b688b 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala @@ -24,11 +24,12 @@ abstract class SparkCustomObj extends ObjectiveTrait { } object SparkCustomObj { - private var typeHintAdded = false + private var typeHintsAdded = Set[String]() def addTypeHint(thiz: SparkCustomObj): Unit = { - if (!typeHintAdded) { + val className = thiz.getClass().getSimpleName() + if (!typeHintsAdded.contains(className)) { addTypeHintForClass(thiz.getClass()) - typeHintAdded = true + typeHintsAdded += className } } } From 85aff26c381987e45fdcfa5ad5e6995a323213c9 Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Mon, 11 Oct 2021 11:30:59 +0200 Subject: [PATCH 06/16] impl in GeneralParams, removed abstract classes --- .../scala/spark/params/CustomParams.scala | 48 +++++++++++++++---- .../scala/spark/params/GeneralParams.scala | 6 +++ .../xgboost4j/scala/spark/CustomObj.scala | 4 +- .../xgboost4j/scala/spark/EvalError.scala | 4 +- .../scala/spark/PersistenceSuite.scala | 13 +++-- .../scala/spark/SparkCustomEval.scala | 35 -------------- .../scala/spark/SparkCustomObj.scala | 35 -------------- 7 files changed, 60 insertions(+), 85 deletions(-) delete mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index b170a405bd4d..3c222211bdec 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -45,13 +45,29 @@ class CustomEvalParam( object CustomEvalParam { var typeHints: TypeHints = NoTypeHints - - final def addTypeHints(value: TypeHints): Unit = { - typeHints = typeHints + value + private var typeHintsAdded = Set[String]() + + def addTypeHint(customEval: Any): Unit = { + if (!customEval.isInstanceOf[EvalTrait]) { + throw new IllegalArgumentException( + s"you specified $customEval as custom_eval," + + " but it does not implement EvalTrait." + ) + } + val clazz = customEval.getClass() + val className = clazz.getSimpleName() + if (!typeHintsAdded.contains(className)) { + addTypeHintForClass(clazz) + typeHintsAdded += className + } } final def addTypeHintForClass(value: Class[_]): Unit = { - typeHints = typeHints + ShortTypeHints(List(value)) + addTypeHints(ShortTypeHints(List(value))) + } + + final def addTypeHints(value: TypeHints): Unit = { + typeHints = typeHints + value } } @@ -76,13 +92,29 @@ class CustomObjParam( object CustomObjParam { var typeHints: TypeHints = NoTypeHints - - final def addTypeHints(value: TypeHints): Unit = { - typeHints = typeHints + value + private var typeHintsAdded = Set[String]() + + def addTypeHint(customObj: Any): Unit = { + if (!customObj.isInstanceOf[ObjectiveTrait]) { + throw new IllegalArgumentException( + s"you specified $customObj as custom_obj," + + " but it does not implement ObjectiveTrait." + ) + } + val clazz = customObj.getClass() + val className = clazz.getSimpleName() + if (!typeHintsAdded.contains(className)) { + addTypeHintForClass(clazz) + typeHintsAdded += className + } } final def addTypeHintForClass(value: Class[_]): Unit = { - typeHints = typeHints + ShortTypeHints(List(value)) + addTypeHints(ShortTypeHints(List(value))) + } + + final def addTypeHints(value: TypeHints): Unit = { + typeHints = typeHints + value } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index a75f64dd8aba..1215d087fc5c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -270,6 +270,12 @@ private[spark] trait ParamMapFuncs extends Params { set(name, paramValue.toString.toFloat) case _: LongParam => set(name, paramValue.toString.toLong) + case _: CustomObjParam => + CustomObjParam.addTypeHint(paramValue) + set(name, paramValue) + case _: CustomEvalParam => + CustomEvalParam.addTypeHint(paramValue) + set(name, paramValue) case _: Param[_] => set(name, paramValue) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala index 39491bca7a3a..f1ade456acaf 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError -import ml.dmlc.xgboost4j.scala.{DMatrix} +import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} import org.apache.commons.logging.LogFactory import scala.collection.mutable.ListBuffer @@ -25,7 +25,7 @@ import scala.collection.mutable.ListBuffer /** * loglikelihood loss obj function */ -case class CustomObj() extends SparkCustomObj { +class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait { val logger = LogFactory.getLog(classOf[CustomObj]) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index 435567fd1c9a..b223a72952b8 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -17,10 +17,10 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError -import ml.dmlc.xgboost4j.scala.{DMatrix} +import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} import org.apache.commons.logging.LogFactory -case class EvalError() extends SparkCustomEval { +class EvalError() extends EvalTrait { val logger = LogFactory.getLog(classOf[EvalError]) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index f247aa17d85a..c06fa9ad6cca 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -134,13 +134,12 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { } test("test persistence of XGBoostClassifier and XGBoostClassificationModel " + - "using custom Obj and Eval") { + "using custom Eval and Obj") { val eval = new EvalError() val trainingDF = buildDataFrame(Classification.train) val testDM = new DMatrix(Classification.test.iterator) - val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", - "custom_obj" -> new CustomObj, "custom_eval" -> new EvalError, + "custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1), "num_round" -> "10", "num_workers" -> numWorkers) val xgbc = new XGBoostClassifier(paramMap) @@ -149,6 +148,14 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { val xgbc2 = XGBoostClassifier.load(xgbcPath) val paramMap2 = xgbc2.MLlib2XGBoostParams paramMap.foreach { + case ("custom_eval", v) => { + assert(v.isInstanceOf[EvalError]) + } + case ("custom_obj", v) => { + assert(v.isInstanceOf[CustomObj]) + assert(v.asInstanceOf[CustomObj].customParameter == + paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter) + } case (k, v) => assert(v.toString == paramMap2(k).toString) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala deleted file mode 100644 index 3312c42fcf21..000000000000 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomEval.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import ml.dmlc.xgboost4j.scala.EvalTrait -import ml.dmlc.xgboost4j.scala.spark.params.CustomEvalParam._ - -abstract class SparkCustomEval extends EvalTrait { - SparkCustomEval.addTypeHint(this) -} - -object SparkCustomEval { - private var typeHintsAdded = Set[String]() - def addTypeHint(thiz: SparkCustomEval): Unit = { - val className = thiz.getClass().getSimpleName() - if (!typeHintsAdded.contains(className)) { - addTypeHintForClass(thiz.getClass()) - typeHintsAdded += className - } - } -} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala deleted file mode 100644 index 11f20a5b688b..000000000000 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SparkCustomObj.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import ml.dmlc.xgboost4j.scala.ObjectiveTrait -import ml.dmlc.xgboost4j.scala.spark.params.CustomObjParam._ - -abstract class SparkCustomObj extends ObjectiveTrait { - SparkCustomObj.addTypeHint(this) -} - -object SparkCustomObj { - private var typeHintsAdded = Set[String]() - def addTypeHint(thiz: SparkCustomObj): Unit = { - val className = thiz.getClass().getSimpleName() - if (!typeHintsAdded.contains(className)) { - addTypeHintForClass(thiz.getClass()) - typeHintsAdded += className - } - } -} From 57b94292c8d90cb82735c34ba59fe3449ffd411b Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Mon, 11 Oct 2021 11:41:51 +0200 Subject: [PATCH 07/16] reverting unnecessary change to EvalError --- .../test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index b223a72952b8..91a840911a32 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -20,7 +20,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} import org.apache.commons.logging.LogFactory -class EvalError() extends EvalTrait { +class EvalError extends EvalTrait { val logger = LogFactory.getLog(classOf[EvalError]) From 4384bc21167a8b17b05b4484bd3c5f2e93110f4a Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Wed, 13 Oct 2021 10:02:47 +0200 Subject: [PATCH 08/16] using a single SavedTypeHints object --- .../scala/spark/params/CustomParams.scala | 58 ++++--------------- .../scala/spark/params/GeneralParams.scala | 4 +- 2 files changed, 14 insertions(+), 48 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index 3c222211bdec..8b62a6815cae 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -33,44 +33,16 @@ class CustomEvalParam( override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value) override def jsonEncode(value: EvalTrait): String = { - implicit val formats = DefaultFormats.withHints(CustomEvalParam.typeHints) + implicit val formats = DefaultFormats.withHints(SavedTypeHints.typeHints) compact(render(Extraction.decompose(value))) } override def jsonDecode(json: String): EvalTrait = { - implicit val formats = DefaultFormats.withHints(CustomEvalParam.typeHints) + implicit val formats = DefaultFormats.withHints(SavedTypeHints.typeHints) parse(json).extract[EvalTrait] } } -object CustomEvalParam { - var typeHints: TypeHints = NoTypeHints - private var typeHintsAdded = Set[String]() - - def addTypeHint(customEval: Any): Unit = { - if (!customEval.isInstanceOf[EvalTrait]) { - throw new IllegalArgumentException( - s"you specified $customEval as custom_eval," + - " but it does not implement EvalTrait." - ) - } - val clazz = customEval.getClass() - val className = clazz.getSimpleName() - if (!typeHintsAdded.contains(className)) { - addTypeHintForClass(clazz) - typeHintsAdded += className - } - } - - final def addTypeHintForClass(value: Class[_]): Unit = { - addTypeHints(ShortTypeHints(List(value))) - } - - final def addTypeHints(value: TypeHints): Unit = { - typeHints = typeHints + value - } -} - class CustomObjParam( parent: Params, name: String, @@ -80,40 +52,34 @@ class CustomObjParam( override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value) override def jsonEncode(value: ObjectiveTrait): String = { - implicit val formats = DefaultFormats.withHints(CustomObjParam.typeHints) + implicit val formats = DefaultFormats.withHints(SavedTypeHints.typeHints) compact(render(Extraction.decompose(value))) } override def jsonDecode(json: String): ObjectiveTrait = { - implicit val formats = DefaultFormats.withHints(CustomObjParam.typeHints) + implicit val formats = DefaultFormats.withHints(SavedTypeHints.typeHints) parse(json).extract[ObjectiveTrait] } } -object CustomObjParam { +object SavedTypeHints { var typeHints: TypeHints = NoTypeHints private var typeHintsAdded = Set[String]() - def addTypeHint(customObj: Any): Unit = { - if (!customObj.isInstanceOf[ObjectiveTrait]) { - throw new IllegalArgumentException( - s"you specified $customObj as custom_obj," + - " but it does not implement ObjectiveTrait." - ) - } - val clazz = customObj.getClass() - val className = clazz.getSimpleName() + def addClass(customEval: Any): Unit = { + val clazz = customEval.getClass() + val className = clazz.getName() if (!typeHintsAdded.contains(className)) { - addTypeHintForClass(clazz) + addClassForClass(clazz) typeHintsAdded += className } } - final def addTypeHintForClass(value: Class[_]): Unit = { - addTypeHints(ShortTypeHints(List(value))) + final def addClassForClass(value: Class[_]): Unit = { + addClasss(ShortTypeHints(List(value))) } - final def addTypeHints(value: TypeHints): Unit = { + final def addClasss(value: TypeHints): Unit = { typeHints = typeHints + value } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 1215d087fc5c..067cc9e7e623 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -271,10 +271,10 @@ private[spark] trait ParamMapFuncs extends Params { case _: LongParam => set(name, paramValue.toString.toLong) case _: CustomObjParam => - CustomObjParam.addTypeHint(paramValue) + SavedTypeHints.addClass(paramValue) set(name, paramValue) case _: CustomEvalParam => - CustomEvalParam.addTypeHint(paramValue) + SavedTypeHints.addClass(paramValue) set(name, paramValue) case _: Param[_] => set(name, paramValue) From cbac4f967a7628e0ccd04e12d8cde1093380712a Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Wed, 13 Oct 2021 10:53:15 +0200 Subject: [PATCH 09/16] moved logic to AddTypeHints trait --- .../xgboost4j/scala/spark/params/AddTypeHints.scala | 5 +++++ .../xgboost4j/scala/spark/params/CustomParams.scala | 11 +++++++---- .../xgboost4j/scala/spark/params/GeneralParams.scala | 6 ------ .../ml/dmlc/xgboost4j/scala/spark/CustomObj.scala | 3 ++- .../ml/dmlc/xgboost4j/scala/spark/EvalError.scala | 3 ++- 5 files changed, 16 insertions(+), 12 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala new file mode 100644 index 000000000000..64c4ed39d9a6 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala @@ -0,0 +1,5 @@ +package ml.dmlc.xgboost4j.scala.spark.params + +trait AddTypeHints { + val typeHintAdded = SavedTypeHints.addClassOf(this) +} \ No newline at end of file diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index 8b62a6815cae..c9dcab2450fa 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -66,16 +66,19 @@ object SavedTypeHints { var typeHints: TypeHints = NoTypeHints private var typeHintsAdded = Set[String]() - def addClass(customEval: Any): Unit = { - val clazz = customEval.getClass() + def addClassOf(instance: Any): Boolean = { + val clazz = instance.getClass() val className = clazz.getName() if (!typeHintsAdded.contains(className)) { - addClassForClass(clazz) + addClass(clazz) typeHintsAdded += className + true + } else { + false } } - final def addClassForClass(value: Class[_]): Unit = { + final def addClass(value: Class[_]): Unit = { addClasss(ShortTypeHints(List(value))) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 067cc9e7e623..a75f64dd8aba 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -270,12 +270,6 @@ private[spark] trait ParamMapFuncs extends Params { set(name, paramValue.toString.toFloat) case _: LongParam => set(name, paramValue.toString.toLong) - case _: CustomObjParam => - SavedTypeHints.addClass(paramValue) - set(name, paramValue) - case _: CustomEvalParam => - SavedTypeHints.addClass(paramValue) - set(name, paramValue) case _: Param[_] => set(name, paramValue) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala index f1ade456acaf..e026b7b97eba 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} +import ml.dmlc.xgboost4j.scala.spark.params.AddTypeHints import org.apache.commons.logging.LogFactory import scala.collection.mutable.ListBuffer @@ -25,7 +26,7 @@ import scala.collection.mutable.ListBuffer /** * loglikelihood loss obj function */ -class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait { +class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait with AddTypeHints { val logger = LogFactory.getLog(classOf[CustomObj]) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index 91a840911a32..a7cd77744b97 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -18,9 +18,10 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} +import ml.dmlc.xgboost4j.scala.spark.params.AddTypeHints import org.apache.commons.logging.LogFactory -class EvalError extends EvalTrait { +class EvalError extends EvalTrait with AddTypeHints { val logger = LogFactory.getLog(classOf[EvalError]) From 643a06464514767523def2166b5ef40e0ba5e0ed Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Wed, 13 Oct 2021 10:58:48 +0200 Subject: [PATCH 10/16] style change --- .../ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala | 2 +- .../ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala index 64c4ed39d9a6..7b6f29a14d8f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala @@ -2,4 +2,4 @@ package ml.dmlc.xgboost4j.scala.spark.params trait AddTypeHints { val typeHintAdded = SavedTypeHints.addClassOf(this) -} \ No newline at end of file +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index c9dcab2450fa..445350e05dc6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -18,7 +18,6 @@ package ml.dmlc.xgboost4j.scala.spark.params import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.TrackerConf - import org.json4s.{DefaultFormats, Extraction, NoTypeHints, ShortTypeHints, TypeHints} import org.json4s.jackson.JsonMethods.{compact, parse, render} @@ -66,7 +65,7 @@ object SavedTypeHints { var typeHints: TypeHints = NoTypeHints private var typeHintsAdded = Set[String]() - def addClassOf(instance: Any): Boolean = { + final def addClassOf(instance: Any): Boolean = { val clazz = instance.getClass() val className = clazz.getName() if (!typeHintsAdded.contains(className)) { From 8e51b6c38b57f74f71e68fe07d7751ce195eb2ea Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Fri, 15 Oct 2021 09:24:13 +0200 Subject: [PATCH 11/16] fixed review comments --- .../scala/spark/params/AddTypeHints.scala | 5 ---- .../scala/spark/params/CustomParams.scala | 3 +++ .../scala/spark/params/TypeHintsTrait.scala | 24 +++++++++++++++++++ .../xgboost4j/scala/spark/CustomObj.scala | 4 ++-- .../xgboost4j/scala/spark/EvalError.scala | 4 ++-- .../scala/spark/PersistenceSuite.scala | 3 +-- 6 files changed, 32 insertions(+), 11 deletions(-) delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala deleted file mode 100644 index 7b6f29a14d8f..000000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/AddTypeHints.scala +++ /dev/null @@ -1,5 +0,0 @@ -package ml.dmlc.xgboost4j.scala.spark.params - -trait AddTypeHints { - val typeHintAdded = SavedTypeHints.addClassOf(this) -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index 445350e05dc6..519011da5875 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -62,6 +62,9 @@ class CustomObjParam( } object SavedTypeHints { + /** + * Stores type hints for (de)serialization of custom objective and eval params. + */ var typeHints: TypeHints = NoTypeHints private var typeHintsAdded = Set[String]() diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala new file mode 100644 index 000000000000..0be6c41a5090 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala @@ -0,0 +1,24 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + + package ml.dmlc.xgboost4j.scala.spark.params + +trait TypeHintsTrait { + /** + * Trait that helps creating type hints for (de)serialization using json4s. + */ + val typeHintAdded = SavedTypeHints.addClassOf(this) +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala index e026b7b97eba..1d64f4af433a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} -import ml.dmlc.xgboost4j.scala.spark.params.AddTypeHints +import ml.dmlc.xgboost4j.scala.spark.params.TypeHintsTrait import org.apache.commons.logging.LogFactory import scala.collection.mutable.ListBuffer @@ -26,7 +26,7 @@ import scala.collection.mutable.ListBuffer /** * loglikelihood loss obj function */ -class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait with AddTypeHints { +class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait with TypeHintsTrait { val logger = LogFactory.getLog(classOf[CustomObj]) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index a7cd77744b97..982000a182f3 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -18,10 +18,10 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} -import ml.dmlc.xgboost4j.scala.spark.params.AddTypeHints +import ml.dmlc.xgboost4j.scala.spark.params.TypeHintsTrait import org.apache.commons.logging.LogFactory -class EvalError extends EvalTrait with AddTypeHints { +class EvalError extends EvalTrait with TypeHintsTrait { val logger = LogFactory.getLog(classOf[EvalError]) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index c06fa9ad6cca..396bbe77eff6 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -25,7 +25,6 @@ import scala.util.Random import org.apache.spark.ml.feature._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.functions._ -import org.json4s.ShortTypeHints import org.scalatest.FunSuite class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { @@ -156,7 +155,7 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { assert(v.asInstanceOf[CustomObj].customParameter == paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter) } - case (k, v) => assert(v.toString == paramMap2(k).toString) + case (_, _) => } val model = xgbc.fit(trainingDF) From 6272a721526bd7a60f75147cc116694d33d730bf Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Fri, 15 Oct 2021 09:58:50 +0200 Subject: [PATCH 12/16] scalastyle changes --- .../ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala | 2 +- .../scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala index 0be6c41a5090..bac02c82282b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ - + package ml.dmlc.xgboost4j.scala.spark.params trait TypeHintsTrait { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index 396bbe77eff6..4533e36b74b6 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -140,7 +140,7 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", "custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1), "num_round" -> "10", "num_workers" -> numWorkers) - + val xgbc = new XGBoostClassifier(paramMap) val xgbcPath = new File(tempDir.toFile, "xgbc").getPath xgbc.write.overwrite().save(xgbcPath) From f447df86b5a06dbb4e22ca8ed3115aaf6b23c88b Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Sat, 16 Oct 2021 09:48:10 +0800 Subject: [PATCH 13/16] a new way to handle json encode/decode --- .../scala/spark/params/CustomParams.scala | 45 +++++-------------- .../scala/spark/params/TypeHintsTrait.scala | 24 ---------- .../xgboost4j/scala/spark/CustomObj.scala | 8 ++-- .../xgboost4j/scala/spark/EvalError.scala | 3 +- .../scala/spark/PersistenceSuite.scala | 12 +++-- 5 files changed, 20 insertions(+), 72 deletions(-) delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index 519011da5875..1815fc2df00d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -18,8 +18,10 @@ package ml.dmlc.xgboost4j.scala.spark.params import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.TrackerConf -import org.json4s.{DefaultFormats, Extraction, NoTypeHints, ShortTypeHints, TypeHints} +import org.json4s.JsonAST.JField +import org.json4s.{DefaultFormats, Extraction, FullTypeHints, JValue, NoTypeHints, TypeHints} import org.json4s.jackson.JsonMethods.{compact, parse, render} +import org.json4s.jackson.Serialization import org.apache.spark.ml.param.{Param, ParamPair, Params} @@ -32,13 +34,14 @@ class CustomEvalParam( override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value) override def jsonEncode(value: EvalTrait): String = { - implicit val formats = DefaultFormats.withHints(SavedTypeHints.typeHints) + implicit val format = Serialization.formats(TypeHintsUtil.getTypeHints(value)) compact(render(Extraction.decompose(value))) } override def jsonDecode(json: String): EvalTrait = { - implicit val formats = DefaultFormats.withHints(SavedTypeHints.typeHints) - parse(json).extract[EvalTrait] + val js = parse(json) + implicit val formats = DefaultFormats.withHints(TypeHintsUtil.extractTypeHint(js)) + js.extract[EvalTrait] } } @@ -51,42 +54,16 @@ class CustomObjParam( override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value) override def jsonEncode(value: ObjectiveTrait): String = { - implicit val formats = DefaultFormats.withHints(SavedTypeHints.typeHints) + implicit val format = Serialization.formats(TypeHintsUtil.getTypeHints(value)) compact(render(Extraction.decompose(value))) } override def jsonDecode(json: String): ObjectiveTrait = { - implicit val formats = DefaultFormats.withHints(SavedTypeHints.typeHints) - parse(json).extract[ObjectiveTrait] - } -} - -object SavedTypeHints { - /** - * Stores type hints for (de)serialization of custom objective and eval params. - */ - var typeHints: TypeHints = NoTypeHints - private var typeHintsAdded = Set[String]() - - final def addClassOf(instance: Any): Boolean = { - val clazz = instance.getClass() - val className = clazz.getName() - if (!typeHintsAdded.contains(className)) { - addClass(clazz) - typeHintsAdded += className - true - } else { - false - } + val js = parse(json) + implicit val formats = DefaultFormats.withHints(TypeHintsUtil.extractTypeHint(js)) + js.extract[ObjectiveTrait] } - final def addClass(value: Class[_]): Unit = { - addClasss(ShortTypeHints(List(value))) - } - - final def addClasss(value: TypeHints): Unit = { - typeHints = typeHints + value - } } class TrackerConfParam( diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala deleted file mode 100644 index bac02c82282b..000000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TypeHintsTrait.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - - package ml.dmlc.xgboost4j.scala.spark.params - -trait TypeHintsTrait { - /** - * Trait that helps creating type hints for (de)serialization using json4s. - */ - val typeHintAdded = SavedTypeHints.addClassOf(this) -} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala index 1d64f4af433a..432a950dcb7e 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -18,7 +18,6 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} -import ml.dmlc.xgboost4j.scala.spark.params.TypeHintsTrait import org.apache.commons.logging.LogFactory import scala.collection.mutable.ListBuffer @@ -26,7 +25,7 @@ import scala.collection.mutable.ListBuffer /** * loglikelihood loss obj function */ -class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait with TypeHintsTrait { +class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait { val logger = LogFactory.getLog(classOf[CustomObj]) @@ -47,9 +46,8 @@ class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait with TypeHi } catch { case e: XGBoostError => logger.error(e) - null - case _: Throwable => - null + throw e + case e: Throwable => throw e } val grad = new Array[Float](nrow) val hess = new Array[Float](nrow) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index 982000a182f3..91a840911a32 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -18,10 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} -import ml.dmlc.xgboost4j.scala.spark.params.TypeHintsTrait import org.apache.commons.logging.LogFactory -class EvalError extends EvalTrait with TypeHintsTrait { +class EvalError extends EvalTrait { val logger = LogFactory.getLog(classOf[EvalError]) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index 4533e36b74b6..788436e89017 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -133,8 +133,7 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { } test("test persistence of XGBoostClassifier and XGBoostClassificationModel " + - "using custom Eval and Obj") { - val eval = new EvalError() + "using custom Eval and Obj") { val trainingDF = buildDataFrame(Classification.train) val testDM = new DMatrix(Classification.test.iterator) val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", @@ -147,17 +146,16 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { val xgbc2 = XGBoostClassifier.load(xgbcPath) val paramMap2 = xgbc2.MLlib2XGBoostParams paramMap.foreach { - case ("custom_eval", v) => { - assert(v.isInstanceOf[EvalError]) - } - case ("custom_obj", v) => { + case ("custom_eval", v) => assert(v.isInstanceOf[EvalError]) + case ("custom_obj", v) => assert(v.isInstanceOf[CustomObj]) assert(v.asInstanceOf[CustomObj].customParameter == paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter) - } case (_, _) => } + val eval = new EvalError() + val model = xgbc.fit(trainingDF) val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(evalResults < 0.1) From 0247104fdaf93cc2d4000b59f1598928dc1a4b96 Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Mon, 18 Oct 2021 12:10:56 +0200 Subject: [PATCH 14/16] moved some logic (CustomGeneralParam and params.Utils) --- .../scala/spark/params/CustomParams.scala | 54 +++++++++---------- .../xgboost4j/scala/spark/params/Utils.scala | 38 +++++++++++++ 2 files changed, 62 insertions(+), 30 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index 1815fc2df00d..c74560218481 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -18,53 +18,47 @@ package ml.dmlc.xgboost4j.scala.spark.params import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.TrackerConf -import org.json4s.JsonAST.JField -import org.json4s.{DefaultFormats, Extraction, FullTypeHints, JValue, NoTypeHints, TypeHints} +import org.apache.spark.ml.param.{Param, ParamPair, Params} +import org.json4s.{DefaultFormats, Extraction, NoTypeHints} import org.json4s.jackson.JsonMethods.{compact, parse, render} import org.json4s.jackson.Serialization -import org.apache.spark.ml.param.{Param, ParamPair, Params} - -class CustomEvalParam( +/** + * General spark parameter that includes TypeHints for (de)serialization using json4s. + */ +class CustomGeneralParam[T: Manifest]( parent: Params, name: String, - doc: String) extends Param[EvalTrait](parent, name, doc) { + doc: String) extends Param[T](parent, name, doc) { /** Creates a param pair with the given value (for Java). */ - override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value) + override def w(value: T): ParamPair[T] = super.w(value) - override def jsonEncode(value: EvalTrait): String = { - implicit val format = Serialization.formats(TypeHintsUtil.getTypeHints(value)) + override def jsonEncode(value: T): String = { + implicit val format = Serialization.formats(Utils.getTypeHintsFromClass(value)) compact(render(Extraction.decompose(value))) } - override def jsonDecode(json: String): EvalTrait = { - val js = parse(json) - implicit val formats = DefaultFormats.withHints(TypeHintsUtil.extractTypeHint(js)) - js.extract[EvalTrait] + override def jsonDecode(json: String): T = { + jsonDecodeT(json) + } + + private def jsonDecodeT[T](jsonString: String)(implicit m: Manifest[T]): T = { + val json = parse(jsonString) + implicit val formats = DefaultFormats.withHints(Utils.getTypeHintsFromJsonClass(json)) + json.extract[T] } } -class CustomObjParam( +class CustomEvalParam( parent: Params, name: String, - doc: String) extends Param[ObjectiveTrait](parent, name, doc) { - - /** Creates a param pair with the given value (for Java). */ - override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value) - - override def jsonEncode(value: ObjectiveTrait): String = { - implicit val format = Serialization.formats(TypeHintsUtil.getTypeHints(value)) - compact(render(Extraction.decompose(value))) - } - - override def jsonDecode(json: String): ObjectiveTrait = { - val js = parse(json) - implicit val formats = DefaultFormats.withHints(TypeHintsUtil.extractTypeHint(js)) - js.extract[ObjectiveTrait] - } + doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc) -} +class CustomObjParam( + parent: Params, + name: String, + doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc) class TrackerConfParam( parent: Params, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala index 7d6e7b9ed715..cd4d588f788b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala @@ -16,6 +16,8 @@ package ml.dmlc.xgboost4j.scala.spark.params +import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints} + // based on org.apache.spark.util copy /paste private[spark] object Utils { @@ -30,4 +32,40 @@ private[spark] object Utils { Class.forName(className, true, getContextOrSparkClassLoader) // scalastyle:on classforname } + + /** + * Get the TypeHints according to the value + * @param value the instance of customized obj/eval + * @return if value is null, + * return NoTypeHints + * else return the FullTypeHints. + * + * The FullTypeHints will save the full class name into the "jsonClass" of the json, + * so we can find the jsonClass and turn it to FullTypeHints when deserializing. + */ + def getTypeHintsFromClass(value: Any): TypeHints = { + if (value == null) { + NoTypeHints + } else { // XGBoost will save the default values + FullTypeHints(List(value.getClass)) + } + } + + /** + * Get the TypeHints according to the saved jsonClass field + * @param json + * @return TypeHints + */ + def getTypeHintsFromJsonClass(json: JValue): TypeHints = { + val jsonClassField = json findField { + case JField("jsonClass", _) => true + case _ => false + } + + jsonClassField.map { field => + implicit val formats = DefaultFormats + val className = field._2.extract[String] + FullTypeHints(List(Utils.classForName(className))) + }.getOrElse(NoTypeHints) + } } From aa6c2b48c148b034722510924e3aa1a61c6407fb Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Mon, 18 Oct 2021 16:33:14 +0200 Subject: [PATCH 15/16] updated copyright year and a code comment --- .../ml/dmlc/xgboost4j/scala/spark/params/Utils.scala | 8 ++++---- .../scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala | 2 +- .../ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala index cd4d588f788b..ddbef93747f8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014,2021 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ private[spark] object Utils { /** * Get the TypeHints according to the value - * @param value the instance of customized obj/eval + * @param value the instance of class to be serialized * @return if value is null, * return NoTypeHints * else return the FullTypeHints. @@ -44,9 +44,9 @@ private[spark] object Utils { * so we can find the jsonClass and turn it to FullTypeHints when deserializing. */ def getTypeHintsFromClass(value: Any): TypeHints = { - if (value == null) { + if (value == null) { // XGBoost will save the default value (null) NoTypeHints - } else { // XGBoost will save the default values + } else { // XGBoost will save the full instance FullTypeHints(List(value.getClass)) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala index 432a950dcb7e..b9a39a14d4f7 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2021 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index 788436e89017..a1732c7f7e1b 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014,2021 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 3ad1414ec1d3a4f1b43c86ee5d84214c98996553 Mon Sep 17 00:00:00 2001 From: nicovdijk Date: Tue, 19 Oct 2021 08:16:22 +0200 Subject: [PATCH 16/16] removed comment as suggested --- .../main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala index ddbef93747f8..fb84ad6d6a85 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/Utils.scala @@ -46,7 +46,7 @@ private[spark] object Utils { def getTypeHintsFromClass(value: Any): TypeHints = { if (value == null) { // XGBoost will save the default value (null) NoTypeHints - } else { // XGBoost will save the full instance + } else { FullTypeHints(List(value.getClass)) } }