Skip to content

Commit

Permalink
[jvm-packages] call setGroup for ranking task (#2066)
Browse files Browse the repository at this point in the history
* [jvm-packages] call setGroup for ranking task

* passing groupData through xgBoostConfMap

* fix original comment position

* make groupData param

* remove groupData variable, use xgBoostConfMap directly

* set default groupData value

* add use groupData tests

* reduce rank-demo size

* use TaskContext.getPartitionId() instead of mapPartitionsWithIndex

* add DF use groupData test

* remove unused varable
  • Loading branch information
cloverrose authored and CodingCat committed Mar 6, 2017
1 parent cf6b173 commit 288f309
Show file tree
Hide file tree
Showing 10 changed files with 334 additions and 1 deletion.
Expand Up @@ -123,6 +123,12 @@ object XGBoost extends Serializable {
}
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
if (xgBoostConfMap.isDefinedAt("groupData")
&& xgBoostConfMap.get("groupData").get != null) {
trainingSet.setGroup(
xgBoostConfMap.get("groupData").get.asInstanceOf[Seq[Seq[Int]]](
TaskContext.getPartitionId()).toArray)
}
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
watches = new mutable.HashMap[String, DMatrix] {
put("train", trainingSet)
Expand Down
Expand Up @@ -53,7 +53,14 @@ trait LearningTaskParams extends Params {
s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))

setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2)
/**
* group data specify each group sizes for ranking task. To correspond to partition of
* training data, it is nested.
*/
val groupData = new Param[Seq[Seq[Int]]](this, "groupData", "group data specify each group size" +
" for ranking task. To correspond to partition of training data, it is nested.")

setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null)
}

private[spark] object LearningTaskParams {
Expand Down
@@ -0,0 +1,75 @@
0 1:985.574005058 2:320.223538037 3:0.621236086198
0 1:1010.52917943 2:635.535543082 3:2.14984030531
0 1:1012.91900422 2:132.387300057 3:0.488761066665
0 1:990.829194034 2:135.102081162 3:0.747701610673
0 1:1007.05103629 2:154.289183562 3:0.464118249201
0 1:994.9573036 2:317.483732878 3:0.0313685555674
0 1:987.8071541 2:731.349178363 3:0.244616944245
1 1:10.0349544469 2:2.29750906143 3:36.4949974282
0 1:9.92953881383 2:5.39134047297 3:120.041297548
0 1:10.0909866713 2:9.06191026312 3:138.807825798
1 1:10.2090970614 2:0.0784495944448 3:58.207703565
0 1:9.85695905893 2:9.99500727713 3:56.8610243778
1 1:10.0805758547 2:0.0410805760559 3:222.102302076
0 1:10.1209914486 2:9.9729127088 3:171.888238763
0 1:10.0331939798 2:0.853339303793 3:311.181328375
0 1:9.93901762951 2:2.72757449146 3:78.4859514413
0 1:10.0752365346 2:9.18695328235 3:49.8520256553
1 1:10.0456548902 2:0.270936043122 3:123.462958597
0 1:10.0568923673 2:0.82997113263 3:44.9391426001
0 1:9.8214143472 2:0.277538931578 3:15.4217659578
0 1:9.95258604431 2:8.69564346094 3:255.513470671
0 1:9.91934976357 2:7.72809741413 3:82.171591817
0 1:10.043239582 2:8.64168255553 3:38.9657919329
1 1:10.0236147929 2:0.0496662263659 3:4.40889812286
1 1:1001.85585324 2:3.75646886071 3:0.0179224994842
0 1:1014.25578571 2:0.285765311201 3:0.510329864983
1 1:1002.81422786 2:9.77676280375 3:0.433705951912
1 1:998.072711553 2:2.82100686538 3:0.889829076909
0 1:1003.77395036 2:2.55916592114 3:0.0359402151496
1 1:10.0807877782 2:4.98513959013 3:47.5266363559
0 1:10.0015013081 2:9.94302478763 3:78.3697486277
1 1:10.0441936789 2:0.305091816635 3:56.8213984987
0 1:9.94257106618 2:7.23909568913 3:442.463339039
1 1:9.86479307916 2:6.41701315844 3:55.1365304834
0 1:10.0428628516 2:9.98466447697 3:0.391632812588
0 1:9.94445884566 2:9.99970945878 3:260.438436534
1 1:9.84641392823 2:225.78051312 3:1.00525978847
1 1:9.86907690608 2:26.8971083147 3:0.577959255991
0 1:10.0177314626 2:0.110585342313 3:2.30545043031
0 1:10.0688190907 2:412.023866234 3:1.22421542264
0 1:10.1251769646 2:13.8212202925 3:0.129171734504
0 1:10.0840758802 2:407.359097187 3:0.477000870705
0 1:10.1007458705 2:987.183625145 3:0.149385677415
0 1:9.86472656059 2:169.559640615 3:0.147221652519
0 1:9.94207419238 2:507.290053755 3:0.41996207214
0 1:9.9671005502 2:1.62610457716 3:0.408173666788
0 1:1010.57126596 2:9.06673707562 3:0.672092284372
0 1:1001.6718262 2:9.53203990055 3:4.7364050044
0 1:995.777341384 2:4.43847316256 3:2.07229073634
0 1:1002.95701386 2:5.51711016665 3:1.24294450546
0 1:1016.0988238 2:0.626468941906 3:0.105627919134
0 1:1013.67571419 2:0.042315529666 3:0.717619310322
1 1:994.747747892 2:6.01989364024 3:0.772910130015
1 1:991.654593872 2:7.35575736952 3:1.19822091548
0 1:1008.47101732 2:8.28240754909 3:0.229582481359
0 1:1000.81975227 2:1.52448354056 3:0.096441660362
0 1:10.0900922344 2:322.656649307 3:57.8149073088
1 1:10.0868337371 2:2.88652339174 3:54.8865514572
0 1:10.0988984137 2:979.483832657 3:52.6809830901
0 1:9.97678959238 2:665.770979738 3:481.069628909
0 1:9.78554312773 2:257.309358658 3:47.7324475232
0 1:10.0985967566 2:935.896512941 3:138.937052808
0 1:10.0522252319 2:876.376299607 3:6.00373510669
1 1:9.88065229501 2:9.99979825653 3:0.0674603696149
0 1:10.0483244098 2:0.0653852316381 3:0.130679349938
1 1:9.99685215607 2:1.76602542774 3:0.2551321159
0 1:9.99750159428 2:1.01591534436 3:0.145445506504
1 1:9.97380908941 2:0.940048645571 3:0.411805696316
0 1:9.99977678382 2:6.91329929641 3:5.57858201258
0 1:978.876096381 2:933.775364741 3:0.579170824236
0 1:998.381016406 2:220.940470582 3:2.01491778565
0 1:987.917644594 2:8.74667873567 3:0.364006099758
0 1:1000.20994892 2:25.2945450565 3:3.5684398964
0 1:1014.57141264 2:675.593540733 3:0.164174055535
0 1:998.867283535 2:765.452750642 3:0.818425293238
@@ -0,0 +1,10 @@
7
7
10
5
7
10
10
7
6
6
@@ -0,0 +1,74 @@
0 1:10.2143092481 2:273.576539531 3:137.111774354
0 1:10.0366658918 2:842.469052609 3:2.32134375927
0 1:10.1281202091 2:395.654057342 3:35.4184893063
0 1:10.1443721289 2:960.058461049 3:272.887070637
0 1:10.1353234784 2:535.51304462 3:2.15393842032
1 1:10.0451640374 2:216.733858424 3:55.6533298016
1 1:9.94254592171 2:44.5985537358 3:304.614176871
0 1:10.1319257181 2:613.545504487 3:5.42391587912
0 1:1020.63622468 2:997.476744201 3:0.509425590461
0 1:986.304585519 2:822.669937965 3:0.605133561808
1 1:1012.66863221 2:26.7185759069 3:0.0875458784828
0 1:995.387656321 2:81.8540176995 3:0.691999430068
0 1:1020.6587198 2:848.826964547 3:0.540159430526
1 1:1003.81573853 2:379.84350931 3:0.0083682925194
0 1:1021.60921516 2:641.376951467 3:1.12339054807
0 1:1000.17585041 2:122.107138713 3:1.09906375372
1 1:987.64802348 2:5.98448541152 3:0.124241987204
1 1:9.94610136583 2:346.114985897 3:0.387708236565
0 1:9.96812192337 2:313.278109696 3:0.00863026595671
0 1:10.0181739194 2:36.7378924562 3:2.92179879835
0 1:9.89000102695 2:164.273723971 3:0.685222591968
0 1:10.1555212436 2:320.451459462 3:2.01341536261
0 1:10.0085727613 2:999.767117646 3:0.462294934168
1 1:9.93099658724 2:5.17478203909 3:0.213855205032
0 1:10.0629454957 2:663.088181857 3:0.049022351462
0 1:10.1109732417 2:734.904569784 3:1.6998450094
0 1:1006.6015266 2:505.023453703 3:1.90870566777
0 1:991.865769489 2:245.437343115 3:0.475109744256
0 1:998.682734072 2:950.041057232 3:1.9256314201
0 1:1005.02207209 2:2.9619314197 3:0.0517146822357
0 1:1002.54526214 2:860.562681899 3:0.915687092848
0 1:1000.38847359 2:808.416525088 3:0.209690673808
1 1:992.557818382 2:373.889409453 3:0.107571728577
0 1:1002.07722137 2:997.329626371 3:1.06504260496
0 1:1000.40504333 2:949.832139189 3:0.539159980327
0 1:10.1460179902 2:8.86082969819 3:135.953842715
1 1:9.98529296553 2:2.87366448495 3:1.74249892194
0 1:9.88942676744 2:9.4031821056 3:149.473066381
1 1:10.0192953341 2:1.99685737576 3:1.79502473397
0 1:10.0110654379 2:8.13112593726 3:87.7765628103
0 1:997.148677047 2:733.936190093 3:1.49298494242
0 1:1008.70465919 2:957.121652078 3:0.217414013634
1 1:997.356154278 2:541.599587807 3:0.100855972216
0 1:999.615897283 2:943.700501824 3:0.862874175879
1 1:997.36859077 2:0.200859940848 3:0.13601892182
0 1:10.0423255624 2:1.73855202168 3:0.956695338485
1 1:9.88440755486 2:9.9994600678 3:0.305080529665
0 1:10.0891026412 2:3.28031719474 3:0.364450973697
0 1:9.90078644258 2:8.77839663617 3:0.456660574479
1 1:9.79380029711 2:8.77220326156 3:0.527292005175
0 1:9.93613887011 2:9.76270841268 3:1.40865693823
0 1:10.0009239007 2:7.29056178263 3:0.498015866607
0 1:9.96603319905 2:5.12498000925 3:0.517492532783
0 1:10.0923827222 2:2.76652583955 3:1.56571226159
1 1:10.0983782035 2:587.788120694 3:0.031756483687
1 1:9.91397225464 2:994.527496819 3:3.72092164978
0 1:10.1057472738 2:2.92894440088 3:0.683506438532
0 1:10.1014053354 2:959.082038017 3:1.07039624129
0 1:10.1433253044 2:322.515119317 3:0.51408278993
1 1:9.82832510699 2:637.104433908 3:0.250272776427
0 1:1000.49729075 2:2.75336888111 3:0.576634423274
1 1:984.90338088 2:0.0295435794035 3:1.26273339929
0 1:1001.53811442 2:4.64164410861 3:0.0293389959504
1 1:995.875898395 2:5.08223403205 3:0.382330566779
0 1:996.405937252 2:6.26395190757 3:0.453645816611
0 1:10.0165140779 2:340.126072514 3:0.220794603312
0 1:9.93482824816 2:951.672000448 3:0.124406293612
0 1:10.1700278554 2:0.0140985961008 3:0.252452256311
0 1:9.99825079542 2:950.382643896 3:0.875382402062
0 1:9.87316410028 2:686.788257829 3:0.215886999825
0 1:10.2893240654 2:89.3947931451 3:0.569578232133
0 1:9.98689192703 2:0.430107535413 3:2.99869831728
0 1:10.1365175107 2:972.279245093 3:0.0865099386744
0 1:9.90744703306 2:50.810461183 3:3.00863325197
@@ -0,0 +1,10 @@
8
9
9
9
5
5
9
6
5
9
66 changes: 66 additions & 0 deletions jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test
@@ -0,0 +1,66 @@
0 1:10.0229017899 2:7.30178495562 3:0.118115020017
0 1:9.93639621859 2:9.93102159291 3:0.0435030004396
0 1:10.1301737265 2:0.00411765220572 3:2.4165878053
1 1:9.87828587087 2:0.608588414992 3:0.111262590883
0 1:10.1373430048 2:0.47764012225 3:0.991553052194
0 1:10.0523814718 2:4.72152505167 3:0.672978832666
0 1:10.0449715742 2:8.40373928536 3:0.384457573667
1 1:996.398498791 2:941.976309154 3:0.230269231292
0 1:1005.11269468 2:900.093680877 3:0.265031528873
0 1:997.160349441 2:891.331101688 3:2.19362017313
0 1:993.754139031 2:44.8000165317 3:1.03868009875
1 1:994.831299184 2:241.959208453 3:0.667631827024
0 1:995.948333283 2:7.94326917112 3:0.750490877118
0 1:989.733981273 2:7.52077625436 3:0.0126335967282
0 1:1003.54086516 2:6.48177510564 3:1.19441696788
0 1:996.56177804 2:9.71959812613 3:1.33082465111
0 1:1005.61382467 2:0.234339369309 3:1.17987797356
1 1:980.215758708 2:6.85554542926 3:2.63965085259
1 1:987.776408872 2:2.23354609991 3:0.841885278028
0 1:1006.54260396 2:8.12142049834 3:2.26639471174
0 1:1009.87927639 2:6.40028519044 3:0.775155669615
0 1:9.95006244393 2:928.76896718 3:234.948458244
1 1:10.0749152258 2:255.294574476 3:62.9728604166
1 1:10.1916541988 2:312.682867085 3:92.299413677
0 1:9.95646724484 2:742.263188416 3:53.3310473654
0 1:9.86211293222 2:996.237023866 3:2.00760301168
1 1:9.91801019468 2:303.971783709 3:50.3147230679
0 1:996.983996934 2:9.52188222766 3:1.33588120981
0 1:995.704388126 2:9.49260524915 3:0.908498516541
0 1:987.86480767 2:0.0870786716821 3:0.108859297837
0 1:1000.99561307 2:2.85272694575 3:0.171134518956
0 1:1011.05508066 2:7.55336771768 3:1.04950084825
1 1:985.52199365 2:0.763305780608 3:1.7402424375
0 1:10.0430321467 2:813.185427181 3:4.97728254185
0 1:10.0812334228 2:258.297288417 3:0.127477670549
0 1:9.84210504292 2:887.205815261 3:0.991689193955
1 1:9.94625332613 2:0.298622762132 3:0.147881353231
0 1:9.97800659954 2:727.619819757 3:0.0718361141866
1 1:9.8037938472 2:957.385549617 3:0.0618862028941
0 1:10.0880634741 2:185.024638577 3:1.7028095095
0 1:9.98630799154 2:109.10631473 3:0.681117359751
0 1:9.91671416638 2:166.248076588 3:122.538291094
0 1:10.1206910464 2:88.1539468531 3:141.189859069
1 1:10.1767160518 2:1.02960996847 3:172.02256237
0 1:9.93025147233 2:391.196641942 3:58.040338247
0 1:9.84850936037 2:474.63346537 3:17.5627875397
1 1:9.8162731343 2:61.9199554213 3:30.6740972851
0 1:10.0403482984 2:987.50416929 3:73.0472906209
1 1:997.019228359 2:133.294717663 3:0.0572254083186
0 1:973.303999107 2:1.79080888849 3:0.100478717048
0 1:1008.28808825 2:342.282350685 3:0.409806485495
0 1:1014.55621524 2:0.680510407082 3:0.929530602495
1 1:1012.74370325 2:823.105266455 3:0.0894693730585
0 1:1003.63554038 2:727.334432075 3:0.58206275756
0 1:10.1560432436 2:740.35938307 3:11.6823378533
0 1:9.83949099701 2:512.828227154 3:138.206666681
1 1:10.1837395682 2:179.287126088 3:185.479062365
1 1:9.9761881495 2:12.1093388336 3:9.1264604171
1 1:9.77402180766 2:318.561317743 3:80.6005221355
0 1:1011.15705381 2:0.215825852155 3:1.34429667906
0 1:1005.60353229 2:727.202346126 3:1.47146041005
1 1:1013.93702961 2:58.7312725205 3:0.421041560754
0 1:1004.86813074 2:757.693204258 3:0.566055205344
0 1:999.996324692 2:813.12386828 3:0.864428279513
0 1:996.55255931 2:918.760056995 3:0.43365051974
1 1:1004.1394132 2:464.371823646 3:0.312492288321
@@ -0,0 +1,10 @@
7
5
9
6
6
8
7
6
5
7
Expand Up @@ -239,4 +239,36 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers)
}

test("test DF use nested groupData") {
val testItr = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile).iterator.
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
(id, instance.features, instance.label)
}
val trainingDF = {
val rowList0 = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile)
val labeledPointsRDD0 = sc.parallelize(rowList0, numSlices = 1)
val rowList1 = loadLabelPoints(getClass.getResource("/rank-demo-1.txt.train").getFile)
val labeledPointsRDD1 = sc.parallelize(rowList1, numSlices = 1)
val labeledPointsRDD = labeledPointsRDD0.union(labeledPointsRDD1)
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
import sparkSession.implicits._
sparkSession.createDataset(labeledPointsRDD).toDF
}
val trainGroupData0: Seq[Int] = Source.fromFile(
getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList
val trainGroupData1: Seq[Int] = Source.fromFile(
getClass.getResource("/rank-demo-1.txt.train.group").getFile).getLines().map(_.toInt).toList
val trainGroupData: Seq[Seq[Int]] = Seq(trainGroupData0, trainGroupData1)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)

val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = 2)
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
"id", "features", "label")
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
assert(testDF.count() === predResultsFromDF.size)
}
}
Expand Up @@ -20,6 +20,7 @@ import java.nio.file.Files
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}

import scala.collection.mutable.ListBuffer
import scala.io.Source
import scala.util.Random
import scala.concurrent.duration._
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker}
Expand Down Expand Up @@ -341,4 +342,46 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
}

test("test use groupData") {
val trainSet = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile)
val trainingRDD = sc.parallelize(trainSet, numSlices = 1)
val trainGroupData: Seq[Seq[Int]] = Seq(Source.fromFile(
getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList)
val testSet = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile)
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)

val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)

val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()(0)
assert(testRDD.count() === predResult1.length)
}

test("test use nested groupData") {
val trainSet0 = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile)
val trainingRDD0 = sc.parallelize(trainSet0, numSlices = 1)
val trainSet1 = loadLabelPoints(getClass.getResource("/rank-demo-1.txt.train").getFile)
val trainingRDD1 = sc.parallelize(trainSet1, numSlices = 1)
val trainingRDD = trainingRDD0.union(trainingRDD1)

val trainGroupData0: Seq[Int] = Source.fromFile(
getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList
val trainGroupData1: Seq[Int] = Source.fromFile(
getClass.getResource("/rank-demo-1.txt.train.group").getFile).getLines().map(_.toInt).toList
val trainGroupData: Seq[Seq[Int]] = Seq(trainGroupData0, trainGroupData1)

val testSet = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile)
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)

val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)

val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()(0)
assert(testRDD.count() === predResult1.length)
}
}

0 comments on commit 288f309

Please sign in to comment.