Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DIST] Enable multiple thread and tracker, make rabit and xgboost mor…
…e thread-safe by using thread local variables.
- Loading branch information
Showing
17 changed files
with
323 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule dmlc-core
updated
4 files
+13 −0 | tracker/dmlc_tracker/launcher.py | |
+5 −0 | tracker/dmlc_tracker/opts.py | |
+63 −2 | tracker/dmlc_tracker/tracker.py | |
+6 −0 | tracker/dmlc_tracker/yarn.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tracker.py | ||
build.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
#!/bin/bash | ||
# Simple script to test distributed version, to be deleted later. | ||
cd xgboost4j-demo | ||
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain | ||
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4 | ||
cd .. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package ml.dmlc.xgboost4j; | ||
|
||
|
||
import java.io.*; | ||
import java.io.IOException; | ||
|
||
/** | ||
* Auxiliary utils to | ||
*/ | ||
class FileUtil { | ||
/** | ||
* Create a temp file that copies the resource from current JAR archive | ||
* <p/> | ||
* The file from JAR is copied into system temp file. | ||
* The temporary file is deleted after exiting. | ||
* Method uses String as filename because the pathname is "abstract", not system-dependent. | ||
* <p/> | ||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to | ||
* {@code path}. | ||
* @param path Path to the resources in the jar | ||
* @return The created temp file. | ||
* @throws IOException | ||
* @throws IllegalArgumentException | ||
*/ | ||
static File createTempFileFromResource(String path) throws IOException, IllegalArgumentException { | ||
// Obtain filename from path | ||
if (!path.startsWith("/")) { | ||
throw new IllegalArgumentException("The path has to be absolute (start with '/')."); | ||
} | ||
|
||
String[] parts = path.split("/"); | ||
String filename = (parts.length > 1) ? parts[parts.length - 1] : null; | ||
|
||
// Split filename to prexif and suffix (extension) | ||
String prefix = ""; | ||
String suffix = null; | ||
if (filename != null) { | ||
parts = filename.split("\\.", 2); | ||
prefix = parts[0]; | ||
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-) | ||
} | ||
|
||
// Check if the filename is okay | ||
if (filename == null || prefix.length() < 3) { | ||
throw new IllegalArgumentException("The filename has to be at least 3 characters long."); | ||
} | ||
// Prepare temporary file | ||
File temp = File.createTempFile(prefix, suffix); | ||
temp.deleteOnExit(); | ||
|
||
if (!temp.exists()) { | ||
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist."); | ||
} | ||
|
||
// Prepare buffer for data copying | ||
byte[] buffer = new byte[1024]; | ||
int readBytes; | ||
|
||
// Open and check input stream | ||
InputStream is = NativeLibLoader.class.getResourceAsStream(path); | ||
if (is == null) { | ||
throw new FileNotFoundException("File " + path + " was not found inside JAR."); | ||
} | ||
|
||
// Open output stream and copy data between source file in JAR and the temporary file | ||
OutputStream os = new FileOutputStream(temp); | ||
try { | ||
while ((readBytes = is.read(buffer)) != -1) { | ||
os.write(buffer, 0, readBytes); | ||
} | ||
} finally { | ||
// If read/write fails, close streams safely before throwing an exception | ||
os.close(); | ||
is.close(); | ||
} | ||
return temp; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package ml.dmlc.xgboost4j; | ||
|
||
|
||
|
||
import java.io.*; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.apache.commons.logging.Log; | ||
import org.apache.commons.logging.LogFactory; | ||
|
||
/** | ||
* Distributed RabitTracker, need to be started on driver code before running distributed jobs. | ||
*/ | ||
public class RabitTracker { | ||
// Maybe per tracker logger? | ||
private static final Log logger = LogFactory.getLog(RabitTracker.class); | ||
// tracker python file. | ||
private static File tracker_py = null; | ||
// environment variable to be pased. | ||
private Map<String, String> envs = new HashMap<String, String>(); | ||
// number of workers to be submitted. | ||
private int num_workers; | ||
// child process | ||
private Process process = null; | ||
// logger thread | ||
private Thread logger_thread = null; | ||
|
||
//load native library | ||
static { | ||
try { | ||
initTrackerPy(); | ||
} catch (IOException ex) { | ||
logger.error("load tracker library failed."); | ||
logger.error(ex); | ||
} | ||
} | ||
|
||
/** | ||
* Tracker logger that logs output from tracker. | ||
*/ | ||
private class TrackerLogger implements Runnable { | ||
public void run() { | ||
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); | ||
String line; | ||
try { | ||
while ((line = reader.readLine()) != null) { | ||
logger.info(line); | ||
} | ||
} catch (IOException ex) { | ||
logger.error(ex.toString()); | ||
} | ||
} | ||
} | ||
|
||
private static synchronized void initTrackerPy() throws IOException { | ||
tracker_py = FileUtil.createTempFileFromResource("/tracker.py"); | ||
} | ||
|
||
|
||
public RabitTracker(int num_workers) { | ||
this.num_workers = num_workers; | ||
} | ||
|
||
/** | ||
* Get environments that can be used to pass to worker. | ||
* @return The environment settings. | ||
*/ | ||
public Map<String, String> getWorkerEnvs() { | ||
return envs; | ||
} | ||
|
||
public void start() throws IOException { | ||
process = Runtime.getRuntime().exec("python " + tracker_py.getAbsolutePath() + | ||
" --num-workers=" + new Integer(num_workers).toString()); | ||
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); | ||
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START"); | ||
String line; | ||
while ((line = reader.readLine()) != null) { | ||
if (line.trim().equals("DMLC_TRACKER_ENV_END")) { | ||
break; | ||
} | ||
String []sep = line.split("="); | ||
if (sep.length == 2) { | ||
envs.put(sep[0], sep[1]); | ||
} | ||
} | ||
logger.debug("Tracker started, with env=" + envs.toString()); | ||
// also start a tracker logger | ||
logger_thread = new Thread(new TrackerLogger()); | ||
logger_thread.setDaemon(true); | ||
logger_thread.start(); | ||
} | ||
|
||
public void waitFor() throws InterruptedException { | ||
process.waitFor(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.