diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala index b7f696a2399a..bee6d75625ea 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark.params import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.TrackerConf -import org.json4s.{DefaultFormats, Extraction, NoTypeHints} + +import org.json4s.{DefaultFormats, Extraction, NoTypeHints, ShortTypeHints, TypeHints} import org.json4s.jackson.JsonMethods.{compact, parse, render} import org.apache.spark.ml.param.{Param, ParamPair, Params} @@ -33,17 +34,28 @@ class CustomEvalParam( override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value) override def jsonEncode(value: EvalTrait): String = { - import org.json4s.jackson.Serialization - implicit val formats = Serialization.formats(NoTypeHints) + implicit val formats = DefaultFormats.withHints(CustomEvalParam.typeHints) compact(render(Extraction.decompose(value))) } override def jsonDecode(json: String): EvalTrait = { - implicit val formats = DefaultFormats + implicit val formats = DefaultFormats.withHints(CustomEvalParam.typeHints) parse(json).extract[EvalTrait] } } +object CustomEvalParam { + var typeHints: TypeHints = NoTypeHints + + final def addTypeHints(value: TypeHints): Unit = { + typeHints = typeHints + value + } + + final def addShortTypeHint(value: Class[_]): Unit = { + typeHints = typeHints + ShortTypeHints(List(value)) + } +} + class CustomObjParam( parent: Params, name: String, @@ -53,17 +65,28 @@ class CustomObjParam( override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value) override def jsonEncode(value: ObjectiveTrait): String = { - import org.json4s.jackson.Serialization - implicit val formats = Serialization.formats(NoTypeHints) + implicit val formats = DefaultFormats.withHints(CustomObjParam.typeHints) compact(render(Extraction.decompose(value))) } override def jsonDecode(json: String): ObjectiveTrait = { - implicit val formats = DefaultFormats + implicit val formats = DefaultFormats.withHints(CustomObjParam.typeHints) parse(json).extract[ObjectiveTrait] } } +object CustomObjParam { + var typeHints: TypeHints = NoTypeHints + + final def addTypeHints(value: TypeHints): Unit = { + typeHints = typeHints + value + } + + final def addShortTypeHint(value: Class[_]): Unit = { + typeHints = typeHints + ShortTypeHints(List(value)) + } +} + class TrackerConfParam( parent: Params, name: String, diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala new file mode 100644 index 000000000000..106d73ecaccf --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala @@ -0,0 +1,89 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.java.XGBoostError +import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait} +import ml.dmlc.xgboost4j.scala.spark.params.CustomObjParam.addShortTypeHint +import org.apache.commons.logging.LogFactory +import org.json4s.ShortTypeHints +import scala.collection.mutable.ListBuffer + + +/** + * loglikelihood loss obj function + */ +case class CustomObj() extends ObjectiveTrait { + + val logger = LogFactory.getLog(classOf[CustomObj]) + + addShortTypeHint(classOf[CustomObj]) + + /** + * user define objective function, return gradient and second order gradient + * + * @param predicts untransformed margin predicts + * @param dtrain training data + * @return List with two float array, correspond to first order grad and second order grad + */ + override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix) + : List[Array[Float]] = { + val nrow = predicts.length + val gradients = new ListBuffer[Array[Float]] + var labels: Array[Float] = null + try { + labels = dtrain.getLabel + } catch { + case e: XGBoostError => + logger.error(e) + null + case _: Throwable => + null + } + val grad = new Array[Float](nrow) + val hess = new Array[Float](nrow) + val transPredicts = transform(predicts) + + for (i <- 0 until nrow) { + val predict = transPredicts(i)(0) + grad(i) = predict - labels(i) + hess(i) = predict * (1 - predict) + } + gradients += grad + gradients += hess + gradients.toList + } + + /** + * simple sigmoid func + * + * @param input + * @return Note: this func is not concern about numerical stability, only used as example + */ + def sigmoid(input: Float): Float = { + (1 / (1 + Math.exp(-input))).toFloat + } + + def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = { + val nrow = predicts.length + val transPredicts = Array.fill[Float](nrow, 1)(0) + for (i <- 0 until nrow) { + transPredicts(i)(0) = sigmoid(predicts(i)(0)) + } + transPredicts + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index 91a840911a32..c1f22e8dfcc2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -18,12 +18,16 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} +import ml.dmlc.xgboost4j.scala.spark.params.CustomEvalParam.addShortTypeHint import org.apache.commons.logging.LogFactory +import org.json4s.ShortTypeHints -class EvalError extends EvalTrait { +case class EvalError() extends EvalTrait { val logger = LogFactory.getLog(classOf[EvalError]) + addShortTypeHint(classOf[EvalError]) + private[xgboost4j] var evalMetric: String = "custom_error" /** diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index ebe1d8546544..0ba396910121 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -25,6 +25,7 @@ import scala.util.Random import org.apache.spark.ml.feature._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.functions._ +import org.json4s.ShortTypeHints import org.scalatest.FunSuite class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { @@ -92,7 +93,6 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { } test("test persistence of MLlib pipeline with XGBoostClassificationModel") { - val r = new Random(0) // maybe move to shared context, but requires session to import implicits val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). @@ -133,6 +133,41 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol) } + test("test persistence of XGBoostClassifier and XGBoostClassificationModel " + + "using custom Obj and Eval") { + val eval = new EvalError() + val trainingDF = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "custom_obj" -> new CustomObj, "custom_eval" -> new EvalError, + "num_round" -> "10", "num_workers" -> numWorkers) + val xgb = new XGBoostClassifier(paramMap) + + val xgbc = new XGBoostClassifier(paramMap) + val xgbcPath = new File(tempDir.toFile, "xgbc").getPath + xgbc.write.overwrite().save(xgbcPath) + val xgbc2 = XGBoostClassifier.load(xgbcPath) + val paramMap2 = xgbc2.MLlib2XGBoostParams + paramMap.foreach { + case (k, v) => assert(v.toString == paramMap2(k).toString) + } + + val model = xgbc.fit(trainingDF) + val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults < 0.1) + val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath + model.write.overwrite.save(xgbcModelPath) + val model2 = XGBoostClassificationModel.load(xgbcModelPath) + assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) + + assert(model.getEta === model2.getEta) + assert(model.getNumRound === model2.getNumRound) + assert(model.getRawPredictionCol === model2.getRawPredictionCol) + val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults === evalResults2) + } + test("cross-version model loading (0.82)") { val modelPath = getClass.getResource("/model/0.82/model").getPath val model = XGBoostClassificationModel.read.load(modelPath)