Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] refine tracker #10313

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.19.0</flink.version>
<junit.version>4.13.2</junit.version>
<spark.version>3.4.1</spark.version>
<spark.version.gpu>3.4.1</spark.version.gpu>
<spark.version>3.5.1</spark.version>
<spark.version.gpu>3.5.1</spark.version.gpu>
<fasterxml.jackson.version>2.15.2</fasterxml.jackson.version>
<scala.version>2.12.18</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.4.0</hadoop.version>
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<cudf.version>23.12.1</cudf.version>
<spark.rapids.version>23.12.1</spark.rapids.version>
<cudf.version>24.04.0</cudf.version>
<spark.rapids.version>24.04.0</spark.rapids.version>
<cudf.classifier>cuda12</cudf.classifier>
<scalatest.version>3.2.18</scalatest.version>
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
Expand Down Expand Up @@ -489,11 +490,6 @@
<artifactId>kryo</artifactId>
<version>5.6.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start()) {
return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerArgs()))
.reduce((x, y) -> x)
.collect()
.get(0);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2022 by Contributors
Copyright (c) 2021-2024 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.functions.{col, collect_list, struct}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
Expand Down Expand Up @@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
.groupBy(groupName)
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")

implicit val encoder = RowEncoder(schema)
implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
// Expand the grouped rows after repartition
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
new Iterator[Row] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}

private[spark] def buildRabitParams : Map[String, String] = Map(
"rabit_reduce_ring_mincount" ->
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
"rabit_debug" ->
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
"rabit_timeout" ->
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
"rabit_timeout_sec" -> {
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
overridedParams.get("rabit_timeout").toString
} else {
"1800"
}
},
"DMLC_WORKER_CONNECT_RETRY" ->
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
)
}

/**
Expand Down Expand Up @@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel {
}
}

/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker: ITracker = new RabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
tracker
}

private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker = getTracker(nWorkers, trackerConf)
// Executes the provided code block inside a tracker and then stops the tracker
private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
require(tracker.start(), "FAULT: Failed to start tracker")
tracker
try {
block(tracker)
} finally {
tracker.stop()
}
}

/**
Expand All @@ -501,55 +481,53 @@ object XGBoost extends XGBoostStageLevel {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")

val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams

val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull

// Get the training data RDD and the cachedRDD
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)

try {
// Train for every ${savingRound} rounds and save the partially completed booster
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
val (booster, metrics) = try {
tracker.workerArgs().putAll(xgbRabitParams)
val rabitEnv = tracker.workerArgs
val (booster, metrics) = withTracker(
runtimeParams.numWorkers,
runtimeParams.trackerConf
) { tracker =>
val rabitEnv = tracker.getWorkerArgs()

val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
var optionWatches: Option[() => Watches] = None

// take the first Watches to train
if (iter.hasNext) {
optionWatches = Some(iter.next())
}

optionWatches.map { buildWatches => buildDistributedBooster(buildWatches,
xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
.getOrElse(throw new RuntimeException("No Watches to train"))

}}
optionWatches.map { buildWatches =>
buildDistributedBooster(buildWatches,
runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
}.getOrElse(throw new RuntimeException("No Watches to train"))
}

val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
boostersAndMetrics)
// The repartition step is to make training stage as ShuffleMapStage, so that when one
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
(booster, metrics)
} finally {
tracker.stop()
}

// we should delete the checkpoint directory after a successful training
xgbExecParams.checkpointParam.foreach {
runtimeParams.checkpointParam.foreach {
cpParam =>
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {

val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker. workerArgs
val trackerEnvs = tracker.getWorkerArgs

val workerCount: Int = numWorkers
/*
Expand Down Expand Up @@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker.workerArgs
val trackerEnvs = tracker.getWorkerArgs

val workerCount: Int = numWorkers

Expand Down
6 changes: 6 additions & 0 deletions jvm-packages/xgboost4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
<version>${scalatest.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*
* - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds).
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
* - getWorkerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
Expand All @@ -21,21 +21,8 @@
* brokers connections between workers.
*/
public interface ITracker extends Thread.UncaughtExceptionHandler {
enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);

private int statusCode;

TrackerStatus(int statusCode) {
this.statusCode = statusCode;
}

public int getStatusCode() {
return this.statusCode;
}
}

Map<String, Object> workerArgs() throws XGBoostError;
Map<String, Object> getWorkerArgs() throws XGBoostError;

boolean start() throws XGBoostError;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/*
Copyright (c) 2014-2024 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.java;

import java.util.Map;
Expand All @@ -10,14 +26,12 @@

/**
* Java implementation of the Rabit tracker to coordinate distributed workers.
*
* The tracker must be started on driver node before running distributed jobs.
*/
public class RabitTracker implements ITracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
private long handle = 0;
private Thread tracker_daemon;
private Thread trackerDaemon;

public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, "");
Expand All @@ -44,24 +58,22 @@ public void uncaughtException(Thread t, Throwable e) {
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
this.tracker_daemon.interrupt();
this.trackerDaemon.interrupt();
}
}

/**
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, Object> workerArgs() throws XGBoostError {
public Map<String, Object> getWorkerArgs() throws XGBoostError {
// fixme: timeout
String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper();
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
};
Map<String, Object> config;
try {
config = mapper.readValue(args[0], typeRef);
config = mapper.readValue(args[0], new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to get worker arguments.", ex);
}
Expand All @@ -74,18 +86,18 @@ public void stop() throws XGBoostError {

public boolean start() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
this.tracker_daemon = new Thread(() -> {
this.trackerDaemon = new Thread(() -> {
try {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
waitFor(0);
} catch (XGBoostError ex) {
logger.error(ex);
return; // exit the thread
}
});
this.tracker_daemon.setDaemon(true);
this.tracker_daemon.start();
this.trackerDaemon.setDaemon(true);
this.trackerDaemon.start();

return this.tracker_daemon.isAlive();
return this.trackerDaemon.isAlive();
}

public void waitFor(long timeout) throws XGBoostError {
Expand Down
Loading