Skip to content

Commit

Permalink
[jvm-packages] fix the persistence of XGBoostEstimator (#2265)
Browse files Browse the repository at this point in the history
* add back train method but mark as deprecated

* fix scalastyle error

* fix the persistence of XGBoostEstimator

* test persistence of a complete pipeline

* fix compilation issue

* do not allow persist custom_eval and custom_obj

* fix the failed tesl
  • Loading branch information
CodingCat committed May 9, 2017
1 parent 6bf968e commit 428453f
Show file tree
Hide file tree
Showing 12 changed files with 362 additions and 66 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,7 @@ build_tests
/tests/cpp/xgboost_test

.DS_Store
lib/
lib/

# spark
metastore_db
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,37 @@ package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path}

import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.SparseVector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.{SparkContext, TaskContext}

import scala.concurrent.duration.{Duration, MILLISECONDS}
import scala.concurrent.duration.{Duration, FiniteDuration, MILLISECONDS}

object TrackerConf {
def apply(): TrackerConf = TrackerConf(Duration.apply(0L, MILLISECONDS), "python")
def apply(): TrackerConf = TrackerConf(0L, "python")
}

/**
* Rabit tracker configurations.
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (supported by "scala" implementation only.)
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*/
case class TrackerConf(workerConnectionTimeout: Duration, trackerImpl: String)
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)

object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
Expand Down Expand Up @@ -240,14 +242,7 @@ object XGBoost extends Serializable {
case _ => new PyRabitTracker(nWorkers)
}

val connectionTimeout = if (trackerConf.workerConnectionTimeout.isFinite()) {
trackerConf.workerConnectionTimeout.toMillis
} else {
// 0 == Duration.Inf
0L
}

require(tracker.start(connectionTimeout), "FAULT: Failed to start tracker")
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
tracker
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable

import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, GeneralParams, LearningTaskParams}
import ml.dmlc.xgboost4j.scala.spark.params._
import org.json4s.DefaultFormats

import org.apache.spark.ml.Predictor
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector => MLVector}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{Dataset, Row}
Expand All @@ -34,7 +36,7 @@ import org.apache.spark.sql.{Dataset, Row}
class XGBoostEstimator private[spark](
override val uid: String, xgboostParams: Map[String, Any])
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel]
with LearningTaskParams with GeneralParams with BoosterParams {
with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {

def this(xgboostParams: Map[String, Any]) =
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
Expand Down Expand Up @@ -129,4 +131,38 @@ class XGBoostEstimator private[spark](
override def copy(extra: ParamMap): XGBoostEstimator = {
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
}

override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this)
}

object XGBoostEstimator extends MLReadable[XGBoostEstimator] {

override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader

override def load(path: String): XGBoostEstimator = super.load(path)

private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
require(instance.fromParamsToXGBParamMap("custom_eval") == null &&
instance.fromParamsToXGBParamMap("custom_obj") == null,
"we do not support persist XGBoostEstimator with customized evaluator and objective" +
" function for now")
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
}
}

private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] {

override def load(path: String): XGBoostEstimator = {
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
DefaultXGBoostParamsReader.getAndSetParams(instance, metadata)
instance.asInstanceOf[XGBoostEstimator]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,13 @@ object XGBoostModel extends MLReadable[XGBoostModel] {
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)

val dataPath = new Path(path, "data").toString
instance.saveModelAsHadoopFile(dataPath)
}
}

private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
private val className = classOf[XGBoostModel].getName

override def load(path: String): XGBoostModel = {
implicit val sc = super.sparkSession.sparkContext
val dataPath = new Path(path, "data").toString
Expand All @@ -340,5 +339,4 @@ object XGBoostModel extends MLReadable[XGBoostModel] {
XGBoost.loadModelFromHadoopFile(dataPath)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark.params

import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
import org.json4s.jackson.JsonMethods.{compact, parse, render}

import org.apache.spark.ml.param.{Param, ParamPair, Params}

class GroupDataParam(
parent: Params,
name: String,
doc: String) extends Param[Seq[Seq[Int]]](parent, name, doc) {

/** Creates a param pair with the given value (for Java). */
override def w(value: Seq[Seq[Int]]): ParamPair[Seq[Seq[Int]]] = super.w(value)

override def jsonEncode(value: Seq[Seq[Int]]): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}

override def jsonDecode(json: String): Seq[Seq[Int]] = {
implicit val formats = DefaultFormats
parse(json).extract[Seq[Seq[Int]]]
}
}

class CustomEvalParam(
parent: Params,
name: String,
doc: String) extends Param[EvalTrait](parent, name, doc) {

/** Creates a param pair with the given value (for Java). */
override def w(value: EvalTrait): ParamPair[EvalTrait] = super.w(value)

override def jsonEncode(value: EvalTrait): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}

override def jsonDecode(json: String): EvalTrait = {
implicit val formats = DefaultFormats
parse(json).extract[EvalTrait]
}
}

class CustomObjParam(
parent: Params,
name: String,
doc: String) extends Param[ObjectiveTrait](parent, name, doc) {

/** Creates a param pair with the given value (for Java). */
override def w(value: ObjectiveTrait): ParamPair[ObjectiveTrait] = super.w(value)

override def jsonEncode(value: ObjectiveTrait): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}

override def jsonDecode(json: String): ObjectiveTrait = {
implicit val formats = DefaultFormats
parse(json).extract[ObjectiveTrait]
}
}

class TrackerConfParam(
parent: Params,
name: String,
doc: String) extends Param[TrackerConf](parent, name, doc) {

/** Creates a param pair with the given value (for Java). */
override def w(value: TrackerConf): ParamPair[TrackerConf] = super.w(value)

override def jsonEncode(value: TrackerConf): String = {
import org.json4s.jackson.Serialization
implicit val formats = Serialization.formats(NoTypeHints)
compact(render(Extraction.decompose(value)))
}

override def jsonDecode(json: String): TrackerConf = {
implicit val formats = DefaultFormats
val parsedValue = parse(json)
println(parsedValue.children)
parsedValue.extract[TrackerConf]
}
}
Loading

0 comments on commit 428453f

Please sign in to comment.