Skip to content

Commit

Permalink
Trying out xgboost 1.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
lucagiovagnoli committed Sep 16, 2020
1 parent 22fad90 commit 7466d5a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ trait XgbConverters {
def asXGB: DMatrix = {
vector match {
case SparseVector(_, indices, values) =>
new DMatrix(Iterator(new LabeledPoint(0.0f, indices, values.map(_.toFloat))))
new DMatrix(Iterator(new LabeledPoint(0.0f, indices.length, indices, values.map(_.toFloat))))

case DenseVector(values) =>
new DMatrix(Iterator(new LabeledPoint(0.0f, null, values.map(_.toFloat))))
new DMatrix(Iterator(new LabeledPoint(0.0f, values.length, null, values.map(_.toFloat))))
}
}

Expand All @@ -34,10 +34,12 @@ trait XgbConverters {
def asXGB: DMatrix = {
tensor match {
case SparseTensor(indices, values, _) =>
new DMatrix(Iterator(new LabeledPoint(0.0f, indices.map(_.head).toArray, values.map(_.toFloat))))
new DMatrix(Iterator(
new LabeledPoint(0.0f, indices.length, indices.map(_.head).toArray, values.map(_.toFloat))))

case DenseTensor(_, _) =>
new DMatrix(Iterator(new LabeledPoint(0.0f, null, tensor.toDense.rawValues.map(_.toFloat))))
new DMatrix(Iterator(
new LabeledPoint(0.0f, tensor.size, null, tensor.toDense.rawValues.map(_.toFloat))))
}
}

Expand Down
1 change: 1 addition & 0 deletions project/Common.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ object Common {
javaOptions in test += sys.env.getOrElse("JVM_OPTS", ""),
resolvers += Resolver.mavenLocal,
resolvers += Resolver.jcenterRepo,
resolvers += "XGBoost4J Release Repo" at "https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/release",
resolvers ++= {
// Only add Sonatype Snapshots if this version itself is a snapshot version
if(isSnapshot.value) {
Expand Down
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.6.1")
addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.9.3")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.7")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.10.0-RC1")
addSbtPlugin("com.frugalmechanic" % "fm-sbt-s3-resolver" % "0.19.0")


libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.7.1"

0 comments on commit 7466d5a

Please sign in to comment.