diff --git a/mleap-xgboost-runtime/src/main/scala/ml/combust/mleap/xgboost/runtime/XgbConverters.scala b/mleap-xgboost-runtime/src/main/scala/ml/combust/mleap/xgboost/runtime/XgbConverters.scala index 673381e6f..4f2a17419 100644 --- a/mleap-xgboost-runtime/src/main/scala/ml/combust/mleap/xgboost/runtime/XgbConverters.scala +++ b/mleap-xgboost-runtime/src/main/scala/ml/combust/mleap/xgboost/runtime/XgbConverters.scala @@ -13,10 +13,10 @@ trait XgbConverters { def asXGB: DMatrix = { vector match { case SparseVector(_, indices, values) => - new DMatrix(Iterator(new LabeledPoint(0.0f, indices, values.map(_.toFloat)))) + new DMatrix(Iterator(new LabeledPoint(0.0f, indices.length, indices, values.map(_.toFloat)))) case DenseVector(values) => - new DMatrix(Iterator(new LabeledPoint(0.0f, null, values.map(_.toFloat)))) + new DMatrix(Iterator(new LabeledPoint(0.0f, values.length, null, values.map(_.toFloat)))) } } @@ -34,10 +34,12 @@ trait XgbConverters { def asXGB: DMatrix = { tensor match { case SparseTensor(indices, values, _) => - new DMatrix(Iterator(new LabeledPoint(0.0f, indices.map(_.head).toArray, values.map(_.toFloat)))) + new DMatrix(Iterator( + new LabeledPoint(0.0f, indices.length, indices.map(_.head).toArray, values.map(_.toFloat)))) case DenseTensor(_, _) => - new DMatrix(Iterator(new LabeledPoint(0.0f, null, tensor.toDense.rawValues.map(_.toFloat)))) + new DMatrix(Iterator( + new LabeledPoint(0.0f, tensor.size, null, tensor.toDense.rawValues.map(_.toFloat)))) } } diff --git a/project/Common.scala b/project/Common.scala index 018ae14d2..ec70fd883 100644 --- a/project/Common.scala +++ b/project/Common.scala @@ -23,6 +23,7 @@ object Common { javaOptions in test += sys.env.getOrElse("JVM_OPTS", ""), resolvers += Resolver.mavenLocal, resolvers += Resolver.jcenterRepo, + resolvers += "XGBoost4J Release Repo" at "https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/release", resolvers ++= { // Only add Sonatype Snapshots if this version itself is a snapshot version if(isSnapshot.value) { diff --git a/project/plugins.sbt b/project/plugins.sbt index d01a8ddac..bc96542bf 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -9,6 +9,7 @@ addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.6.1") addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.9.3") addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.7") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.10.0-RC1") +addSbtPlugin("com.frugalmechanic" % "fm-sbt-s3-resolver" % "0.19.0") libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.7.1"