Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp the rabit implementation. #10112

Merged
merged 45 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
6c446ca
Rework RABIT.
trivialfis Aug 31, 2023
abf2419
cleanup.
trivialfis Apr 22, 2024
9fa105d
cleanup jvm packages.
trivialfis Apr 25, 2024
3929daa
Fix demo chunks.
trivialfis Apr 23, 2024
ae0f49d
consistent jvm.
trivialfis Apr 25, 2024
08ccefc
Revert "consistent jvm."
trivialfis Apr 26, 2024
22194a7
non-jvm changes.
trivialfis Apr 26, 2024
bb8e40e
bisect.
trivialfis Apr 26, 2024
e5c35ec
revert test changes.
trivialfis Apr 27, 2024
81d8692
revert gh changes.
trivialfis Apr 28, 2024
99d55b3
jvm changes.
trivialfis Apr 28, 2024
baf99f3
unused import.
trivialfis Apr 29, 2024
353f2de
Revert "unused import."
trivialfis Apr 29, 2024
4690758
Revert "Revert "unused import.""
trivialfis Apr 29, 2024
72ef55b
update jackson.
trivialfis Apr 29, 2024
594bbc9
try latest spark.
trivialfis Apr 29, 2024
2232a6e
scope.
trivialfis Apr 29, 2024
33876f6
Revert "try latest spark."
trivialfis Apr 29, 2024
c3547ab
Fix.
trivialfis Apr 29, 2024
269090c
GPU package.
trivialfis May 6, 2024
eb7b88e
more.
trivialfis May 6, 2024
b4e97f9
Revert "more."
trivialfis May 8, 2024
b56ba97
Revert jackson version
trivialfis May 8, 2024
ae44c1a
Merge branch 'master' into rabit-worker-get-all
trivialfis May 9, 2024
83ac716
Fix secure definition.
trivialfis May 11, 2024
56cf189
Merge branch 'master' into rabit-worker-get-all
trivialfis May 11, 2024
a1bfe5f
Merge branch 'master' into rabit-worker-get-all
trivialfis May 14, 2024
04cb943
Work on revamping the blocking implementation.
trivialfis May 15, 2024
92a1bc2
Free.
trivialfis May 15, 2024
d91cb40
linter.
trivialfis May 15, 2024
ec4b35c
lint.
trivialfis May 15, 2024
44a43a9
Merge branch 'master' into rabit-worker-get-all
trivialfis May 15, 2024
cc70c88
rng.
trivialfis May 15, 2024
61983c4
Merge remote-tracking branch 'jiamingy/rabit-worker-get-all' into rab…
trivialfis May 15, 2024
e334141
Fix rc check.
trivialfis May 15, 2024
53d072e
err
trivialfis May 15, 2024
53d2a73
Log.
trivialfis May 16, 2024
3391490
Log time stamp.
trivialfis May 16, 2024
c9129ef
cleanup.
trivialfis May 16, 2024
4c90247
Simplify the loop.
trivialfis May 17, 2024
7bb4f22
Never throw.
trivialfis May 17, 2024
c5331f4
remove ref.
trivialfis May 17, 2024
a53c87b
Merge branch 'master' into rabit-worker-get-all
trivialfis May 17, 2024
abc5f3b
Windows
trivialfis May 17, 2024
e4aa87b
Windows.
trivialfis May 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF)
option(USE_DEVICE_DEBUG "Generate CUDA device debug info." OFF)
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header")
option(RABIT_MOCK "Build rabit with mock" OFF)
option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
## CUDA
Expand Down Expand Up @@ -282,9 +281,6 @@ if(MSVC)
endif()
endif()

# rabit
add_subdirectory(rabit)

# core xgboost
add_subdirectory(${xgboost_SOURCE_DIR}/src)
target_link_libraries(objxgboost PUBLIC dmlc)
Expand Down
8 changes: 1 addition & 7 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/socket.o \
Expand All @@ -134,7 +131,4 @@ OBJECTS= \
$(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \
$(PKGROOT)/src/c_api/c_api_error.o \
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine.o \
$(PKGROOT)/rabit/src/rabit_c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o
$(PKGROOT)/amalgamation/dmlc-minimum0.o
8 changes: 1 addition & 7 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/socket.o \
Expand All @@ -134,7 +131,4 @@ OBJECTS= \
$(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \
$(PKGROOT)/src/c_api/c_api_error.o \
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine.o \
$(PKGROOT)/rabit/src/rabit_c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o
$(PKGROOT)/amalgamation/dmlc-minimum0.o
1 change: 1 addition & 0 deletions cmake/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ function(xgboost_set_cuda_flags target)
target_include_directories(
${target} PRIVATE
${xgboost_SOURCE_DIR}/gputreeshap
${xgboost_SOURCE_DIR}/rabit/include
${CUDAToolkit_INCLUDE_DIRS})

if(MSVC)
Expand Down
2 changes: 1 addition & 1 deletion demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def main(client: Client) -> None:
m = 100000
n = 100
rng = da.random.default_rng(1)
X = rng.normal(size=(m, n))
X = rng.normal(size=(m, n), chunks=(10000, -1))
y = X.sum(axis=1)

# DaskDMatrix acts like normal DMatrix, works as a proxy for local
Expand Down
184 changes: 105 additions & 79 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
*
* @return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
char const *c_json_config, DMatrixHandle m,
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values,
char const *config, DMatrixHandle m,
bst_ulong const **out_shape, bst_ulong *out_dim,
const float **out_result);

Expand Down Expand Up @@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
*
* @brief Experimental support for exposing internal communicator in XGBoost.
*
* @note This is still under development.
*
* The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has
* changed significantly since its adoption. It consists of a tracker and a set of
* workers. The tracker is responsible for bootstrapping the communication group and
* handling centralized tasks like logging. The workers are actual communicators
* performing collective tasks like allreduce.
*
* To use the collective implementation, one needs to first create a tracker with
* corresponding parameters, then get the arguments for workers using
* XGTrackerWorkerArgs(). The obtained arguments can then be passed to the
* XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a
* XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses
* `std::thread` in C++, which has undefined behavior in a C++ destructor due to the
* runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the
* runtime is shutting down. This requirement is similar to a Python thread or socket,
* which should not be relied upon in a `__del__` function.
*
* Since it's used as a part of XGBoost, errors will be returned when a XGBoost function
* is called, for instance, training a booster might return a connection error.
*
* @{
*/

/**
* @brief Handle to tracker.
* @brief Handle to the tracker.
*
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
* other one is `federated`.
* other one is `federated`. `rabit` is used for normal collective communication, while
* `federated` is used for federated learning.
*
* This is still under development.
*/
typedef void *TrackerHandle; /* NOLINT */

Expand All @@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */
*
* @param config JSON encoded parameters.
*
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
* and `federated`.
* - dmlc_communicator: String, the type of tracker to create. Available options are
* `rabit` and `federated`. See @ref TrackerHandle for more info.
* - n_workers: Integer, the number of workers.
* - port: (Optional) Integer, the port this tracker should listen to.
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
* - timeout: (Optional) Integer, timeout in seconds for various networking
operations. Default is 300 seconds.
*
* Some configurations are `rabit` specific:
*
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
* This can be useful when the communicator cannot reliably obtain the host address.
* - sortby: (Optional) Integer.
* + 0: Sort workers by their host name.
* + 1: Sort workers by task IDs.
*
* Some `federated` specific configurations:
* - federated_secure: Boolean, whether this is a secure server.
* - federated_secure: Boolean, whether this is a secure server. False for testing.
* - server_key_path: Path to the server key. Used only if this is a secure server.
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
Expand Down Expand Up @@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
*/
XGB_DLL int XGTrackerFree(TrackerHandle handle);

/*!
* \brief Initialize the collective communicator.
/**
* @brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
* Call this once in the worker process before using anything. Please make sure
* XGCommunicatorFinalize() is called after use. The initialized commuicator is a global
* thread-local variable.
*
* \param config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* @param config JSON encoded configuration. Accepted JSON keys are:
* - dmlc_communicator: The type of the communicator, this should match the tracker type.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
* `USE_DLOPEN_NCCL`.
* Only applicable to the Federated communicator (use upper case for environment variables, use
*
* Only applicable to the `rabit` communicator:
* - dmlc_tracker_uri: Hostname or IP address of the tracker.
* - dmlc_tracker_port: Port number of the tracker.
* - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - dmlc_retry: The number of retries for connection failure.
* - dmlc_timeout: Timeout in seconds.
* - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`.
*
* Only applicable to the `federated` communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
* \return 0 for success, -1 for failure.
* - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key_path: Client key file path. Only needed for the SSL mode.
* - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorInit(char const* config);

/*!
* \brief Finalize the collective communicator.
/**
* @brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
* Call this function after you have finished all jobs.
*
* \return 0 for success, -1 for failure.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorFinalize(void);

/*!
* \brief Get rank of current process.
/**
* @brief Get rank of the current process.
*
* \return Rank of the worker.
* @return Rank of the worker.
*/
XGB_DLL int XGCommunicatorGetRank(void);

/*!
* \brief Get total number of processes.
/**
* @brief Get the total number of processes.
*
* \return Total world size.
* @return Total world size.
*/
XGB_DLL int XGCommunicatorGetWorldSize(void);

/*!
* \brief Get if the communicator is distributed.
/**
* @brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
* @return True if the communicator is distributed.
*/
XGB_DLL int XGCommunicatorIsDistributed(void);

/*!
* \brief Print the message to the communicator.
/**
* @brief Print the message to the tracker.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
* This function can be used to communicate the information of the progress to the user
* who monitors the tracker.
*
* \param message The message to be printed.
* \return 0 for success, -1 for failure.
* @param message The message to be printed.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorPrint(char const *message);

/*!
* \brief Get the name of the processor.
/**
* @brief Get the name of the processor.
*
* \param name_str Pointer to received returned processor name.
* \return 0 for success, -1 for failure.
* @param name_str Pointer to received returned processor name.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);

/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
/**
* @brief Broadcast a memory region to all others from root. This function is NOT
* thread-safe.
*
* Example:
* \code
* @code
* int a = 1;
* Broadcast(&a, sizeof(a), root);
* \endcode
* @endcode
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
* \return 0 for success, -1 for failure.
* @param send_receive_buffer Pointer to the send or receive buffer.
* @param size Size of the data in bytes.
* @param root The process rank to broadcast from.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);

/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
/**
* @brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* \code
* vector<int> data(10);
* @code
* enum class Op {
* kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
* };
* std::vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
* ...
* \endcode
* @endcode

* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
* \return 0 for success, -1 for failure.
* @param send_receive_buffer Buffer for both sending and receiving data.
* @param count Number of elements to be reduced.
* @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);

Expand Down
11 changes: 1 addition & 10 deletions jvm-packages/create_jni.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"USE_NCCL": "OFF",
"JVM_BINDINGS": "ON",
"LOG_CAPI_INVOCATION": "OFF",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
}


Expand Down Expand Up @@ -97,10 +98,6 @@ def native_build(args):

args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]

# if enviorment set rabit_mock
if os.getenv("RABIT_MOCK", None) is not None:
args.append("-DRABIT_MOCK:BOOL=ON")

# if enviorment set GPU_ARCH_FLAG
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
if gpu_arch_flag is not None:
Expand Down Expand Up @@ -162,12 +159,6 @@ def native_build(args):
maybe_makedirs(output_folder)
cp("../lib/" + library_name, output_folder)

print("copying pure-Python tracker")
cp(
"../python-package/xgboost/tracker.py",
"{}/src/main/resources".format(xgboost4j),
)

print("copying train/test files")
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
with cd("../demo/CLI/regression"):
Expand Down
5 changes: 5 additions & 0 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@
<artifactId>kryo</artifactId>
<version>5.6.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
Expand Down
Loading
Loading