Skip to content

Commit

Permalink
Bump xgboost to 1.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
lucagiovagnoli committed Sep 16, 2020
1 parent 01629c1 commit 69990d2
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ trait XgbConverters {
def asXGB: DMatrix = {
vector match {
case SparseVector(_, indices, values) =>
new DMatrix(Iterator(new LabeledPoint(0.0f, indices, values.map(_.toFloat))))
new DMatrix(Iterator(new LabeledPoint(0.0f, indices.length, indices, values.map(_.toFloat))))

case DenseVector(values) =>
new DMatrix(Iterator(new LabeledPoint(0.0f, null, values.map(_.toFloat))))
new DMatrix(Iterator(new LabeledPoint(0.0f, values.length, null, values.map(_.toFloat))))
}
}

Expand All @@ -34,10 +34,12 @@ trait XgbConverters {
def asXGB: DMatrix = {
tensor match {
case SparseTensor(indices, values, _) =>
new DMatrix(Iterator(new LabeledPoint(0.0f, indices.map(_.head).toArray, values.map(_.toFloat))))
new DMatrix(Iterator(
new LabeledPoint(0.0f, indices.length, indices.map(_.head).toArray, values.map(_.toFloat))))

case DenseTensor(_, _) =>
new DMatrix(Iterator(new LabeledPoint(0.0f, null, tensor.toDense.rawValues.map(_.toFloat))))
new DMatrix(Iterator(
new LabeledPoint(0.0f, tensor.size, null, tensor.toDense.rawValues.map(_.toFloat))))
}
}

Expand Down
1 change: 1 addition & 0 deletions project/Common.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ object Common {
javaOptions in test += sys.env.getOrElse("JVM_OPTS", ""),
resolvers += Resolver.mavenLocal,
resolvers += Resolver.jcenterRepo,
resolvers += "XGBoost4J Release Repo" at "https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/release",
resolvers ++= {
// Only add Sonatype Snapshots if this version itself is a snapshot version
if(isSnapshot.value) {
Expand Down
2 changes: 1 addition & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ object Dependencies {
lazy val slf4jVersion = "1.7.25"
lazy val awsSdkVersion = "1.11.349"
val tensorflowVersion = "1.11.0"
val xgboostVersion = "1.0.0"
val xgboostVersion = "1.1.1"
val hadoopVersion = "2.6.5" // matches spark version
val kryoVersion = "4.0.2" // Remove upon upgrading to xgboost 1.1.1

Expand Down
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.6.1")
addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.9.3")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.7")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.10.0-RC1")
addSbtPlugin("com.frugalmechanic" % "fm-sbt-s3-resolver" % "0.19.0")


libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.7.1"

0 comments on commit 69990d2

Please sign in to comment.