diff --git a/.github/workflows/native_s3_xgboost.yml b/.github/workflows/native_s3_xgboost.yml index 3d92e1bd3a8..72a8adfc83e 100644 --- a/.github/workflows/native_s3_xgboost.yml +++ b/.github/workflows/native_s3_xgboost.yml @@ -34,7 +34,7 @@ jobs: run: | yum -y update yum -y install centos-release-scl-rh epel-release - yum -y install devtoolset-7 git patch libstdc++-static curl python3-devel + yum -y install devtoolset-8 git patch libstdc++-static curl python3-devel curl -L -o cmake.tar.gz https://github.com/Kitware/CMake/releases/download/v3.27.0-rc2/cmake-3.27.0-rc2-linux-aarch64.tar.gz tar xvfz cmake.tar.gz ln -sf $PWD/cmake-3.*/bin/cmake /usr/bin/cmake @@ -50,7 +50,7 @@ jobs: XGBOOST_VERSION=${{ github.event.inputs.xgb_version }} XGBOOST_VERSION=${XGBOOST_VERSION:-$(cat gradle.properties | awk -F '=' '/xgboost_version/ {print $2}')} git clone https://github.com/dmlc/xgboost --recursive -b v"$XGBOOST_VERSION" - export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin + export PATH=$PATH:/opt/rh/devtoolset-8/root/usr/bin cd xgboost/jvm-packages python3 create_jni.py cd ../.. diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java index 3b56cbca241..81f9708e72b 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java @@ -39,6 +39,7 @@ public class XgbNDManager extends BaseNDManager { private static final XgbNDManager SYSTEM_MANAGER = new SystemManager(); private float missingValue = Float.NaN; + private int nthread = 1; private XgbNDManager(NDManager parent, Device device) { super(parent, device); @@ -57,6 +58,15 @@ public void setMissingValue(float missingValue) { this.missingValue = missingValue; } + /** + * Sets the default number of threads. + * + * @param nthread the default number of threads + */ + public void setNthread(int nthread) { + this.nthread = nthread; + } + /** {@inheritDoc} */ @Override public ByteBuffer allocateDirect(int capacity) { @@ -166,7 +176,7 @@ public NDArray createCSR(Buffer buffer, long[] indptr, long[] indices, Shape sha int[] intIndices = Arrays.stream(indices).mapToInt(Math::toIntExact).toArray(); float[] data = new float[buffer.remaining()]; ((FloatBuffer) buffer).get(data); - long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data); + long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data, missingValue, nthread); return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.CSR); } diff --git a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java index fefbe7f0716..eb071552fd0 100644 --- a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java +++ b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java @@ -67,9 +67,12 @@ public static long createDMatrix(ColumnBatch columnBatch, float missing, int nth return handles[0]; } - public static long createDMatrixCSR(long[] indptr, int[] indices, float[] array) { + public static long createDMatrixCSR( + long[] indptr, int[] indices, float[] array, float missing, int nthread) { long[] handles = new long[1]; - checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(indptr, indices, array, 0, handles)); + checkCall( + XGBoostJNI.XGDMatrixCreateFromCSR( + indptr, indices, array, 0, missing, nthread, handles)); return handles[0]; } diff --git a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java index 0b09ed6807c..7d928b121f4 100644 --- a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java +++ b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java @@ -53,7 +53,7 @@ public void downloadXGBoostModel() throws IOException { @Test public void testVersion() { Engine engine = Engine.getEngine("XGBoost"); - Assert.assertEquals("1.7.5", engine.getVersion()); + Assert.assertEquals("2.0.1", engine.getVersion()); } /* @@ -93,6 +93,7 @@ public void testNDArray() { try (XgbNDManager manager = (XgbNDManager) XgbNDManager.getSystemManager().newSubManager()) { manager.setMissingValue(Float.NaN); + manager.setNthread(1); NDArray zeros = manager.zeros(new Shape(1, 2)); Assert.expectThrows(UnsupportedOperationException.class, zeros::toFloatArray); diff --git a/gradle.properties b/gradle.properties index 0ef8f6e991b..ca970df1e66 100644 --- a/gradle.properties +++ b/gradle.properties @@ -22,7 +22,7 @@ paddlepaddle_version=2.3.2 sentencepiece_version=0.1.97 tokenizers_version=0.14.1 fasttext_version=0.9.2 -xgboost_version=1.7.5 +xgboost_version=2.0.1 lightgbm_version=3.2.110 rapis_version=22.12.0