diff --git a/.gitignore b/.gitignore index 4be5f8c1e1a9..ed70379b87b6 100644 --- a/.gitignore +++ b/.gitignore @@ -88,4 +88,7 @@ build_tests /tests/cpp/xgboost_test .DS_Store -lib/ \ No newline at end of file +lib/ + +# spark +metastore_db \ No newline at end of file diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index d9dbcd543160..cce60043d795 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -18,21 +18,22 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.mutable import scala.collection.mutable.ListBuffer + import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.{FSDataInputStream, Path} + import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.ml.linalg.SparseVector import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.{SparkContext, TaskContext} - -import scala.concurrent.duration.{Duration, MILLISECONDS} +import scala.concurrent.duration.{Duration, FiniteDuration, MILLISECONDS} object TrackerConf { - def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python") + def apply(): TrackerConf = TrackerConf(0L, "python") } /** @@ -40,13 +41,14 @@ object TrackerConf { * @param workerConnectionTimeout The timeout for all workers to connect to the tracker. * Set timeout length to zero to disable timeout. * Use a finite, non-zero timeout value to prevent tracker from - * hanging indefinitely (supported by "scala" implementation only.) + * hanging indefinitely (in milliseconds) + * (supported by "scala" implementation only.) * @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of * the Python Rabit tracker (in dmlc_core), whereas the latter is implemented * in Scala without Python components, and with full support of timeouts. * The Scala implementation is currently experimental, use at your own risk. */ -case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String) +case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String) object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") @@ -240,14 +242,7 @@ object XGBoost extends Serializable { case _ => new PyRabitTracker(nWorkers) } - val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) { - trackerConf.workerConnectionTimeout.toMillis - } else { - // 0 == Duration.Inf - 0L - } - - require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker") + require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker") tracker } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index a5bbdb60ca22..ba7fe109899d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -18,12 +18,14 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.mutable -import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, GeneralParams, LearningTaskParams} +import ml.dmlc.xgboost4j.scala.spark.params._ +import org.json4s.DefaultFormats + import org.apache.spark.ml.Predictor import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector => MLVector} import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.{Dataset, Row} @@ -34,7 +36,7 @@ import org.apache.spark.sql.{Dataset, Row} class XGBoostEstimator private[spark]( override val uid: String, xgboostParams: Map[String, Any]) extends Predictor[MLVector, XGBoostEstimator, XGBoostModel] - with LearningTaskParams with GeneralParams with BoosterParams { + with LearningTaskParams with GeneralParams with BoosterParams with MLWritable { def this(xgboostParams: Map[String, Any]) = this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any]) @@ -129,4 +131,38 @@ class XGBoostEstimator private[spark]( override def copy(extra: ParamMap): XGBoostEstimator = { defaultCopy(extra).asInstanceOf[XGBoostEstimator] } + + override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this) +} + +object XGBoostEstimator extends MLReadable[XGBoostEstimator] { + + override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader + + override def load(path: String): XGBoostEstimator = super.load(path) + + private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator) + extends MLWriter { + override protected def saveImpl(path: String): Unit = { + require(instance.fromParamsToXGBParamMap("custom_eval") == null && + instance.fromParamsToXGBParamMap("custom_obj") == null, + "we do not support persist XGBoostEstimator with customized evaluator and objective" + + " function for now") + implicit val format = DefaultFormats + implicit val sc = super.sparkSession.sparkContext + DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc) + } + } + + private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] { + + override def load(path: String): XGBoostEstimator = { + val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + val instance = + cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] + DefaultXGBoostParamsReader.getAndSetParams(instance, metadata) + instance.asInstanceOf[XGBoostEstimator] + } + } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 2731b9dd91d5..fe281bff6959 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -324,14 +324,13 @@ object XGBoostModel extends MLReadable[XGBoostModel] { implicit val format = DefaultFormats implicit val sc = super.sparkSession.sparkContext DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc) - val dataPath = new Path(path, "data").toString instance.saveModelAsHadoopFile(dataPath) } } private class XGBoostModelModelReader extends MLReader[XGBoostModel] { - private val className = classOf[XGBoostModel].getName + override def load(path: String): XGBoostModel = { implicit val sc = super.sparkSession.sparkContext val dataPath = new Path(path, "data").toString @@ -340,5 +339,4 @@ object XGBoostModel extends MLReadable[XGBoostModel] { XGBoost.loadModelFromHadoopFile(dataPath) } } - } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala new file mode 100644 index 000000000000..0a411054a29d --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala @@ -0,0 +1,106 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark.params + +import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} +import ml.dmlc.xgboost4j.scala.spark.TrackerConf +import org.json4s.{DefaultFormats, Extraction, NoTypeHints} +import org.json4s.jackson.JsonMethods.{compact, parse, render} + +import org.apache.spark.ml.param.{Param, ParamPair, Params} + +class GroupDataParam( + parent: Params, + name: String, + doc: String) extends Param[Seq[Seq[Int]]](parent, name, doc) { + + /** Creates a param pair with the given value (for Java). */ + override def w(value: Seq[Seq[Int]]): ParamPair[Seq[Seq[Int]]] = super.w(value) + + override def jsonEncode(value: Seq[Seq[Int]]): String = { + import org.json4s.jackson.Serialization + implicit val formats = Serialization.formats(NoTypeHints) + compact(render(Extraction.decompose(value))) + } + + override def jsonDecode(json: String): Seq[Seq[Int]] = { + implicit val formats = DefaultFormats + parse(json).extract[Seq[Seq[Int]]] + } +} + +class CustomEvalParam( + parent: Params, + name: String, + doc: String) extends Param[EvalTrait](parent, name, doc) { + + /** Creates a param pair with the given value (for Java). */ + override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value) + + override def jsonEncode(value: EvalTrait): String = { + import org.json4s.jackson.Serialization + implicit val formats = Serialization.formats(NoTypeHints) + compact(render(Extraction.decompose(value))) + } + + override def jsonDecode(json: String): EvalTrait = { + implicit val formats = DefaultFormats + parse(json).extract[EvalTrait] + } +} + +class CustomObjParam( + parent: Params, + name: String, + doc: String) extends Param[ObjectiveTrait](parent, name, doc) { + + /** Creates a param pair with the given value (for Java). */ + override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value) + + override def jsonEncode(value: ObjectiveTrait): String = { + import org.json4s.jackson.Serialization + implicit val formats = Serialization.formats(NoTypeHints) + compact(render(Extraction.decompose(value))) + } + + override def jsonDecode(json: String): ObjectiveTrait = { + implicit val formats = DefaultFormats + parse(json).extract[ObjectiveTrait] + } +} + +class TrackerConfParam( + parent: Params, + name: String, + doc: String) extends Param[TrackerConf](parent, name, doc) { + + /** Creates a param pair with the given value (for Java). */ + override def w(value: TrackerConf): ParamPair[TrackerConf] = super.w(value) + + override def jsonEncode(value: TrackerConf): String = { + import org.json4s.jackson.Serialization + implicit val formats = Serialization.formats(NoTypeHints) + compact(render(Extraction.decompose(value))) + } + + override def jsonDecode(json: String): TrackerConf = { + implicit val formats = DefaultFormats + val parsedValue = parse(json) + println(parsedValue.children) + parsedValue.extract[TrackerConf] + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala new file mode 100644 index 000000000000..b79d5b694345 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala @@ -0,0 +1,136 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark.params + +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JValue} +import org.json4s.JsonAST.JObject +import org.json4s.jackson.JsonMethods.{compact, parse, render} + +import org.apache.spark.SparkContext +import org.apache.spark.ml.param.Params +import org.apache.spark.ml.util.MLReader + +// This originates from apache-spark DefaultPramsReader copy paste +private[spark] object DefaultXGBoostParamsReader { + + /** + * All info from metadata file. + * + * @param params paramMap, as a `JValue` + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) + */ + case class Metadata( + className: String, + uid: String, + timestamp: Long, + sparkVersion: String, + params: JValue, + metadata: JValue, + metadataJson: String) { + + /** + * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. + * This can be useful for getting a Param value before an instance of `Params` + * is available. + */ + def getParamValue(paramName: String): JValue = { + implicit val format = DefaultFormats + params match { + case JObject(pairs) => + val values = pairs.filter { case (pName, jsonValue) => + pName == paramName + }.map(_._2) + assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + values.head + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + } + + /** + * Load metadata saved using [[DefaultXGBoostParamsWriter.saveMetadata()]] + * + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + parseMetadata(metadataStr, expectedClassName) + } + + /** + * Parse metadata JSON string produced by [[DefaultXGBoostParamsWriter.getMetadataToSave()]]. + * This is a helper function for [[loadMetadata()]]. + * + * @param metadataStr JSON string of metadata + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = { + val metadata = parse(metadataStr) + + implicit val format = DefaultFormats + val className = (metadata \ "class").extract[String] + val uid = (metadata \ "uid").extract[String] + val timestamp = (metadata \ "timestamp").extract[Long] + val sparkVersion = (metadata \ "sparkVersion").extract[String] + val params = metadata \ "paramMap" + if (expectedClassName.nonEmpty) { + require(className == expectedClassName, s"Error loading metadata: Expected class name" + + s" $expectedClassName but found class name $className") + } + + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * TODO: Move to [[Metadata]] method + */ + def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + implicit val format = DefaultFormats + metadata.params match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") + } + } + + /** + * Load a `Params` instance from the given path, and return it. + * This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } +} + diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala index 22ab8885432d..acf6815ecaae 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsWriter.scala @@ -46,6 +46,7 @@ private[spark] object DefaultXGBoostParamsWriter { sc: SparkContext, extraMetadata: Option[JObject] = None, paramMap: Option[JValue] = None): Unit = { + val metadataPath = new Path(path, "metadata").toString val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -65,7 +66,9 @@ private[spark] object DefaultXGBoostParamsWriter { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] - val jsonParams = paramMap.getOrElse(render(params.map { + val jsonParams = paramMap.getOrElse(render(params.filter{ + case ParamPair(p, _) => p != null + }.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 212daadbcb34..af14ce43ccc8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -20,8 +20,6 @@ import ml.dmlc.xgboost4j.scala.spark.TrackerConf import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import org.apache.spark.ml.param._ -import scala.concurrent.duration.{Duration, NANOSECONDS} - trait GeneralParams extends Params { /** @@ -58,13 +56,13 @@ trait GeneralParams extends Params { /** * customized objective function provided by user. default: null */ - val customObj = new Param[ObjectiveTrait](this, "custom_obj", "customized objective function " + + val customObj = new CustomObjParam(this, "custom_obj", "customized objective function " + "provided by user") /** * customized evaluation function provided by user. default: null */ - val customEval = new Param[EvalTrait](this, "custom_eval", "customized evaluation function " + + val customEval = new CustomEvalParam(this, "custom_eval", "customized evaluation function " + "provided by user") /** @@ -99,7 +97,7 @@ trait GeneralParams extends Params { * Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf). * Ignored if the tracker implementation is "python". */ - val trackerConf = new Param[TrackerConf](this, "tracker_conf", "Rabit tracker configurations") + val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations") setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1, useExternalMemory -> false, silent -> 0, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index b02eecc433d4..722f7d079369 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -57,7 +57,7 @@ trait LearningTaskParams extends Params { * group data specify each group sizes for ranking task. To correspond to partition of * training data, it is nested. */ - val groupData = new Param[Seq[Seq[Int]]](this, "groupData", "group data specify each group size" + + val groupData = new GroupDataParam(this, "groupData", "group data specify each group size" + " for ranking task. To correspond to partition of training data, it is nested.") setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index 01eaca737997..2bab6028a233 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -18,17 +18,21 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File +import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.io.Source +import scala.util.Random import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.SparkContext import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.{LabeledPoint, VectorAssembler} import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.sql._ class XGBoostDFSuite extends SharedSparkContext with Utils { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 29cbf5c474e0..654fd1e9df98 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -110,7 +110,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", - "tracker_conf" -> TrackerConf(1 minute, "scala")).toMap + "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers) assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala index 07a289528bc9..5d0a0c7550cd 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala @@ -18,67 +18,84 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.{File, FileNotFoundException} +import scala.util.Random + import org.apache.spark.SparkConf import org.apache.spark.ml.feature._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.SparkSession -import scala.concurrent.duration._ - -case class Foobar(TARGET: Int, bar: Double, baz: Double) class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils { override def afterAll(): Unit = { super.afterAll() delete(new File("./testxgbPipe")) - delete(new File("./test2xgbPipe")) + delete(new File("./testxgbEst")) + delete(new File("./testxgbModel")) + delete(new File("./test2xgbModel")) } private def delete(f: File) { - if (f.isDirectory()) { - for (c <- f.listFiles()) { - delete(c) + if (f.exists()) { + if (f.isDirectory()) { + for (c <- f.listFiles()) { + delete(c) + } + } + if (!f.delete()) { + throw new FileNotFoundException("Failed to delete file: " + f) } - } - if (!f.delete()) { - throw new FileNotFoundException("Failed to delete file: " + f) } } - test("test sparks pipeline persistence of dataframe-based model") { - // maybe move to shared context, but requires session to import implicits. - // what about introducing https://github.com/holdenk/spark-testing-base ? - val conf: SparkConf = new SparkConf() - .setAppName("foo") - .setMaster("local[*]") + test("test persistence of XGBoostEstimator") { + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "multi:softmax", "num_class" -> "6") + val xgbEstimator = new XGBoostEstimator(paramMap) + xgbEstimator.write.overwrite().save("./testxgbEst") + val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst") + val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap + paramMap.foreach { + case (k, v) => assert(v == loadedParamMap(k).toString) + } + } - val spark: SparkSession = SparkSession - .builder() - .config(conf) - .getOrCreate() + test("test persistence of a complete pipeline") { + val conf = new SparkConf().setAppName("foo").setMaster("local[*]") + val spark = SparkSession.builder().config(conf).getOrCreate() + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "multi:softmax", "num_class" -> "6") + val r = new Random(0) + val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features") + val xgbEstimator = new XGBoostEstimator(paramMap) + val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator)) + pipeline.write.overwrite().save("testxgbPipe") + val loadedPipeline = Pipeline.read.load("testxgbPipe") + val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator] + val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap + paramMap.foreach { + case (k, v) => assert(v == loadedParamMap(k).toString) + } + } - import spark.implicits._ + test("test persistence of XGBoostModel") { + val conf = new SparkConf().setAppName("foo").setMaster("local[*]") + val spark = SparkSession.builder().config(conf).getOrCreate() + val r = new Random(0) // maybe move to shared context, but requires session to import implicits - - val df = Seq(Foobar(0, 0.5, 1), Foobar(1, 0.01, 0.8), - Foobar(0, 0.8, 0.5), Foobar(1, 8.4, 0.04)) - .toDS - + val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))). + toDF("feature", "label") val vectorAssembler = new VectorAssembler() .setInputCols(df.columns - .filter(!_.contains("TARGET"))) + .filter(!_.contains("label"))) .setOutputCol("features") - val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10, - "tracker_conf" -> TrackerConf(1 minute, "scala") - )) - .setFeaturesCol("features") - .setLabelCol("TARGET") - + "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala") + )).setFeaturesCol("features").setLabelCol("label") // separate val predModel = xgbEstimator.fit(vectorAssembler.transform(df)) - predModel.write.overwrite.save("test2xgbPipe") - val same2Model = XGBoostModel.load("test2xgbPipe") + predModel.write.overwrite.save("test2xgbModel") + val same2Model = XGBoostModel.load("test2xgbModel") assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray)) val predParamMap = predModel.extractParamMap() @@ -93,8 +110,8 @@ class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils { // chained val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df) - predictionModel.write.overwrite.save("testxgbPipe") - val sameModel = PipelineModel.load("testxgbPipe") + predictionModel.write.overwrite.save("testxgbModel") + val sameModel = PipelineModel.load("testxgbModel") val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head