Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] allow tight control over ports #3994

Merged
merged 30 commits into from Feb 23, 2021

Conversation

jameslamb
Copy link
Collaborator

@jameslamb jameslamb commented Feb 17, 2021

In #3823, @jmoralez proposed a change for the Dask module that would speed up training by taking a bit of setup overhead that is O(n_workers) and make it O(1). That discussion led to a conversation about how to give tight control over the ports used by LightGBM, for users who are using a cluster of machines whose communication is limited by firewall rules.

I promised to go do some research in LightGBM's source code and come back with a proposal. This PR is that proposal.

Proposed Changes

This PR proposes that lightgbm.dask should handle ports as follows.

If no network params are given in params (the default)

model = DaskLGBMRegressor()
model.fit(data, labels)

LightGBM chooses ports randomly. As of this PR, it will use the current approach on master (searching 1000 ports starting with the default value of local_listen_port), but after this PR we should replace that with @jmoralez 's proposal in #3823.

If machines is provided in params

machines = "10.0.0.1:15000,10.0.0.1:15001,10.0.1.1:15000,10.0.1.2:15005"
model = DaskLGBMRegressor(
    machines=machines
)
model.fit(data, labels)

LightGBM respects machines and does not do any searching. This gives people who are in a constrained environment total control over the ports used, without needing to invent a new parameter specific to the Dask interface.

This can work for cases where you run multiple Dask worker processes on the same host. For example, for a FargateCluster with n_workers=4, nprocs=2, you might use machines like this:

10.0.0.1:15000,10.0.0.1:15001,10.0.1.1:15000,10.0.1.2:15005

In a follow-up PR after this, to make this easier, I'd propose adding a function dask_collection_to_machines_param(), which takes in your training data and a list of ports that you've allowed in your firewall settings, and returns a machines string you can put into params.

def dask_collection_to_machines_param(
    data: Union[dask.DataFrame, dask.Array],
    allowed_ports: Set[int]
) -> str:

If local_listen_port is provided in params

model = DaskLGBMRegressor(
    local_listen_port=16000
)
model.fit(data, labels)

Create machines by looking at which worker addresses have a piece of the training data, and assume that each of them will use the same port (local_listen_port). This will only work in the case where you have 1 Dask worker process per physical host.

How this PR improves lightgbm.dask

This gives an official answer to the question "how I should I used distributed LightGBM training with Dask if my Dask cluster is constrained by firewall rules?", while also making it possible to take advantage of the speedups from #3823 .

It does this in a way that only uses existing LightGBM network parameters, and which doesn't require any notes in documentation that say "hey this parameter has a slightly different meaning in Dask than everywhere else".

How LightGBM's network setup works

This is optional background that will help explain the proposal I'm making. Everything below refers only to the socket-based build of LightGBM.

notes on how the code in src/network works

Each worker runs Network::Init() initialize a LightGBM network. The following happens when that method is called.

  1. "What is my ID (rank) and what other workers are in this network?" --> Create a Linkers
    1. "My ID will be an integer corresponding to my place in the machines list" -->
      rank_ = static_cast<int>(i);
    2. "I can open a TCP listener socket to accept connections from other workers" -->
      listener_ = std::unique_ptr<TcpSocket>(new TcpSocket());
      TryBind(local_listen_port_);
    3. "I'll set up some data structures to store information about the other workers in the network" -->
      for (int i = 0; i < num_machines_; ++i) {
      linkers_.push_back(nullptr);
      }
      // construct communication topo
      bruck_map_ = BruckMap::Construct(rank_, num_machines_);
      recursive_halving_map_ = RecursiveHalvingMap::Construct(rank_, num_machines_);
    4. "I need to establish a socket connection with each other worker in the network" --> call Linkers::Construct() to set up these connections
      1. "Since these are two-way communications, I'll just initiate communications with all workers that have a rank greater than mine. I know from machines which IP addresses to look for and what local_listen_port they'll be listening on. The workers with rank less than me will initiate communications with me. Each time I communicate successfully, I'll save that new TCP socket so I can use it during training to talk to that specific worker" -->
        for (auto it = need_connect.begin(); it != need_connect.end(); ++it) {
        int out_rank = it->first;
        // let smaller rank connect to larger rank
        if (out_rank > rank_) {
        int connect_fail_delay_time = connect_fail_retry_first_delay_interval;
        for (int i = 0; i < connect_fail_retry_cnt; ++i) {
        TcpSocket cur_socket;
        if (cur_socket.Connect(client_ips_[out_rank].c_str(), client_ports_[out_rank])) {
        // send local rank
        cur_socket.Send(reinterpret_cast<const char*>(&rank_), sizeof(rank_));
        SetLinker(out_rank, cur_socket);
        break;
        } else {
        Log::Warning("Connecting to rank %d failed, waiting for %d milliseconds", out_rank, connect_fail_delay_time);
        cur_socket.Close();
        std::this_thread::sleep_for(std::chrono::milliseconds(connect_fail_delay_time));
        connect_fail_delay_time = static_cast<int>(connect_fail_delay_time * connect_fail_retry_delay_factor);
        }
        }
        }
        }
    5. "Ok, now that I have sockets to talk to all the other workers, I can destroy the listener socket. I'm not accepting any more new connections" -->
      listener_->Close();
      is_init_ = true;
  2. "Now that I know how to talk to other workers, I should also set up some space in memory to store data I want to send to them and data I receive from them" -->
    block_start_ = std::vector<comm_size_t>(num_machines_);
    block_len_ = std::vector<comm_size_t>(num_machines_);
    buffer_size_ = 1024 * 1024;
    buffer_.resize(buffer_size_);

You can learn more about this from https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html#preparation, https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html#id1, and https://lightgbm.readthedocs.io/en/latest/Parameters.html#network-parameters.

Notes for Reviewers

I'll devote a section to this topic in the documentation introduced for #3814 .

The one Dask unit test that's failing will pass if #3993 is accepted and merged.

@jameslamb
Copy link
Collaborator Author

I thought I left a comment on here but I guess I navigated away before sending it haha.

@jmoralez @ffineis if you have time and interest, I'd really appreciate your feedback on this PR.

@jameslamb
Copy link
Collaborator Author

jameslamb commented Feb 18, 2021

I tested this on a FargateCluster on AWS tonight using dask-cloudprovider: https://github.com/jameslamb/lightgbm-dask-testing/blob/c7e94dd520e762bdcc6c45967dd8bff4c5936302/notebooks/aws.ipynb

I tried all three possible configurations described above, and all three trainings succeeded. So I'm confident this will work on either LocalCluster (what we use in tests) or a true distributed cluster.

* ``local_listen_port``: port that each LightGBM worker opens a listening socket on,
to accept connections from other workers. This can be differ from LightGBM worker
to LightGBM worker, but does not have to.
* ``machines``: a list of all machines in the cluster, plus a port to communicate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This totally just my opinion, but "machines": a list of all machines in the cluster, plus a port to communicate..." is kind of circular - the second use of "machines" is referring to IP addresses, but the original use of machines (IMO) is an IP:port combo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah I probably should not use the word "machines" again in the description haha, thank you

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok updated in 25462ea, thanks for catching this

@ffineis
Copy link
Contributor

ffineis commented Feb 18, 2021

I thought I left a comment on here but I guess I navigated away before sending it haha.

@jmoralez @ffineis if you have time and interest, I'd really appreciate your feedback on this PR.

Hey I think these three options make sense and that they're easy to understand! Nice work! Really, I don't have much to add.

@StrikerRUS
Copy link
Collaborator

@jameslamb I haven't looked at the diff yet, just read the detailed explanation you've provided.

I think I found one inconsistency.

If no network params are given in params (the default)
LightGBM chooses ports randomly.

It does this in a way that only uses existing LightGBM network parameters, and which doesn't require any notes in documentation that say "hey this parameter has a slightly different meaning in Dask than everywhere else".

Default value of the local_listen_port is 12400 and not "any random value starting from 12400 (or even simply 'any random value' after #3823)". And some users which are fine with default port and are under firewall constraints don't expect that ports will be random, they expect all ports on all machines will be exactly 12400 (default value). So, I still think that we need some clarification of this inconsistency in Dask docs.
Refer to #3823 (review).

@jameslamb
Copy link
Collaborator Author

@jameslamb Thank you! I just thought about it a little bit more and come up to that it probably will be easier to just write default = 12400 (random for Dask-package) because right now users must read docs very carefully in two places and we must duplicate this warning in multiple (at least 3) places after implementing #3846 and #3847.
Refer to

// default = train

for how to set custom default value in params docs.

ok sure, no problem. Done in 0c81f60

image

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! But I have some comments below.

python-package/lightgbm/dask.py Outdated Show resolved Hide resolved
python-package/lightgbm/dask.py Outdated Show resolved Hide resolved
python-package/lightgbm/dask.py Outdated Show resolved Hide resolved
python-package/lightgbm/dask.py Outdated Show resolved Hide resolved
Comment on lines +403 to +404
"machine in the cluster has multiple Dask worker processes running on it. Please omit "
"'local_listen_port' or pass 'machines'."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"machine in the cluster has multiple Dask worker processes running on it. Please omit "
"'local_listen_port' or pass 'machines'."
"machine in the cluster has multiple Dask worker processes running on it.\nPlease omit "
"'local_listen_port' or pass full configuration via 'machines' parameter."

Copy link
Collaborator Author

@jameslamb jameslamb Feb 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why you think we should include a newline?

I'm concerned that in logs, it will look like an exception with only the text before the newline followed by a separate print statement.

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, it is up to you. Feel free to revert new line. I personally don't like long line warnings/errors.

Copy link
Collaborator Author

@jameslamb jameslamb Feb 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright I'm not going to accept this suggestion then if it's just a matter a matter of personal preference.

I've had problems in the past with external logs-management systems and log messages that have newline characters. You can read about that general problem at https://www.datadoghq.com/blog/multiline-logging-guide/#the-multi-line-logging-problem if you're interested.

Long log messages will also be wrapped automatically in Jupyter notebooks

image

and in python REPLs

image

# * 'num_machines': set automatically from Dask worker list
# * 'num_threads': overridden to match nthreads on each Dask process
for param_alias in _ConfigAliases.get('machines', 'num_machines', 'num_threads'):
for param_alias in _ConfigAliases.get('num_machines', 'num_threads'):
params.pop(param_alias, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is important to notify users about this behavior.

Suggested change
params.pop(param_alias, None)
if param_alias in params:
_log_warning(f"Parameter {param_alias} will be ignored.")
params.pop(param_alias)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I accepted this suggestion but I think now that we should only apply it to num_machines, not num_threads.

This results in a warning that users cannot suppress.

/opt/conda/lib/python3.8/site-packages/lightgbm/dask.py:338: UserWarning: Parameter n_jobs will be ignored.
_log_warning(f"Parameter {param_alias} will be ignored.")

Caused by the fact that n_jobs is an alias of num_threads

I believe that every warning should be something that can be changed by user code changes. Otherwise, we're just adding noise to logs that might cause people to start filtering out ALL warnings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is important to notify users about num_threads as well before implementing #3714. Silently ignore parameter is more serious problem compared to unfixable warning, I believe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree in this specific case about the meaning of "ignore", since this is a parameter default and not something explicitly passed in. However, since num_threads isn't directly related to the purpose of this PR and since I don't want to delay this PR too long because I'd like to merge #3823 soon after it, I'll leave this warning in for now and propose another PR in a few days where we could discuss it further.

dask_model2.fit(dX, dy, group=dg)
else:
dask_model2.fit(dX, dy)
assert dask_model2.fitted_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should really assert that

  1. assert dask_model2.get_params()['machines'] == machines
  2. check somehow that passed workers from machines were used and not any other as it is stated in the test name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good point! ok yes, I can do that.

check somehow that passed workers from machines were used and not any other as it is stated in the test name.

I can test this by binding one of the ports mentioned in machines, and asserting that training fails with the "failed to bind port" error

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I think I've addressed both oof these in my most recent commit. I'm glad you mentioned #1, because it's something I didn't think about correctly. To fix that, I captured whether machines was provided explicitly in parameters, and if it was I had _train() preserve it in the fitted model object. Like this

if not machines_in_params:
    for param in _ConfigAliases.get('machines'):
        model._other_params.pop(param, None)

@jmoralez
Copy link
Collaborator

Sorry for being late to the party, I've had a lot of stuff at work.

Default value of the local_listen_port is 12400 and not "any random value starting from 12400 (or even simply 'any random value' after #3823)". And some users which are fine with default port and are under firewall constraints don't expect that ports will be random, they expect all ports on all machines will be exactly 12400 (default value)

I agree with @StrikerRUS here and I really like the proposal of including the machines argument because I think it makes it very easy for the user to be explicit about the desired distributed training configuration. Right now I think when machines is specified it isn't being verified that the ports are actually open, so I think we could add something like:

def _check_port_is_open(port: int) -> int:
"""Check that port is open on the machine or find a random open one if port is 0."""
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
              s.bind(('', port))
              port = s.getsockname()[1]
    except OSError:
        # handle port already in use
    return port

to _machines_to_worker_map and that way for machines="ip1:12600,ip2:12700" the function would check that 12600 is available in ip1 and 12700 is available in ip2. And using random ports for training could be achieved with machines="ip1:0,ip2:0". This wouldn't change LightGBM's default behavior in the dask interface and should only be documented in the machines argument.

@jameslamb
Copy link
Collaborator Author

No problem @jmoralez , thanks for all your help so far with the Dask package.

Right now I think when machines is specified it isn't being verified that the ports are actually open, so I think we could add something like...

I'm against adding additional code for this in the Python package. You'll already get the informative error "Binding port 12345 failed" if one of the ports you chose is not open.

Log::Fatal("Binding port %d failed", port);
.

And using random ports for training could be achieved with machines="ip1:0,ip2:0"

I don't want to encourage or support this pattern. If the communication between workers is not governed by strict firewall rules, lightgbm.dask should "just work", just like every library in the Dask ecosystem works without you ever needing to think about IP addresses or ports.

The "list IPs explicitly with machines" approach is fragile because if any of those machines don't also have a piece of the training data, training will fail. And because those values have to be recomputed each time you connect to a new cluster, or whenever a Dask worker is restarted. It's worth paying the price of that fragility in exchange for being compliant with your organization's networking constraints, but I wouldn't encourage anyone who is comfortable with getting random ports to use machines.

So I'm against adding a 4th option that is like "if you use :0 in machines, lightgbm.dask will search randomly for ports". I think the 3 options explained in this pull request's description are sufficient.


@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_machines_should_be_used_if_provided(task, output):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to make this a separate test from test_network_params_not_required_but_respected_if_given because all the different context managers (client fixture + with pytests.raises() + with socket.socket()) were interacting in a strange way, and I was getting frequent teardown errors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also why I had to use LocalCluster instead of the client fixture for this test

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

python-package/lightgbm/dask.py Outdated Show resolved Hide resolved
tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
Comment on lines 1121 to 1122
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just letting you know that you can use one-line with for multiple variables to not make the code over-indented:

with LocalCluster(n_workers=2) as cluster, Client(cluster) as client:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oooo I didn't know you could do that when the second with references the as name from the first one, nice!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just made that change in b3c8a2c, thanks

@github-actions
Copy link

This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 24, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants