Skip to content

Commit

Permalink
Updates XGBoost to 2.0.1 (#2833)
Browse files Browse the repository at this point in the history
* Updates XGBoost to 2.0.1

* Use devtools 8

* Updates based on new Xgboost JNI API.

---------

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
zachgk and frankfliu committed Nov 4, 2023
1 parent 6981d76 commit 715e620
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/native_s3_xgboost.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
run: |
yum -y update
yum -y install centos-release-scl-rh epel-release
yum -y install devtoolset-7 git patch libstdc++-static curl python3-devel
yum -y install devtoolset-8 git patch libstdc++-static curl python3-devel
curl -L -o cmake.tar.gz https://github.com/Kitware/CMake/releases/download/v3.27.0-rc2/cmake-3.27.0-rc2-linux-aarch64.tar.gz
tar xvfz cmake.tar.gz
ln -sf $PWD/cmake-3.*/bin/cmake /usr/bin/cmake
Expand All @@ -50,7 +50,7 @@ jobs:
XGBOOST_VERSION=${{ github.event.inputs.xgb_version }}
XGBOOST_VERSION=${XGBOOST_VERSION:-$(cat gradle.properties | awk -F '=' '/xgboost_version/ {print $2}')}
git clone https://github.com/dmlc/xgboost --recursive -b v"$XGBOOST_VERSION"
export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin
export PATH=$PATH:/opt/rh/devtoolset-8/root/usr/bin
cd xgboost/jvm-packages
python3 create_jni.py
cd ../..
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class XgbNDManager extends BaseNDManager {
private static final XgbNDManager SYSTEM_MANAGER = new SystemManager();

private float missingValue = Float.NaN;
private int nthread = 1;

private XgbNDManager(NDManager parent, Device device) {
super(parent, device);
Expand All @@ -57,6 +58,15 @@ public void setMissingValue(float missingValue) {
this.missingValue = missingValue;
}

/**
* Sets the default number of threads.
*
* @param nthread the default number of threads
*/
public void setNthread(int nthread) {
this.nthread = nthread;
}

/** {@inheritDoc} */
@Override
public ByteBuffer allocateDirect(int capacity) {
Expand Down Expand Up @@ -166,7 +176,7 @@ public NDArray createCSR(Buffer buffer, long[] indptr, long[] indices, Shape sha
int[] intIndices = Arrays.stream(indices).mapToInt(Math::toIntExact).toArray();
float[] data = new float[buffer.remaining()];
((FloatBuffer) buffer).get(data);
long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data);
long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data, missingValue, nthread);
return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.CSR);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,12 @@ public static long createDMatrix(ColumnBatch columnBatch, float missing, int nth
return handles[0];
}

public static long createDMatrixCSR(long[] indptr, int[] indices, float[] array) {
public static long createDMatrixCSR(
long[] indptr, int[] indices, float[] array, float missing, int nthread) {
long[] handles = new long[1];
checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(indptr, indices, array, 0, handles));
checkCall(
XGBoostJNI.XGDMatrixCreateFromCSR(
indptr, indices, array, 0, missing, nthread, handles));
return handles[0];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void downloadXGBoostModel() throws IOException {
@Test
public void testVersion() {
Engine engine = Engine.getEngine("XGBoost");
Assert.assertEquals("1.7.5", engine.getVersion());
Assert.assertEquals("2.0.1", engine.getVersion());
}

/*
Expand Down Expand Up @@ -93,6 +93,7 @@ public void testNDArray() {
try (XgbNDManager manager =
(XgbNDManager) XgbNDManager.getSystemManager().newSubManager()) {
manager.setMissingValue(Float.NaN);
manager.setNthread(1);
NDArray zeros = manager.zeros(new Shape(1, 2));
Assert.expectThrows(UnsupportedOperationException.class, zeros::toFloatArray);

Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ paddlepaddle_version=2.3.2
sentencepiece_version=0.1.97
tokenizers_version=0.14.1
fasttext_version=0.9.2
xgboost_version=1.7.5
xgboost_version=2.0.1
lightgbm_version=3.2.110
rapis_version=22.12.0

Expand Down

0 comments on commit 715e620

Please sign in to comment.