Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for federated learning #7831

Merged
merged 28 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
48fc024
add federated plugin
rongou Apr 4, 2022
5e1ff7e
add federation server and test client
rongou Apr 8, 2022
cd52ceb
minor cleanup
rongou Apr 8, 2022
b63395a
implemented allreduce/allgather/broadcast
rongou Apr 16, 2022
66bf977
refactor federated server
rongou Apr 16, 2022
549fbdb
more refactoring of the federated server
rongou Apr 16, 2022
4723c00
support custom reduction
rongou Apr 18, 2022
457f690
remove unused includes
rongou Apr 19, 2022
095f717
Merge remote-tracking branch 'upstream/master' into federated
rongou Apr 19, 2022
66e9425
fix finalize
rongou Apr 19, 2022
4dc81df
no splitting data in federated mode
rongou Apr 19, 2022
e40ca5d
update readme
rongou Apr 19, 2022
c99883c
support more than 10 workers
rongou Apr 20, 2022
80f8593
Merge remote-tracking branch 'upstream/master' into federated
rongou Apr 21, 2022
60ef12e
add some comments and copyright headers
rongou Apr 21, 2022
5ce184b
Merge remote-tracking branch 'upstream/master' into federated
rongou Apr 22, 2022
dca5902
add mutual ssl/tls authentication
rongou Apr 23, 2022
05fce58
simplify cert generation
rongou Apr 23, 2022
ac8be18
Merge remote-tracking branch 'upstream/master' into federated
rongou Apr 25, 2022
6c16319
change to functors
rongou Apr 26, 2022
851ccde
add c api
rongou Apr 27, 2022
b7ba8ae
add federated server unit tests
rongou Apr 28, 2022
f2164c6
exclude federated tests when plugin not enabled
rongou Apr 29, 2022
9819f85
Merge remote-tracking branch 'upstream/master' into federated
rongou Apr 29, 2022
bb0896e
revert accidiental change
rongou Apr 29, 2022
7ea426c
Merge remote-tracking branch 'upstream/master' into federated
rongou Apr 29, 2022
5bf7e35
Merge remote-tracking branch 'upstream/master' into federated
rongou May 2, 2022
ba52021
address review comments
rongou May 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ address, leak, undefined and thread.")
## Plugins
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF)
option(PLUGIN_FEDERATED "Build with Federated Learning" OFF)
## TODO: 1. Add check if DPC++ compiler is used for building
option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF)
option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)
Expand Down
5 changes: 5 additions & 0 deletions plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ if (PLUGIN_UPDATER_ONEAPI)
# Add all objects of oneapi_plugin to objxgboost
target_sources(objxgboost INTERFACE $<TARGET_OBJECTS:oneapi_plugin>)
endif (PLUGIN_UPDATER_ONEAPI)

# Add the Federate Learning plugin if enabled.
if (PLUGIN_FEDERATED)
add_subdirectory(federated)
endif (PLUGIN_FEDERATED)
27 changes: 27 additions & 0 deletions plugin/federated/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# gRPC needs to be installed first. See README.md.
find_package(Protobuf REQUIRED)
find_package(gRPC REQUIRED)
find_package(Threads)

# Generated code from the protobuf definition.
add_library(federated_proto federated.proto)
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)

get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
protobuf_generate(
TARGET federated_proto
LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}")

# Wrapper for the gRPC client.
add_library(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)

# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc engine_federated.cc)
target_link_libraries(objxgboost PRIVATE federated_client)
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
35 changes: 35 additions & 0 deletions plugin/federated/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
XGBoost Plugin for Federated Learning
=====================================

This folder contains the plugin for federated learning. Follow these steps to build and test it.

Install gRPC
------------
```shell
sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build
git clone -b v1.45.2 https://github.com/grpc/grpc
cd grpc
git submodule update --init
cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON
cmake --build build --target install
```

Build the Plugin
----------------
```shell
# Under xgboost source tree.
mkdir build
cd build
cmake .. -GNinja -DPLUGIN_FEDERATED=ON
ninja
cd ../python-package
pip install -e . # or equivalently python setup.py develop
```

Test Federated XGBoost
----------------------
```shell
# Under xgboost source tree.
cd tests/distributed
./runtests-federated.sh
```
274 changes: 274 additions & 0 deletions plugin/federated/engine_federated.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <cstdio>
#include <fstream>
#include <sstream>

#include "federated_client.h"
#include "rabit/internal/engine.h"
#include "rabit/internal/utils.h"

namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
} // namespace MPI

namespace rabit {
namespace engine {

/*! \brief implementation of engine using federated learning */
class FederatedEngine : public IEngine {
public:
void Init(int argc, char *argv[]) {
// Parse environment variables first.
for (auto const &env_var : env_vars_) {
char const *value = getenv(env_var.c_str());
if (value != nullptr) {
SetParam(env_var, value);
}
}
// Command line argument overrides.
for (int i = 0; i < argc; ++i) {
std::string const key_value = argv[i];
auto const delimiter = key_value.find('=');
if (delimiter != std::string::npos) {
SetParam(key_value.substr(0, delimiter), key_value.substr(delimiter + 1));
}
}
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
server_address_.c_str(), world_size_, rank_);
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
}

void Finalize() { client_.reset(); }

void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end,
size_t size_prev_slice) override {
throw std::logic_error("FederatedEngine:: Allgather is not supported");
}

std::string Allgather(void *sendbuf, size_t total_size) {
std::string const send_buffer(reinterpret_cast<char *>(sendbuf), total_size);
return client_->Allgather(send_buffer);
}

void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, ReduceFunction reducer,
PreprocFunction prepare_fun, void *prepare_arg) override {
throw std::logic_error("FederatedEngine:: Allreduce is not supported, use Allreduce_ instead");
}

void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
receive_buffer.copy(buffer, size);
}

int GetRingPrevRank() const override {
throw std::logic_error("FederatedEngine:: GetRingPrevRank is not supported");
}

void Broadcast(void *sendrecvbuf, size_t size, int root) override {
if (world_size_ == 1) return;
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Broadcast(send_buffer, root);
if (rank_ != root) {
receive_buffer.copy(buffer, size);
}
}

int LoadCheckPoint(Serializable *global_model, Serializable *local_model = nullptr) override {
return 0;
}

void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number_ += 1;
}

void LazyCheckPoint(const Serializable *global_model) override { version_number_ += 1; }

int VersionNumber() const override { return version_number_; }

/*! \brief get rank of current node */
int GetRank() const override { return rank_; }

/*! \brief get total number of */
int GetWorldSize() const override { return world_size_; }

/*! \brief whether it is distributed */
bool IsDistributed() const override { return true; }

/*! \brief get the host name of current node */
std::string GetHost() const override { return "rank" + std::to_string(rank_); }

void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}

private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}

/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}

void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;
} else if (!strcasecmp(name.c_str(), "FEDERATED_WORLD_SIZE")) {
world_size_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_RANK")) {
rank_ = std::stoi(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_CERT")) {
server_cert_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_KEY")) {
client_key_ = ReadFile(val);
} else if (!strcasecmp(name.c_str(), "FEDERATED_CLIENT_CERT")) {
client_cert_ = ReadFile(val);
}
}

static std::string ReadFile(std::string const &path) {
auto stream = std::ifstream(path.data());
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}

// clang-format off
std::vector<std::string> const env_vars_{
"FEDERATED_SERVER_ADDRESS",
"FEDERATED_WORLD_SIZE",
"FEDERATED_RANK",
"FEDERATED_SERVER_CERT",
"FEDERATED_CLIENT_KEY",
"FEDERATED_CLIENT_CERT" };
// clang-format on
std::string server_address_{"localhost:9091"};
int world_size_{1};
int rank_{0};
std::string server_cert_{};
std::string client_key_{};
std::string client_cert_{};
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
int version_number_{0};
};

// Singleton federated engine.
FederatedEngine engine; // NOLINT(cert-err58-cpp)

/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
engine.Init(argc, argv);
return true;
} catch (std::exception const &e) {
fprintf(stderr, " failed in federated Init %s\n", e.what());
return false;
}
}

/*! \brief finalize synchronization module */
bool Finalize() {
try {
engine.Finalize();
return true;
} catch (const std::exception &e) {
fprintf(stderr, "failed in federated shutdown %s\n", e.what());
return false;
}
}

/*! \brief singleton method to get engine */
IEngine *GetEngine() { return &engine; }

// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;
engine.Allreduce(sendrecvbuf, type_nbytes * count, dtype, op);
}

ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;

int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast<int>(dtype.type_size); }

void ReduceHandle::Init(IEngine::ReduceFunction redfunc,
__attribute__((unused)) size_t type_nbytes) {
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}

void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun, void *prepare_arg) {
utils::Assert(redfunc_ != nullptr, "must initialize handle to call AllReduce");
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (engine.GetWorldSize() == 1) return;

// Gather all the buffers and call the reduce function locally.
auto const buffer_size = type_nbytes * count;
auto const gathered = engine.Allgather(sendrecvbuf, buffer_size);
auto const *data = gathered.data();
for (int i = 0; i < engine.GetWorldSize(); i++) {
if (i != engine.GetRank()) {
redfunc_(data + buffer_size * i, sendrecvbuf, static_cast<int>(count),
MPI::Datatype(type_nbytes));
}
}
}

} // namespace engine
} // namespace rabit
Loading