From e80d3db64b2d97a982a5df81cd25c4e401db60b0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 3 Mar 2016 11:36:34 -0800 Subject: [PATCH] [DIST] Enable multiple thread and tracker, make rabit and xgboost more thread-safe by using thread local variables. --- NEWS.md | 5 + dmlc-core | 2 +- jvm-packages/.gitignore | 2 + jvm-packages/create_jni.sh | 2 + jvm-packages/test_distributed.sh | 2 +- .../ml/dmlc/xgboost4j/demo/DistTrain.java | 76 +++++++++--- .../main/java/ml/dmlc/xgboost4j/FileUtil.java | 78 ++++++++++++ .../ml/dmlc/xgboost4j/NativeLibLoader.java | 57 +-------- .../java/ml/dmlc/xgboost4j/RabitTracker.java | 98 +++++++++++++++ .../java/ml/dmlc/xgboost4j/XgboostJNI.java | 4 +- .../xgboost4j/src/native/xgboost4j.cpp | 116 ++++++++---------- jvm-packages/xgboost4j/src/native/xgboost4j.h | 6 +- rabit | 2 +- src/common/common.cc | 12 +- src/common/random.h | 3 +- src/common/thread_local.h | 2 + src/gbm/gbtree.cc | 9 +- 17 files changed, 323 insertions(+), 153 deletions(-) create mode 100644 jvm-packages/.gitignore create mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java create mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java diff --git a/NEWS.md b/NEWS.md index 48d5bedeae2f..0da657019c8a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -22,6 +22,11 @@ This file records the changes in xgboost library in reverse chronological order. - The windows version is still blocked due to Rtools do not support ```std::thread```. * rabit and dmlc-core are maintained through git submodule - Anyone can open PR to update these dependencies now. +* Improvements + - Rabit and xgboost libs are not thread-safe and use thread local PRNGs + - This could fix some of the previous problem which runs xgboost on multiple threads. +* JVM Package + - Enable xgboost4j for java and scala ## v0.47 (2016.01.14) diff --git a/dmlc-core b/dmlc-core index 71360023dba4..3f6ff43d3976 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0 +Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f diff --git a/jvm-packages/.gitignore b/jvm-packages/.gitignore new file mode 100644 index 000000000000..d1d4390d6b5d --- /dev/null +++ b/jvm-packages/.gitignore @@ -0,0 +1,2 @@ +tracker.py +build.sh \ No newline at end of file diff --git a/jvm-packages/create_jni.sh b/jvm-packages/create_jni.sh index 13e6a8556997..13d2604e7419 100755 --- a/jvm-packages/create_jni.sh +++ b/jvm-packages/create_jni.sh @@ -27,6 +27,8 @@ fi rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl} mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl} +# copy python to native resources +cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py popd > /dev/null echo "complete" diff --git a/jvm-packages/test_distributed.sh b/jvm-packages/test_distributed.sh index 7b5515b495e8..c9a5b21be9fd 100644 --- a/jvm-packages/test_distributed.sh +++ b/jvm-packages/test_distributed.sh @@ -1,5 +1,5 @@ #!/bin/bash # Simple script to test distributed version, to be deleted later. cd xgboost4j-demo -../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain +java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4 cd .. diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java index e64b3ef70794..3cff4bd79ab0 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java @@ -2,42 +2,78 @@ import java.io.IOException; import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import ml.dmlc.xgboost4j.*; + /** * Distributed training example, used to quick test distributed training. * * @author tqchen */ public class DistTrain { + private static final Log logger = LogFactory.getLog(DistTrain.class); + private Map envs = null; + + private class Worker implements Runnable { + private int worker_id; + Worker(int worker_id) { + this.worker_id = worker_id; + } + + public void run() { + try { + Map worker_env = new HashMap(envs); - public static void main(String[] args) throws IOException, XGBoostError { - // always initialize rabit module before training. - Rabit.init(new HashMap()); + worker_env.put("DMLC_TASK_ID", new Integer(worker_id).toString()); + // always initialize rabit module before training. + Rabit.init(worker_env); - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - HashMap params = new HashMap(); - params.put("eta", 1.0); - params.put("max_depth", 2); - params.put("silent", 1); - params.put("objective", "binary:logistic"); + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("nthread", 2); + params.put("objective", "binary:logistic"); - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); - //set round - int round = 2; + //set round + int round = 2; - //train a boost model - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + //train a boost model + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + // always shutdown rabit module after training. + Rabit.shutdown(); + } catch (Exception ex){ + logger.error(ex); + } + } + } + + void start(int nworker) throws IOException, XGBoostError, InterruptedException { + RabitTracker tracker = new RabitTracker(nworker); + tracker.start(); + envs = tracker.getWorkerEnvs(); + for (int i = 0; i < nworker; ++i) { + new Thread(new Worker(i)).start(); + } + tracker.waitFor(); + } - // always shutdown rabit module after training. - Rabit.shutdown(); + public static void main(String[] args) throws IOException, XGBoostError, InterruptedException { + new DistTrain().start(Integer.parseInt(args[0])); } } \ No newline at end of file diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java new file mode 100644 index 000000000000..4b535bd2fb7b --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java @@ -0,0 +1,78 @@ +package ml.dmlc.xgboost4j; + + +import java.io.*; +import java.io.IOException; + +/** + * Auxiliary utils to + */ +class FileUtil { + /** + * Create a temp file that copies the resource from current JAR archive + *

+ * The file from JAR is copied into system temp file. + * The temporary file is deleted after exiting. + * Method uses String as filename because the pathname is "abstract", not system-dependent. + *

+ * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to + * {@code path}. + * @param path Path to the resources in the jar + * @return The created temp file. + * @throws IOException + * @throws IllegalArgumentException + */ + static File createTempFileFromResource(String path) throws IOException, IllegalArgumentException { + // Obtain filename from path + if (!path.startsWith("/")) { + throw new IllegalArgumentException("The path has to be absolute (start with '/')."); + } + + String[] parts = path.split("/"); + String filename = (parts.length > 1) ? parts[parts.length - 1] : null; + + // Split filename to prexif and suffix (extension) + String prefix = ""; + String suffix = null; + if (filename != null) { + parts = filename.split("\\.", 2); + prefix = parts[0]; + suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-) + } + + // Check if the filename is okay + if (filename == null || prefix.length() < 3) { + throw new IllegalArgumentException("The filename has to be at least 3 characters long."); + } + // Prepare temporary file + File temp = File.createTempFile(prefix, suffix); + temp.deleteOnExit(); + + if (!temp.exists()) { + throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist."); + } + + // Prepare buffer for data copying + byte[] buffer = new byte[1024]; + int readBytes; + + // Open and check input stream + InputStream is = NativeLibLoader.class.getResourceAsStream(path); + if (is == null) { + throw new FileNotFoundException("File " + path + " was not found inside JAR."); + } + + // Open output stream and copy data between source file in JAR and the temporary file + OutputStream os = new FileOutputStream(temp); + try { + while ((readBytes = is.read(buffer)) != -1) { + os.write(buffer, 0, readBytes); + } + } finally { + // If read/write fails, close streams safely before throwing an exception + os.close(); + is.close(); + } + return temp; + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java index c23ace8f4464..01d846f62432 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java @@ -21,6 +21,9 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import ml.dmlc.xgboost4j.FileUtil; + + /** * class to load native library * @@ -61,59 +64,7 @@ public static synchronized void initXgBoost() throws IOException { * three characters */ private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{ - - if (!path.startsWith("/")) { - throw new IllegalArgumentException("The path has to be absolute (start with '/')."); - } - - // Obtain filename from path - String[] parts = path.split("/"); - String filename = (parts.length > 1) ? parts[parts.length - 1] : null; - - // Split filename to prexif and suffix (extension) - String prefix = ""; - String suffix = null; - if (filename != null) { - parts = filename.split("\\.", 2); - prefix = parts[0]; - suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-) - } - - // Check if the filename is okay - if (filename == null || prefix.length() < 3) { - throw new IllegalArgumentException("The filename has to be at least 3 characters long."); - } - - // Prepare temporary file - File temp = File.createTempFile(prefix, suffix); - temp.deleteOnExit(); - - if (!temp.exists()) { - throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist."); - } - - // Prepare buffer for data copying - byte[] buffer = new byte[1024]; - int readBytes; - - // Open and check input stream - InputStream is = NativeLibLoader.class.getResourceAsStream(path); - if (is == null) { - throw new FileNotFoundException("File " + path + " was not found inside JAR."); - } - - // Open output stream and copy data between source file in JAR and the temporary file - OutputStream os = new FileOutputStream(temp); - try { - while ((readBytes = is.read(buffer)) != -1) { - os.write(buffer, 0, readBytes); - } - } finally { - // If read/write fails, close streams safely before throwing an exception - os.close(); - is.close(); - } - + File temp = FileUtil.createTempFileFromResource(path); // Finally, load the library System.load(temp.getAbsolutePath()); } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java new file mode 100644 index 000000000000..793943ffffdd --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java @@ -0,0 +1,98 @@ +package ml.dmlc.xgboost4j; + + + +import java.io.*; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +/** + * Distributed RabitTracker, need to be started on driver code before running distributed jobs. + */ +public class RabitTracker { + // Maybe per tracker logger? + private static final Log logger = LogFactory.getLog(RabitTracker.class); + // tracker python file. + private static File tracker_py = null; + // environment variable to be pased. + private Map envs = new HashMap(); + // number of workers to be submitted. + private int num_workers; + // child process + private Process process = null; + // logger thread + private Thread logger_thread = null; + + //load native library + static { + try { + initTrackerPy(); + } catch (IOException ex) { + logger.error("load tracker library failed."); + logger.error(ex); + } + } + + /** + * Tracker logger that logs output from tracker. + */ + private class TrackerLogger implements Runnable { + public void run() { + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + String line; + try { + while ((line = reader.readLine()) != null) { + logger.info(line); + } + } catch (IOException ex) { + logger.error(ex.toString()); + } + } + } + + private static synchronized void initTrackerPy() throws IOException { + tracker_py = FileUtil.createTempFileFromResource("/tracker.py"); + } + + + public RabitTracker(int num_workers) { + this.num_workers = num_workers; + } + + /** + * Get environments that can be used to pass to worker. + * @return The environment settings. + */ + public Map getWorkerEnvs() { + return envs; + } + + public void start() throws IOException { + process = Runtime.getRuntime().exec("python " + tracker_py.getAbsolutePath() + + " --num-workers=" + new Integer(num_workers).toString()); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START"); + String line; + while ((line = reader.readLine()) != null) { + if (line.trim().equals("DMLC_TRACKER_ENV_END")) { + break; + } + String []sep = line.split("="); + if (sep.length == 2) { + envs.put(sep[0], sep[1]); + } + } + logger.debug("Tracker started, with env=" + envs.toString()); + // also start a tracker logger + logger_thread = new Thread(new TrackerLogger()); + logger_thread.setDaemon(true); + logger_thread.start(); + } + + public void waitFor() throws InterruptedException { + process.waitFor(); + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java index c26968b54861..8eded82a735c 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java @@ -74,9 +74,9 @@ public final static native int XGBoosterPredict(long handle, long dmat, int opti public final static native int XGBoosterSaveModel(long handle, String fname); - public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len); + public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes); - public final static native int XGBoosterGetModelRaw(long handle, String[] out_string); + public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes); public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings); diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 8556be4a5273..da3f5a92d56c 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -13,6 +13,8 @@ */ #include +#include +#include #include "./xgboost4j.h" #include #include @@ -276,27 +278,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { - DMatrixHandle* handles = NULL; - bst_ulong len = 0; - jlong* cjhandles = 0; - BoosterHandle result; - - if (jhandles) { - len = (bst_ulong)jenv->GetArrayLength(jhandles); - handles = new DMatrixHandle[len]; - //put handle from jhandles to chandles - cjhandles = jenv->GetLongArrayElements(jhandles, 0); - for(bst_ulong i=0; i handles; + if (jhandles != nullptr) { + size_t len = jenv->GetArrayLength(jhandles); + jlong *cjhandles = jenv->GetLongArrayElements(jhandles, 0); + for (size_t i = 0; i < len; ++i) { + handles.push_back((DMatrixHandle) cjhandles[i]); } - } - - int ret = XGBoosterCreate(handles, len, &result); - //release - if (jhandles) { - delete[] handles; jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); } + BoosterHandle result; + int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result); setHandle(jenv, jout, result); return ret; } @@ -369,43 +361,34 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle* dmats = 0; - char **evnames = 0; - char *result = 0; - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats); - if(len > 0) { - dmats = new DMatrixHandle[len]; - evnames = new char*[len]; - } - //put handle from jhandles to chandles + std::vector dmats; + std::vector evnames; + std::vector evchars; + + size_t len = static_cast(jenv->GetArrayLength(jdmats)); + // put handle from jhandles to chandles jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0); - for(bst_ulong i=0; iGetObjectArrayElement(jevnames, i); - const char* cevname = jenv->GetStringUTFChars(jevname, 0); - evnames[i] = new char[jenv->GetStringLength(jevname)]; - strcpy(evnames[i], cevname); - jenv->ReleaseStringUTFChars(jevname, cevname); + const char *s =jenv->GetStringUTFChars(jevname, 0); + evnames.push_back(std::string(s, jenv->GetStringLength(jevname))); + if (s != nullptr) jenv->ReleaseStringUTFChars(jevname, s); } - - int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result); - if(len > 0) { - delete[] dmats; - //release string chars - for(bst_ulong i=0; iReleaseLongArrayElements(jdmats, cjdmats, 0); + jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); + for (size_t i = 0; i < len; ++i) { + evchars.push_back(evnames[i].c_str()); + } + const char* result; + int ret = XGBoosterEvalOneIter(handle, jiter, + dmlc::BeginPtr(dmats), + dmlc::BeginPtr(evchars), + len, &result); + jstring jinfo = nullptr; + if (result != nullptr) { + jinfo = jenv->NewStringUTF(result); } - - jstring jinfo = 0; - if (result) jinfo = jenv->NewStringUTF((const char *) result); jenv->SetObjectArrayElement(jout, 0, jinfo); - return ret; } @@ -456,37 +439,40 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel int ret = XGBoosterSaveModel(handle, fname); if (fname) jenv->ReleaseStringUTFChars(jfname, fname); - return ret; } /* * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterLoadModelFromBuffer - * Signature: (JJJ)V + * Signature: (J[B)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer - (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { + (JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) { BoosterHandle handle = (BoosterHandle) jhandle; - void *buf = (void*) jbuf; - return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen); + jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0); + int ret = XGBoosterLoadModelFromBuffer( + handle, buffer, jenv->GetArrayLength(jbytes)); + jenv->ReleaseByteArrayElements(jbytes, buffer, 0); + return ret; } /* * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterGetModelRaw - * Signature: (J)Ljava/lang/String; + * Signature: (J[[B)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; bst_ulong len = 0; - char *result; + const char* result; + int ret = XGBoosterGetModelRaw(handle, &len, &result); - int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result); if (result) { - jstring jinfo = jenv->NewStringUTF((const char *) result); - jenv->SetObjectArrayElement(jout, 0, jinfo); + jbyteArray jarray = jenv->NewByteArray(len); + jenv->SetByteArrayRegion(jarray, 0, len, (jbyte*)result); + jenv->SetObjectArrayElement(jout, 0, jarray); } return ret; } @@ -553,15 +539,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs); for (bst_ulong i = 0; i < len; ++i) { jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i); - std::string s(jenv->GetStringUTFChars(arg, 0), - jenv->GetStringLength(arg)); - if (s.length() != 0) args.push_back(s); + const char *s = jenv->GetStringUTFChars(arg, 0); + args.push_back(std::string(s, jenv->GetStringLength(arg))); + if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s); + if (args.back().length() == 0) args.pop_back(); } for (size_t i = 0; i < args.size(); ++i) { argv.push_back(&args[i][0]); } - RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]); + + RabitInit(args.size(), dmlc::BeginPtr(argv)); return 0; } diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 023827c44c8b..6d811ad88e90 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -194,15 +194,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel /* * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterLoadModelFromBuffer - * Signature: (JJJ)I + * Signature: (J[B)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer - (JNIEnv *, jclass, jlong, jlong, jlong); + (JNIEnv *, jclass, jlong, jbyteArray); /* * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterGetModelRaw - * Signature: (J[Ljava/lang/String;)I + * Signature: (J[[B)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw (JNIEnv *, jclass, jlong, jobjectArray); diff --git a/rabit b/rabit index 1392e9f3da59..be50e7b63224 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0 +Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043 diff --git a/src/common/common.cc b/src/common/common.cc index 6e12045f6e01..2010e9ee4af7 100644 --- a/src/common/common.cc +++ b/src/common/common.cc @@ -4,12 +4,20 @@ * \brief Enable all kinds of global variables in common. */ #include "./random.h" +#include "./thread_local.h" namespace xgboost { namespace common { +/*! \brief thread local entry for random. */ +struct RandomThreadLocalEntry { + /*! \brief the random engine instance. */ + GlobalRandomEngine engine; +}; + +typedef ThreadLocalStore RandomThreadLocalStore; + GlobalRandomEngine& GlobalRandom() { - static GlobalRandomEngine inst; - return inst; + return RandomThreadLocalStore::Get()->engine; } } } // namespace xgboost diff --git a/src/common/random.h b/src/common/random.h index f47ff5f75fe2..92f41410838c 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -61,7 +61,8 @@ typedef RandomEngine GlobalRandomEngine; /*! * \brief global singleton of a random engine. - * Only use this engine when necessary, not thread-safe. + * This random engine is thread-local and + * only visible to current thread. */ GlobalRandomEngine& GlobalRandom(); // NOLINT(*) diff --git a/src/common/thread_local.h b/src/common/thread_local.h index 6ea8eb5ab400..812fe973c917 100644 --- a/src/common/thread_local.h +++ b/src/common/thread_local.h @@ -6,6 +6,8 @@ #ifndef XGBOOST_COMMON_THREAD_LOCAL_H_ #define XGBOOST_COMMON_THREAD_LOCAL_H_ +#include + #if DMLC_ENABLE_STD_THREAD #include #endif diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index f74ffc1777fd..29fb114a6a62 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -15,6 +15,7 @@ #include #include #include +#include "../common/common.h" namespace xgboost { namespace gbm { @@ -265,13 +266,11 @@ class GBTree : public GradientBooster { inline void InitUpdater() { if (updaters.size() != 0) return; std::string tval = tparam.updater_seq; - char *pstr; - pstr = std::strtok(&tval[0], ","); - while (pstr != nullptr) { - std::unique_ptr up(TreeUpdater::Create(pstr)); + std::vector ups = common::Split(tval, ','); + for (const std::string& pstr : ups) { + std::unique_ptr up(TreeUpdater::Create(pstr.c_str())); up->Init(this->cfg); updaters.push_back(std::move(up)); - pstr = std::strtok(nullptr, ","); } } // do group specific group