diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index bf22f7fcca22..7a2d8df04d38 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -123,6 +123,12 @@ object XGBoost extends Serializable { } val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing) val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName)) + if (xgBoostConfMap.isDefinedAt("groupData") + && xgBoostConfMap.get("groupData").get != null) { + trainingSet.setGroup( + xgBoostConfMap.get("groupData").get.asInstanceOf[Seq[Seq[Int]]]( + TaskContext.getPartitionId()).toArray) + } booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, watches = new mutable.HashMap[String, DMatrix] { put("train", trainingSet) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 1ac0778f7b25..b02eecc433d4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -53,7 +53,14 @@ trait LearningTaskParams extends Params { s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}", (value: String) => LearningTaskParams.supportedEvalMetrics.contains(value)) - setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2) + /** + * group data specify each group sizes for ranking task. To correspond to partition of + * training data, it is nested. + */ + val groupData = new Param[Seq[Seq[Int]]](this, "groupData", "group data specify each group size" + + " for ranking task. To correspond to partition of training data, it is nested.") + + setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train new file mode 100644 index 000000000000..1f31343dd4d3 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train @@ -0,0 +1,75 @@ +0 1:985.574005058 2:320.223538037 3:0.621236086198 +0 1:1010.52917943 2:635.535543082 3:2.14984030531 +0 1:1012.91900422 2:132.387300057 3:0.488761066665 +0 1:990.829194034 2:135.102081162 3:0.747701610673 +0 1:1007.05103629 2:154.289183562 3:0.464118249201 +0 1:994.9573036 2:317.483732878 3:0.0313685555674 +0 1:987.8071541 2:731.349178363 3:0.244616944245 +1 1:10.0349544469 2:2.29750906143 3:36.4949974282 +0 1:9.92953881383 2:5.39134047297 3:120.041297548 +0 1:10.0909866713 2:9.06191026312 3:138.807825798 +1 1:10.2090970614 2:0.0784495944448 3:58.207703565 +0 1:9.85695905893 2:9.99500727713 3:56.8610243778 +1 1:10.0805758547 2:0.0410805760559 3:222.102302076 +0 1:10.1209914486 2:9.9729127088 3:171.888238763 +0 1:10.0331939798 2:0.853339303793 3:311.181328375 +0 1:9.93901762951 2:2.72757449146 3:78.4859514413 +0 1:10.0752365346 2:9.18695328235 3:49.8520256553 +1 1:10.0456548902 2:0.270936043122 3:123.462958597 +0 1:10.0568923673 2:0.82997113263 3:44.9391426001 +0 1:9.8214143472 2:0.277538931578 3:15.4217659578 +0 1:9.95258604431 2:8.69564346094 3:255.513470671 +0 1:9.91934976357 2:7.72809741413 3:82.171591817 +0 1:10.043239582 2:8.64168255553 3:38.9657919329 +1 1:10.0236147929 2:0.0496662263659 3:4.40889812286 +1 1:1001.85585324 2:3.75646886071 3:0.0179224994842 +0 1:1014.25578571 2:0.285765311201 3:0.510329864983 +1 1:1002.81422786 2:9.77676280375 3:0.433705951912 +1 1:998.072711553 2:2.82100686538 3:0.889829076909 +0 1:1003.77395036 2:2.55916592114 3:0.0359402151496 +1 1:10.0807877782 2:4.98513959013 3:47.5266363559 +0 1:10.0015013081 2:9.94302478763 3:78.3697486277 +1 1:10.0441936789 2:0.305091816635 3:56.8213984987 +0 1:9.94257106618 2:7.23909568913 3:442.463339039 +1 1:9.86479307916 2:6.41701315844 3:55.1365304834 +0 1:10.0428628516 2:9.98466447697 3:0.391632812588 +0 1:9.94445884566 2:9.99970945878 3:260.438436534 +1 1:9.84641392823 2:225.78051312 3:1.00525978847 +1 1:9.86907690608 2:26.8971083147 3:0.577959255991 +0 1:10.0177314626 2:0.110585342313 3:2.30545043031 +0 1:10.0688190907 2:412.023866234 3:1.22421542264 +0 1:10.1251769646 2:13.8212202925 3:0.129171734504 +0 1:10.0840758802 2:407.359097187 3:0.477000870705 +0 1:10.1007458705 2:987.183625145 3:0.149385677415 +0 1:9.86472656059 2:169.559640615 3:0.147221652519 +0 1:9.94207419238 2:507.290053755 3:0.41996207214 +0 1:9.9671005502 2:1.62610457716 3:0.408173666788 +0 1:1010.57126596 2:9.06673707562 3:0.672092284372 +0 1:1001.6718262 2:9.53203990055 3:4.7364050044 +0 1:995.777341384 2:4.43847316256 3:2.07229073634 +0 1:1002.95701386 2:5.51711016665 3:1.24294450546 +0 1:1016.0988238 2:0.626468941906 3:0.105627919134 +0 1:1013.67571419 2:0.042315529666 3:0.717619310322 +1 1:994.747747892 2:6.01989364024 3:0.772910130015 +1 1:991.654593872 2:7.35575736952 3:1.19822091548 +0 1:1008.47101732 2:8.28240754909 3:0.229582481359 +0 1:1000.81975227 2:1.52448354056 3:0.096441660362 +0 1:10.0900922344 2:322.656649307 3:57.8149073088 +1 1:10.0868337371 2:2.88652339174 3:54.8865514572 +0 1:10.0988984137 2:979.483832657 3:52.6809830901 +0 1:9.97678959238 2:665.770979738 3:481.069628909 +0 1:9.78554312773 2:257.309358658 3:47.7324475232 +0 1:10.0985967566 2:935.896512941 3:138.937052808 +0 1:10.0522252319 2:876.376299607 3:6.00373510669 +1 1:9.88065229501 2:9.99979825653 3:0.0674603696149 +0 1:10.0483244098 2:0.0653852316381 3:0.130679349938 +1 1:9.99685215607 2:1.76602542774 3:0.2551321159 +0 1:9.99750159428 2:1.01591534436 3:0.145445506504 +1 1:9.97380908941 2:0.940048645571 3:0.411805696316 +0 1:9.99977678382 2:6.91329929641 3:5.57858201258 +0 1:978.876096381 2:933.775364741 3:0.579170824236 +0 1:998.381016406 2:220.940470582 3:2.01491778565 +0 1:987.917644594 2:8.74667873567 3:0.364006099758 +0 1:1000.20994892 2:25.2945450565 3:3.5684398964 +0 1:1014.57141264 2:675.593540733 3:0.164174055535 +0 1:998.867283535 2:765.452750642 3:0.818425293238 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train.group b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train.group new file mode 100644 index 000000000000..67e55b03b764 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train.group @@ -0,0 +1,10 @@ +7 +7 +10 +5 +7 +10 +10 +7 +6 +6 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train new file mode 100644 index 000000000000..44c0f1ae3177 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train @@ -0,0 +1,74 @@ +0 1:10.2143092481 2:273.576539531 3:137.111774354 +0 1:10.0366658918 2:842.469052609 3:2.32134375927 +0 1:10.1281202091 2:395.654057342 3:35.4184893063 +0 1:10.1443721289 2:960.058461049 3:272.887070637 +0 1:10.1353234784 2:535.51304462 3:2.15393842032 +1 1:10.0451640374 2:216.733858424 3:55.6533298016 +1 1:9.94254592171 2:44.5985537358 3:304.614176871 +0 1:10.1319257181 2:613.545504487 3:5.42391587912 +0 1:1020.63622468 2:997.476744201 3:0.509425590461 +0 1:986.304585519 2:822.669937965 3:0.605133561808 +1 1:1012.66863221 2:26.7185759069 3:0.0875458784828 +0 1:995.387656321 2:81.8540176995 3:0.691999430068 +0 1:1020.6587198 2:848.826964547 3:0.540159430526 +1 1:1003.81573853 2:379.84350931 3:0.0083682925194 +0 1:1021.60921516 2:641.376951467 3:1.12339054807 +0 1:1000.17585041 2:122.107138713 3:1.09906375372 +1 1:987.64802348 2:5.98448541152 3:0.124241987204 +1 1:9.94610136583 2:346.114985897 3:0.387708236565 +0 1:9.96812192337 2:313.278109696 3:0.00863026595671 +0 1:10.0181739194 2:36.7378924562 3:2.92179879835 +0 1:9.89000102695 2:164.273723971 3:0.685222591968 +0 1:10.1555212436 2:320.451459462 3:2.01341536261 +0 1:10.0085727613 2:999.767117646 3:0.462294934168 +1 1:9.93099658724 2:5.17478203909 3:0.213855205032 +0 1:10.0629454957 2:663.088181857 3:0.049022351462 +0 1:10.1109732417 2:734.904569784 3:1.6998450094 +0 1:1006.6015266 2:505.023453703 3:1.90870566777 +0 1:991.865769489 2:245.437343115 3:0.475109744256 +0 1:998.682734072 2:950.041057232 3:1.9256314201 +0 1:1005.02207209 2:2.9619314197 3:0.0517146822357 +0 1:1002.54526214 2:860.562681899 3:0.915687092848 +0 1:1000.38847359 2:808.416525088 3:0.209690673808 +1 1:992.557818382 2:373.889409453 3:0.107571728577 +0 1:1002.07722137 2:997.329626371 3:1.06504260496 +0 1:1000.40504333 2:949.832139189 3:0.539159980327 +0 1:10.1460179902 2:8.86082969819 3:135.953842715 +1 1:9.98529296553 2:2.87366448495 3:1.74249892194 +0 1:9.88942676744 2:9.4031821056 3:149.473066381 +1 1:10.0192953341 2:1.99685737576 3:1.79502473397 +0 1:10.0110654379 2:8.13112593726 3:87.7765628103 +0 1:997.148677047 2:733.936190093 3:1.49298494242 +0 1:1008.70465919 2:957.121652078 3:0.217414013634 +1 1:997.356154278 2:541.599587807 3:0.100855972216 +0 1:999.615897283 2:943.700501824 3:0.862874175879 +1 1:997.36859077 2:0.200859940848 3:0.13601892182 +0 1:10.0423255624 2:1.73855202168 3:0.956695338485 +1 1:9.88440755486 2:9.9994600678 3:0.305080529665 +0 1:10.0891026412 2:3.28031719474 3:0.364450973697 +0 1:9.90078644258 2:8.77839663617 3:0.456660574479 +1 1:9.79380029711 2:8.77220326156 3:0.527292005175 +0 1:9.93613887011 2:9.76270841268 3:1.40865693823 +0 1:10.0009239007 2:7.29056178263 3:0.498015866607 +0 1:9.96603319905 2:5.12498000925 3:0.517492532783 +0 1:10.0923827222 2:2.76652583955 3:1.56571226159 +1 1:10.0983782035 2:587.788120694 3:0.031756483687 +1 1:9.91397225464 2:994.527496819 3:3.72092164978 +0 1:10.1057472738 2:2.92894440088 3:0.683506438532 +0 1:10.1014053354 2:959.082038017 3:1.07039624129 +0 1:10.1433253044 2:322.515119317 3:0.51408278993 +1 1:9.82832510699 2:637.104433908 3:0.250272776427 +0 1:1000.49729075 2:2.75336888111 3:0.576634423274 +1 1:984.90338088 2:0.0295435794035 3:1.26273339929 +0 1:1001.53811442 2:4.64164410861 3:0.0293389959504 +1 1:995.875898395 2:5.08223403205 3:0.382330566779 +0 1:996.405937252 2:6.26395190757 3:0.453645816611 +0 1:10.0165140779 2:340.126072514 3:0.220794603312 +0 1:9.93482824816 2:951.672000448 3:0.124406293612 +0 1:10.1700278554 2:0.0140985961008 3:0.252452256311 +0 1:9.99825079542 2:950.382643896 3:0.875382402062 +0 1:9.87316410028 2:686.788257829 3:0.215886999825 +0 1:10.2893240654 2:89.3947931451 3:0.569578232133 +0 1:9.98689192703 2:0.430107535413 3:2.99869831728 +0 1:10.1365175107 2:972.279245093 3:0.0865099386744 +0 1:9.90744703306 2:50.810461183 3:3.00863325197 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train.group b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train.group new file mode 100644 index 000000000000..877ef9231fda --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train.group @@ -0,0 +1,10 @@ +8 +9 +9 +9 +5 +5 +9 +6 +5 +9 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test new file mode 100644 index 000000000000..fc237b7e106f --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test @@ -0,0 +1,66 @@ +0 1:10.0229017899 2:7.30178495562 3:0.118115020017 +0 1:9.93639621859 2:9.93102159291 3:0.0435030004396 +0 1:10.1301737265 2:0.00411765220572 3:2.4165878053 +1 1:9.87828587087 2:0.608588414992 3:0.111262590883 +0 1:10.1373430048 2:0.47764012225 3:0.991553052194 +0 1:10.0523814718 2:4.72152505167 3:0.672978832666 +0 1:10.0449715742 2:8.40373928536 3:0.384457573667 +1 1:996.398498791 2:941.976309154 3:0.230269231292 +0 1:1005.11269468 2:900.093680877 3:0.265031528873 +0 1:997.160349441 2:891.331101688 3:2.19362017313 +0 1:993.754139031 2:44.8000165317 3:1.03868009875 +1 1:994.831299184 2:241.959208453 3:0.667631827024 +0 1:995.948333283 2:7.94326917112 3:0.750490877118 +0 1:989.733981273 2:7.52077625436 3:0.0126335967282 +0 1:1003.54086516 2:6.48177510564 3:1.19441696788 +0 1:996.56177804 2:9.71959812613 3:1.33082465111 +0 1:1005.61382467 2:0.234339369309 3:1.17987797356 +1 1:980.215758708 2:6.85554542926 3:2.63965085259 +1 1:987.776408872 2:2.23354609991 3:0.841885278028 +0 1:1006.54260396 2:8.12142049834 3:2.26639471174 +0 1:1009.87927639 2:6.40028519044 3:0.775155669615 +0 1:9.95006244393 2:928.76896718 3:234.948458244 +1 1:10.0749152258 2:255.294574476 3:62.9728604166 +1 1:10.1916541988 2:312.682867085 3:92.299413677 +0 1:9.95646724484 2:742.263188416 3:53.3310473654 +0 1:9.86211293222 2:996.237023866 3:2.00760301168 +1 1:9.91801019468 2:303.971783709 3:50.3147230679 +0 1:996.983996934 2:9.52188222766 3:1.33588120981 +0 1:995.704388126 2:9.49260524915 3:0.908498516541 +0 1:987.86480767 2:0.0870786716821 3:0.108859297837 +0 1:1000.99561307 2:2.85272694575 3:0.171134518956 +0 1:1011.05508066 2:7.55336771768 3:1.04950084825 +1 1:985.52199365 2:0.763305780608 3:1.7402424375 +0 1:10.0430321467 2:813.185427181 3:4.97728254185 +0 1:10.0812334228 2:258.297288417 3:0.127477670549 +0 1:9.84210504292 2:887.205815261 3:0.991689193955 +1 1:9.94625332613 2:0.298622762132 3:0.147881353231 +0 1:9.97800659954 2:727.619819757 3:0.0718361141866 +1 1:9.8037938472 2:957.385549617 3:0.0618862028941 +0 1:10.0880634741 2:185.024638577 3:1.7028095095 +0 1:9.98630799154 2:109.10631473 3:0.681117359751 +0 1:9.91671416638 2:166.248076588 3:122.538291094 +0 1:10.1206910464 2:88.1539468531 3:141.189859069 +1 1:10.1767160518 2:1.02960996847 3:172.02256237 +0 1:9.93025147233 2:391.196641942 3:58.040338247 +0 1:9.84850936037 2:474.63346537 3:17.5627875397 +1 1:9.8162731343 2:61.9199554213 3:30.6740972851 +0 1:10.0403482984 2:987.50416929 3:73.0472906209 +1 1:997.019228359 2:133.294717663 3:0.0572254083186 +0 1:973.303999107 2:1.79080888849 3:0.100478717048 +0 1:1008.28808825 2:342.282350685 3:0.409806485495 +0 1:1014.55621524 2:0.680510407082 3:0.929530602495 +1 1:1012.74370325 2:823.105266455 3:0.0894693730585 +0 1:1003.63554038 2:727.334432075 3:0.58206275756 +0 1:10.1560432436 2:740.35938307 3:11.6823378533 +0 1:9.83949099701 2:512.828227154 3:138.206666681 +1 1:10.1837395682 2:179.287126088 3:185.479062365 +1 1:9.9761881495 2:12.1093388336 3:9.1264604171 +1 1:9.77402180766 2:318.561317743 3:80.6005221355 +0 1:1011.15705381 2:0.215825852155 3:1.34429667906 +0 1:1005.60353229 2:727.202346126 3:1.47146041005 +1 1:1013.93702961 2:58.7312725205 3:0.421041560754 +0 1:1004.86813074 2:757.693204258 3:0.566055205344 +0 1:999.996324692 2:813.12386828 3:0.864428279513 +0 1:996.55255931 2:918.760056995 3:0.43365051974 +1 1:1004.1394132 2:464.371823646 3:0.312492288321 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test.group b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test.group new file mode 100644 index 000000000000..81e3e05be03a --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test.group @@ -0,0 +1,10 @@ +7 +5 +9 +6 +6 +8 +7 +6 +5 +7 diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index aff16f146624..f8098c0deebf 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -239,4 +239,36 @@ class XGBoostDFSuite extends SharedSparkContext with Utils { XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers) } + + test("test DF use nested groupData") { + val testItr = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile).iterator. + zipWithIndex.map { case (instance: LabeledPoint, id: Int) => + (id, instance.features, instance.label) + } + val trainingDF = { + val rowList0 = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile) + val labeledPointsRDD0 = sc.parallelize(rowList0, numSlices = 1) + val rowList1 = loadLabelPoints(getClass.getResource("/rank-demo-1.txt.train").getFile) + val labeledPointsRDD1 = sc.parallelize(rowList1, numSlices = 1) + val labeledPointsRDD = labeledPointsRDD0.union(labeledPointsRDD1) + val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate() + import sparkSession.implicits._ + sparkSession.createDataset(labeledPointsRDD).toDF + } + val trainGroupData0: Seq[Int] = Source.fromFile( + getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList + val trainGroupData1: Seq[Int] = Source.fromFile( + getClass.getResource("/rank-demo-1.txt.train.group").getFile).getLines().map(_.toInt).toList + val trainGroupData: Seq[Seq[Int]] = Seq(trainGroupData0, trainGroupData1) + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "rank:pairwise", "groupData" -> trainGroupData) + + val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, + round = 5, nWorkers = 2) + val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF( + "id", "features", "label") + val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF). + collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap + assert(testDF.count() === predResultsFromDF.size) + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 20c5263ef6c4..fb41beceabe3 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -20,6 +20,7 @@ import java.nio.file.Files import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} import scala.collection.mutable.ListBuffer +import scala.io.Source import scala.util.Random import scala.concurrent.duration._ import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix, RabitTracker => PyRabitTracker} @@ -341,4 +342,46 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { assert(loadedXGBoostModel.getLabelCol == "label") assert(loadedXGBoostModel.getPredictionCol == "prediction") } + + test("test use groupData") { + val trainSet = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile) + val trainingRDD = sc.parallelize(trainSet, numSlices = 1) + val trainGroupData: Seq[Seq[Int]] = Seq(Source.fromFile( + getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList) + val testSet = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile) + val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features) + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "rank:pairwise", "groupData" -> trainGroupData) + + val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1) + val predRDD = xgBoostModel.predict(testRDD) + val predResult1: Array[Array[Float]] = predRDD.collect()(0) + assert(testRDD.count() === predResult1.length) + } + + test("test use nested groupData") { + val trainSet0 = loadLabelPoints(getClass.getResource("/rank-demo-0.txt.train").getFile) + val trainingRDD0 = sc.parallelize(trainSet0, numSlices = 1) + val trainSet1 = loadLabelPoints(getClass.getResource("/rank-demo-1.txt.train").getFile) + val trainingRDD1 = sc.parallelize(trainSet1, numSlices = 1) + val trainingRDD = trainingRDD0.union(trainingRDD1) + + val trainGroupData0: Seq[Int] = Source.fromFile( + getClass.getResource("/rank-demo-0.txt.train.group").getFile).getLines().map(_.toInt).toList + val trainGroupData1: Seq[Int] = Source.fromFile( + getClass.getResource("/rank-demo-1.txt.train.group").getFile).getLines().map(_.toInt).toList + val trainGroupData: Seq[Seq[Int]] = Seq(trainGroupData0, trainGroupData1) + + val testSet = loadLabelPoints(getClass.getResource("/rank-demo.txt.test").getFile) + val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features) + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "rank:pairwise", "groupData" -> trainGroupData) + + val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2) + val predRDD = xgBoostModel.predict(testRDD) + val predResult1: Array[Array[Float]] = predRDD.collect()(0) + assert(testRDD.count() === predResult1.length) + } }