Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[jvm-packages] Integration with Spark Dataframe/Dataset (#1559)
* bump up to scala 2.11 * framework of data frame integration * test consistency between RDD and DataFrame * order preservation * test order preservation * example code and fix makefile * improve type checking * improve APIs * user docs * work around travis CI's limitation on log length * adjust test structure * integrate with Spark -1 .x * spark 2.x integration * remove spark 1.x implementation but provide instructions on how to downgrade
- Loading branch information
Showing
15 changed files
with
623 additions
and
137 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,3 +79,5 @@ tags | |
*.class | ||
target | ||
*.swp | ||
|
||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
...t4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
Copyright (c) 2014 by Contributors | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package ml.dmlc.xgboost4j.scala.example.spark | ||
|
||
import ml.dmlc.xgboost4j.scala.Booster | ||
import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils} | ||
import org.apache.spark.mllib.util.MLUtils | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.{SQLContext, Row} | ||
import org.apache.spark.{SparkContext, SparkConf} | ||
|
||
object SparkWithDataFrame { | ||
def main(args: Array[String]): Unit = { | ||
if (args.length != 5) { | ||
println( | ||
"usage: program num_of_rounds num_workers training_path test_path model_path") | ||
sys.exit(1) | ||
} | ||
// create SparkSession | ||
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example") | ||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") | ||
sparkConf.registerKryoClasses(Array(classOf[Booster])) | ||
val sqlContext = new SQLContext(new SparkContext(sparkConf)) | ||
// create training and testing dataframes | ||
val inputTrainPath = args(2) | ||
val inputTestPath = args(3) | ||
val outputModelPath = args(4) | ||
// number of iterations | ||
val numRound = args(0).toInt | ||
import DataUtils._ | ||
val trainRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTrainPath). | ||
map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)} | ||
val trainDF = sqlContext.createDataFrame(trainRDDOfRows, StructType( | ||
Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType)))) | ||
val testRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTestPath). | ||
zipWithIndex().map{ case (labeledPoint, id) => | ||
Row(id, labeledPoint.features, labeledPoint.label)} | ||
val testDF = sqlContext.createDataFrame(testRDDOfRows, StructType( | ||
Array(StructField("id", LongType), | ||
StructField("features", ArrayType(FloatType)), StructField("label", IntegerType)))) | ||
// training parameters | ||
val paramMap = List( | ||
"eta" -> 0.1f, | ||
"max_depth" -> 2, | ||
"objective" -> "binary:logistic").toMap | ||
val xgboostModel = XGBoost.trainWithDataFrame( | ||
trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true) | ||
// xgboost-spark appends the column containing prediction results | ||
xgboostModel.transform(testDF).show() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
...kages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* | ||
Copyright (c) 2014 by Contributors | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package ml.dmlc.xgboost4j.scala.spark | ||
|
||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} | ||
import org.apache.spark.ml.{Predictor, Estimator} | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.mllib.linalg.{VectorUDT, Vector} | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types.{NumericType, DoubleType, StructType} | ||
import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row} | ||
|
||
/** | ||
* the estimator wrapping XGBoost to produce a training model | ||
* | ||
* @param inputCol the name of input column | ||
* @param labelCol the name of label column | ||
* @param xgboostParams the parameters configuring XGBoost | ||
* @param round the number of iterations to train | ||
* @param nWorkers the total number of workers of xgboost | ||
* @param obj the customized objective function, default to be null and using the default in model | ||
* @param eval the customized eval function, default to be null and using the default in model | ||
* @param useExternalMemory whether to use external memory when training | ||
* @param missing the value taken as missing | ||
*/ | ||
class XGBoostEstimator( | ||
inputCol: String, labelCol: String, | ||
xgboostParams: Map[String, Any], round: Int, nWorkers: Int, | ||
obj: ObjectiveTrait = null, | ||
eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN) | ||
extends Estimator[XGBoostModel] { | ||
|
||
override val uid: String = Identifiable.randomUID("XGBoostEstimator") | ||
|
||
|
||
/** | ||
* produce a XGBoostModel by fitting the given dataset | ||
*/ | ||
def fit(trainingSet: Dataset[_]): XGBoostModel = { | ||
val instances = trainingSet.select( | ||
col(inputCol), col(labelCol).cast(DoubleType)).rdd.map { | ||
case Row(feature: Vector, label: Double) => | ||
LabeledPoint(label, feature) | ||
} | ||
transformSchema(trainingSet.schema, logging = true) | ||
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, round, nWorkers, obj, | ||
eval, useExternalMemory, missing).setParent(this) | ||
copyValues(trainedModel) | ||
} | ||
|
||
override def copy(extra: ParamMap): Estimator[XGBoostModel] = { | ||
defaultCopy(extra) | ||
} | ||
|
||
override def transformSchema(schema: StructType): StructType = { | ||
// check input type, for now we only support vectorUDT as the input feature type | ||
val inputType = schema(inputCol).dataType | ||
require(inputType.equals(new VectorUDT), s"the type of input column $inputCol has to VectorUDT") | ||
// check label Type, | ||
val labelType = schema(labelCol).dataType | ||
require(labelType.isInstanceOf[NumericType], s"the type of label column $labelCol has to" + | ||
s" be NumericType") | ||
schema | ||
} | ||
} |
Oops, something went wrong.
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CodingCat @tqchen - https://groups.google.com/forum/#!topic/xgboost-user/gwt8ozwdZiE
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if possible, please keep the dev-related discussion in github...
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as mentioned in PR, this will be included in another PR
please feel free to maintain spark 1.6 version, since Spark 2.0 has broken the backward compatibility to Spark 1.x, we have to choose only one of them
if you once tried that you will find scala does not allow to overload methods with default parameters
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest that you can merge your params related part with the master
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tanwanirahul , if you decide to contribute to params implementation, please let me know ASAP, otherwise, I will start working on it after tmr
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tanwanirahul , for your question about release management,
JVM APIs are subject to change in recent releases, basically, jvm-packages 0.5 is released in xgboost 0.6 because we skipped XGBoost 0.5....
There will be the incompatible change from 0.6 to 0.7 due to the reasons I mentioned above
Sorry for any inconvenience
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CodingCat Thanks for the explanation. I will push the params related change by tomorrow. Also, let me see if I can get it working on Spark 1.6 and maintain the different branch for it.
Backward incompatible changes - XGBoost.train has been removed and is substituted by trainWithRDD and trainWithDataFrame. We should rather keep the train method as is and add the overloaded method for the DataSet.
-- if you once tried that you will find scala does not allow to overload methods with default parameters
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tanwanirahul for now I do not think maintaining different branches in XGBooost branch is the adopted release strategy. If you like, you can maintain a version in a different repository
fb02797
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be very helpful for maintaining a Spark 1.6 branch as many of the enterprise software stacks still use Spark 1.6. Tx