Skip to content

Commit

Permalink
[DIST] Enable multiple thread and tracker, make rabit and xgboost mor…
Browse files Browse the repository at this point in the history
…e thread-safe by using thread local variables.
  • Loading branch information
tqchen committed Mar 4, 2016
1 parent 12dc92f commit e80d3db
Show file tree
Hide file tree
Showing 17 changed files with 323 additions and 153 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Expand Up @@ -22,6 +22,11 @@ This file records the changes in xgboost library in reverse chronological order.
- The windows version is still blocked due to Rtools do not support ```std::thread```.
* rabit and dmlc-core are maintained through git submodule
- Anyone can open PR to update these dependencies now.
* Improvements
- Rabit and xgboost libs are not thread-safe and use thread local PRNGs
- This could fix some of the previous problem which runs xgboost on multiple threads.
* JVM Package
- Enable xgboost4j for java and scala

## v0.47 (2016.01.14)

Expand Down
2 changes: 2 additions & 0 deletions jvm-packages/.gitignore
@@ -0,0 +1,2 @@
tracker.py
build.sh
2 changes: 2 additions & 0 deletions jvm-packages/create_jni.sh
Expand Up @@ -27,6 +27,8 @@ fi

rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
# copy python to native resources
cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py

popd > /dev/null
echo "complete"
2 changes: 1 addition & 1 deletion jvm-packages/test_distributed.sh
@@ -1,5 +1,5 @@
#!/bin/bash
# Simple script to test distributed version, to be deleted later.
cd xgboost4j-demo
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
cd ..
Expand Up @@ -2,42 +2,78 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import ml.dmlc.xgboost4j.*;


/**
* Distributed training example, used to quick test distributed training.
*
* @author tqchen
*/
public class DistTrain {
private static final Log logger = LogFactory.getLog(DistTrain.class);
private Map<String, String> envs = null;

private class Worker implements Runnable {
private int worker_id;
Worker(int worker_id) {
this.worker_id = worker_id;
}

public void run() {
try {
Map<String, String> worker_env = new HashMap<String, String>(envs);

public static void main(String[] args) throws IOException, XGBoostError {
// always initialize rabit module before training.
Rabit.init(new HashMap<String, String>());
worker_env.put("DMLC_TASK_ID", new Integer(worker_id).toString());
// always initialize rabit module before training.
Rabit.init(worker_env);

// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("nthread", 2);
params.put("objective", "binary:logistic");


HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);

//set round
int round = 2;
//set round
int round = 2;

//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);

// always shutdown rabit module after training.
Rabit.shutdown();
} catch (Exception ex){
logger.error(ex);
}
}
}

void start(int nworker) throws IOException, XGBoostError, InterruptedException {
RabitTracker tracker = new RabitTracker(nworker);
tracker.start();
envs = tracker.getWorkerEnvs();
for (int i = 0; i < nworker; ++i) {
new Thread(new Worker(i)).start();
}
tracker.waitFor();
}

// always shutdown rabit module after training.
Rabit.shutdown();
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
new DistTrain().start(Integer.parseInt(args[0]));
}
}
@@ -0,0 +1,78 @@
package ml.dmlc.xgboost4j;


import java.io.*;
import java.io.IOException;

/**
* Auxiliary utils to
*/
class FileUtil {
/**
* Create a temp file that copies the resource from current JAR archive
* <p/>
* The file from JAR is copied into system temp file.
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p/>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
* {@code path}.
* @param path Path to the resources in the jar
* @return The created temp file.
* @throws IOException
* @throws IllegalArgumentException
*/
static File createTempFileFromResource(String path) throws IOException, IllegalArgumentException {
// Obtain filename from path
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}

String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;

// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
}

// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}
// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();

if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}

// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;

// Open and check input stream
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}

// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
return temp;
}
}
Expand Up @@ -21,6 +21,9 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import ml.dmlc.xgboost4j.FileUtil;


/**
* class to load native library
*
Expand Down Expand Up @@ -61,59 +64,7 @@ public static synchronized void initXgBoost() throws IOException {
* three characters
*/
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{

if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}

// Obtain filename from path
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;

// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
}

// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}

// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();

if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}

// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;

// Open and check input stream
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}

// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}

File temp = FileUtil.createTempFileFromResource(path);
// Finally, load the library
System.load(temp.getAbsolutePath());
}
Expand Down
@@ -0,0 +1,98 @@
package ml.dmlc.xgboost4j;



import java.io.*;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
* Distributed RabitTracker, need to be started on driver code before running distributed jobs.
*/
public class RabitTracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file.
private static File tracker_py = null;
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int num_workers;
// child process
private Process process = null;
// logger thread
private Thread logger_thread = null;

//load native library
static {
try {
initTrackerPy();
} catch (IOException ex) {
logger.error("load tracker library failed.");
logger.error(ex);
}
}

/**
* Tracker logger that logs output from tracker.
*/
private class TrackerLogger implements Runnable {
public void run() {
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
String line;
try {
while ((line = reader.readLine()) != null) {
logger.info(line);
}
} catch (IOException ex) {
logger.error(ex.toString());
}
}
}

private static synchronized void initTrackerPy() throws IOException {
tracker_py = FileUtil.createTempFileFromResource("/tracker.py");
}


public RabitTracker(int num_workers) {
this.num_workers = num_workers;
}

/**
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, String> getWorkerEnvs() {
return envs;
}

public void start() throws IOException {
process = Runtime.getRuntime().exec("python " + tracker_py.getAbsolutePath() +
" --num-workers=" + new Integer(num_workers).toString());
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
break;
}
String []sep = line.split("=");
if (sep.length == 2) {
envs.put(sep[0], sep[1]);
}
}
logger.debug("Tracker started, with env=" + envs.toString());
// also start a tracker logger
logger_thread = new Thread(new TrackerLogger());
logger_thread.setDaemon(true);
logger_thread.start();
}

public void waitFor() throws InterruptedException {
process.waitFor();
}
}
Expand Up @@ -74,9 +74,9 @@ public final static native int XGBoosterPredict(long handle, long dmat, int opti

public final static native int XGBoosterSaveModel(long handle, String fname);

public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes);

public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes);

public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
String[][] out_strings);
Expand Down

0 comments on commit e80d3db

Please sign in to comment.