Skip to content

Commit

Permalink
[jvm-packages] Spark pipeline persistence (#1906)
Browse files Browse the repository at this point in the history
[jvm-packages] Spark pipeline persistence
  • Loading branch information
geoHeil authored and CodingCat committed Mar 6, 2017
1 parent 5b54b94 commit cf6b173
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 8 deletions.
Expand Up @@ -17,10 +17,9 @@
package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable

import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
import org.apache.spark.ml.param.{DoubleArrayParam, Param, ParamMap}
import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, Param, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand All @@ -43,7 +42,7 @@ class XGBoostClassificationModel private[spark](
/**
* whether to output raw margin
*/
final val outputMargin: Param[Boolean] = new Param[Boolean](this, "outputMargin", "whether to output untransformed margin value ")
final val outputMargin = new BooleanParam(this, "outputMargin", "whether to output untransformed margin value")

setDefault(outputMargin, false)

Expand Down
Expand Up @@ -17,30 +17,32 @@
package ml.dmlc.xgboost4j.scala.spark

import scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit}
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.spark.params.DefaultXGBoostParamsWriter
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.apache.spark.ml.PredictionModel
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.ml.param.{BooleanParam, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.types.{ArrayType, FloatType}
import org.apache.spark.{SparkContext, TaskContext}
import org.json4s.DefaultFormats

/**
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
*/
abstract class XGBoostModel(protected var _booster: Booster)
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params {
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params with MLWritable {

def setLabelCol(name: String): XGBoostModel = set(labelCol, name)

// scalastyle:off

final val useExternalMemory: Param[Boolean] = new Param[Boolean](this, "useExternalMemory", "whether to use external memory for prediction")
final val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external memory for prediction")

setDefault(useExternalMemory, false)

Expand Down Expand Up @@ -295,4 +297,38 @@ abstract class XGBoostModel(protected var _booster: Booster)
}

def booster: Booster = _booster

override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)

override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)
}

object XGBoostModel extends MLReadable[XGBoostModel] {

override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader

override def load(path: String): XGBoostModel = super.load(path)

private[XGBoostModel] class XGBoostModelModelWriter(instance: XGBoostModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)

val dataPath = new Path(path, "data").toString
instance.saveModelAsHadoopFile(dataPath)
}
}

private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
private val className = classOf[XGBoostModel].getName
override def load(path: String): XGBoostModel = {
implicit val sc = super.sparkSession.sparkContext
val dataPath = new Path(path, "data").toString
// not used / all data resides in platform independent xgboost model file
// val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
XGBoost.loadModelFromHadoopFile(dataPath)
}
}

}
@@ -0,0 +1,86 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark.params

import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.param.{ParamPair, Params}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.json4s.{JObject, _}

// This originates from apache-spark DefaultPramsWriter copy paste
private[spark] object DefaultXGBoostParamsWriter {

/**
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - paramMap
* - (optionally, extra metadata)
*
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val metadataPath = new Path(path, "metadata").toString
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
}

/**
* Helper for [[saveMetadata()]] which extracts the JSON to save.
* This is useful for ensemble models which need to save metadata for many sub-models.
*
* @see [[saveMetadata()]] for details on what this includes.
*/
def getMetadataToSave(
instance: Params,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.map {
case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
basicMetadata
}
val metadataJson: String = compact(render(metadata))
metadataJson
}
}
@@ -0,0 +1,33 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark.params

// based on org.apache.spark.util copy /paste
private[spark] object Utils {

def getSparkClassLoader: ClassLoader = getClass.getClassLoader

def getContextOrSparkClassLoader: ClassLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)

// scalastyle:off classforname
/** Preferred alternative to Class.forName(className) */
def classForName(className: String): Class[_] = {
Class.forName(className, true, getContextOrSparkClassLoader)
// scalastyle:on classforname
}
}
@@ -0,0 +1,101 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark

import org.apache.spark.SparkConf
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession

import scala.concurrent.duration._

case class Foobar(TARGET: Int, bar: Double, baz: Double)

class XGBoostSparkPipelinePersistence extends SharedSparkContext with Utils {
test("test sparks pipeline persistence of dataframe-based model") {
// maybe move to shared context, but requires session to import implicits.
// what about introducing https://github.com/holdenk/spark-testing-base ?
val conf: SparkConf = new SparkConf()
.setAppName("foo")
.setMaster("local[*]")

val spark: SparkSession = SparkSession
.builder()
.config(conf)
.getOrCreate()

import spark.implicits._
// maybe move to shared context, but requires session to import implicits

val df = Seq(Foobar(0, 0.5, 1), Foobar(1, 0.01, 0.8),
Foobar(0, 0.8, 0.5), Foobar(1, 8.4, 0.04))
.toDS

val vectorAssembler = new VectorAssembler()
.setInputCols(df.columns
.filter(!_.contains("TARGET")))
.setOutputCol("features")

val xgbEstimator = new XGBoostEstimator(Map("num_rounds" -> 10,
"tracker_conf" -> TrackerConf(1 minute, "scala")
))
.setFeaturesCol("features")
.setLabelCol("TARGET")

// separate
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
predModel.write.overwrite.save("test2xgbPipe")
val same2Model = XGBoostModel.load("test2xgbPipe")

assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
val predParamMap = predModel.extractParamMap()
val same2ParamMap = same2Model.extractParamMap()
assert(predParamMap.get(predModel.useExternalMemory)
=== same2ParamMap.get(same2Model.useExternalMemory))
assert(predParamMap.get(predModel.featuresCol) === same2ParamMap.get(same2Model.featuresCol))
assert(predParamMap.get(predModel.predictionCol)
=== same2ParamMap.get(same2Model.predictionCol))
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))

// chained
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
predictionModel.write.overwrite.save("testxgbPipe")
val sameModel = PipelineModel.load("testxgbPipe")

val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head

assert(java.util.Arrays.equals(
predictionModelXGB.booster.toByteArray,
sameModelXGB.booster.toByteArray
))
val predictionModelXGBParamMap = predictionModel.extractParamMap()
val sameModelXGBParamMap = sameModel.extractParamMap()
assert(predictionModelXGBParamMap.get(predictionModelXGB.useExternalMemory)
=== sameModelXGBParamMap.get(sameModelXGB.useExternalMemory))
assert(predictionModelXGBParamMap.get(predictionModelXGB.featuresCol)
=== sameModelXGBParamMap.get(sameModelXGB.featuresCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.predictionCol)
=== sameModelXGBParamMap.get(sameModelXGB.predictionCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
}
}

0 comments on commit cf6b173

Please sign in to comment.