Skip to content

Commit

Permalink
Start convert the scala package.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 10, 2024
1 parent b31f862 commit 3001987
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,13 +68,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(cmd.contains(s" --host-ip=${hostIp}"))
case _ => assert(false, "expected python tracker implementation")
}
}

test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
Expand Down
12 changes: 9 additions & 3 deletions src/collective/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,16 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return Success();
}

RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id, StringView nccl_path)
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
StringView nccl_path)
: HostComm{std::move(tracker_host), tracker_port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} {
if (this->TrackerInfo().host.empty()) {
// Not in a distributed environment.
return;
}

loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
if (!rc.OK()) {
Expand Down
6 changes: 3 additions & 3 deletions src/collective/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ class RabitComm : public HostComm {
public:
// bootstrapping construction.
RabitComm() = default;
// ctor for testing where environment is known.
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id, StringView nccl_path);
RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
StringView nccl_path);
~RabitComm() noexcept(false) override;

[[nodiscard]] bool IsFederated() const override { return false; }
Expand Down
14 changes: 7 additions & 7 deletions src/collective/comm_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ CommGroup::CommGroup()
auto task_id = get_param("dmlc_task_id", std::string{}, String{});

if (type == "rabit") {
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
auto ptr =
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id, nccl}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
auto ptr = new CommGroup{
std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id, nccl}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
return ptr;
} else if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
Expand Down

0 comments on commit 3001987

Please sign in to comment.