Skip to content

Commit

Permalink
[jvm-packages] refine tracker (#10313)
Browse files Browse the repository at this point in the history
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
  • Loading branch information
wbo4958 and trivialfis committed May 23, 2024
1 parent 966dc81 commit 932d720
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 92 deletions.
14 changes: 5 additions & 9 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.19.0</flink.version>
<junit.version>4.13.2</junit.version>
<spark.version>3.4.1</spark.version>
<spark.version.gpu>3.4.1</spark.version.gpu>
<spark.version>3.5.1</spark.version>
<spark.version.gpu>3.5.1</spark.version.gpu>
<fasterxml.jackson.version>2.15.2</fasterxml.jackson.version>
<scala.version>2.12.18</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.4.0</hadoop.version>
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<cudf.version>23.12.1</cudf.version>
<spark.rapids.version>23.12.1</spark.rapids.version>
<cudf.version>24.04.0</cudf.version>
<spark.rapids.version>24.04.0</spark.rapids.version>
<cudf.classifier>cuda12</cudf.classifier>
<scalatest.version>3.2.18</scalatest.version>
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
Expand Down Expand Up @@ -489,11 +490,6 @@
<artifactId>kryo</artifactId>
<version>5.6.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start()) {
return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerArgs()))
.reduce((x, y) -> x)
.collect()
.get(0);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2022 by Contributors
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.functions.{col, collect_list, struct}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
Expand Down Expand Up @@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
.groupBy(groupName)
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")

implicit val encoder = RowEncoder(schema)
implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
// Expand the grouped rows after repartition
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
new Iterator[Row] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}

private[spark] def buildRabitParams : Map[String, String] = Map(
"rabit_reduce_ring_mincount" ->
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
"rabit_debug" ->
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
"rabit_timeout" ->
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
"rabit_timeout_sec" -> {
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
overridedParams.get("rabit_timeout").toString
} else {
"1800"
}
},
"DMLC_WORKER_CONNECT_RETRY" ->
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
)
}

/**
Expand Down Expand Up @@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel {
}
}

/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker: ITracker = new RabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
tracker
}

private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker = getTracker(nWorkers, trackerConf)
// Executes the provided code block inside a tracker and then stops the tracker
private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
require(tracker.start(), "FAULT: Failed to start tracker")
tracker
try {
block(tracker)
} finally {
tracker.stop()
}
}

/**
Expand All @@ -501,55 +481,53 @@ object XGBoost extends XGBoostStageLevel {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")

val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams

val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull

// Get the training data RDD and the cachedRDD
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)

try {
// Train for every ${savingRound} rounds and save the partially completed booster
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
val (booster, metrics) = try {
tracker.workerArgs().putAll(xgbRabitParams)
val rabitEnv = tracker.workerArgs
val (booster, metrics) = withTracker(
runtimeParams.numWorkers,
runtimeParams.trackerConf
) { tracker =>
val rabitEnv = tracker.getWorkerArgs()

val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
var optionWatches: Option[() => Watches] = None

// take the first Watches to train
if (iter.hasNext) {
optionWatches = Some(iter.next())
}

optionWatches.map { buildWatches => buildDistributedBooster(buildWatches,
xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
.getOrElse(throw new RuntimeException("No Watches to train"))

}}
optionWatches.map { buildWatches =>
buildDistributedBooster(buildWatches,
runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
}.getOrElse(throw new RuntimeException("No Watches to train"))
}

val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
boostersAndMetrics)
// The repartition step is to make training stage as ShuffleMapStage, so that when one
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
(booster, metrics)
} finally {
tracker.stop()
}

// we should delete the checkpoint directory after a successful training
xgbExecParams.checkpointParam.foreach {
runtimeParams.checkpointParam.foreach {
cpParam =>
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {

val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker. workerArgs
val trackerEnvs = tracker.getWorkerArgs

val workerCount: Int = numWorkers
/*
Expand Down Expand Up @@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker.workerArgs
val trackerEnvs = tracker.getWorkerArgs

val workerCount: Int = numWorkers

Expand Down
6 changes: 6 additions & 0 deletions jvm-packages/xgboost4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
<version>${scalatest.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*
* - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds).
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
* - getWorkerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
Expand All @@ -21,21 +21,8 @@
* brokers connections between workers.
*/
public interface ITracker extends Thread.UncaughtExceptionHandler {
enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);

private int statusCode;

TrackerStatus(int statusCode) {
this.statusCode = statusCode;
}

public int getStatusCode() {
return this.statusCode;
}
}

Map<String, Object> workerArgs() throws XGBoostError;
Map<String, Object> getWorkerArgs() throws XGBoostError;

boolean start() throws XGBoostError;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/*
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.java;

import java.util.Map;
Expand All @@ -10,14 +26,12 @@

/**
* Java implementation of the Rabit tracker to coordinate distributed workers.
*
* The tracker must be started on driver node before running distributed jobs.
*/
public class RabitTracker implements ITracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
private long handle = 0;
private Thread tracker_daemon;
private Thread trackerDaemon;

public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, "");
Expand All @@ -44,24 +58,22 @@ public void uncaughtException(Thread t, Throwable e) {
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
this.tracker_daemon.interrupt();
this.trackerDaemon.interrupt();
}
}

/**
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, Object> workerArgs() throws XGBoostError {
public Map<String, Object> getWorkerArgs() throws XGBoostError {
// fixme: timeout
String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper();
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
};
Map<String, Object> config;
try {
config = mapper.readValue(args[0], typeRef);
config = mapper.readValue(args[0], new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to get worker arguments.", ex);
}
Expand All @@ -74,18 +86,18 @@ public void stop() throws XGBoostError {

public boolean start() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
this.tracker_daemon = new Thread(() -> {
this.trackerDaemon = new Thread(() -> {
try {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
waitFor(0);
} catch (XGBoostError ex) {
logger.error(ex);
return; // exit the thread
}
});
this.tracker_daemon.setDaemon(true);
this.tracker_daemon.start();
this.trackerDaemon.setDaemon(true);
this.trackerDaemon.start();

return this.tracker_daemon.isAlive();
return this.trackerDaemon.isAlive();
}

public void waitFor(long timeout) throws XGBoostError {
Expand Down

0 comments on commit 932d720

Please sign in to comment.