Skip to content

Commit

Permalink
[dask] [python-package] Search for available ports when setting up ne…
Browse files Browse the repository at this point in the history
…twork (fixes #3753) (#3766)

* starting work

* fixed port-binding issue on localhost

* minor cleanup

* updates

* getting closer

* definitely working for LocalCluster

* it works, it works

* docs

* add tests

* removing testing-only files

* linting

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* remove duplicated code

* remove unnecessary listen()

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
jameslamb and StrikerRUS committed Jan 15, 2021
1 parent 9bacf03 commit f6d2dce
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 37 deletions.
117 changes: 95 additions & 22 deletions python-package/lightgbm/dask.py
Expand Up @@ -5,15 +5,17 @@
It is based on dask-xgboost package.
"""
import logging
import socket
from collections import defaultdict
from typing import Dict, Iterable
from urllib.parse import urlparse

import numpy as np
import pandas as pd
from dask import array as da
from dask import dataframe as dd
from dask import delayed
from dask.distributed import default_client, get_worker, wait
from dask.distributed import Client, default_client, get_worker, wait

from .basic import _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor
Expand All @@ -23,33 +25,84 @@
logger = logging.getLogger(__name__)


def _parse_host_port(address):
parsed = urlparse(address)
return parsed.hostname, parsed.port
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port.
This function tries to find a free port on the machine it's run on. It is intended to
be run once on each Dask worker, sequentially.
def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out):
"""Build network parameters suitable for LightGBM C backend.
Parameters
----------
worker_ip : str
IP address for the Dask worker.
local_listen_port : int
First port to try when searching for open ports.
ports_to_skip: Iterable[int]
An iterable of integers referring to ports that should be skipped. Since multiple Dask
workers can run on the same physical machine, this method may be called multiple times
on the same machine. ``ports_to_skip`` is used to ensure that LightGBM doesn't try to use
the same port for two worker processes running on the same machine.
Returns
-------
result : int
A free port on the machine referenced by ``worker_ip``.
"""
max_tries = 1000
out_port = None
found_port = False
for i in range(max_tries):
out_port = local_listen_port + i
if out_port in ports_to_skip:
continue
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, out_port))
found_port = True
break
# if unavailable, you'll get OSError: Address already in use
except OSError:
continue
if not found_port:
msg = "LightGBM tried %s:%d-%d and could not create a connection. Try setting local_listen_port to a different value."
raise RuntimeError(msg % (worker_ip, local_listen_port, out_port))
return out_port


def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]:
"""Find an open port on each worker.
LightGBM distributed training uses TCP sockets by default, and this method is used to
identify open ports on each worker so LightGBM can reliable create those sockets.
Parameters
----------
worker_addresses : iterable of str - collection of worker addresses in `<protocol>://<host>:port` format
local_worker_ip : str
client : dask.distributed.Client
Dask client.
worker_addresses : Iterable[str]
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``
local_listen_port : int
time_out : int
First port to try when searching for open ports.
Returns
-------
params: dict
result : Dict[str, int]
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
"""
addr_port_map = {addr: (local_listen_port + i) for i, addr in enumerate(worker_addresses)}
params = {
'machines': ','.join('%s:%d' % (_parse_host_port(addr)[0], port) for addr, port in addr_port_map.items()),
'local_listen_port': addr_port_map[local_worker_ip],
'time_out': time_out,
'num_machines': len(addr_port_map)
}
return params
lightgbm_ports = set()
worker_ip_to_port = {}
for worker_address in worker_addresses:
port = client.submit(
func=_find_open_port,
workers=[worker_address],
worker_ip=urlparse(worker_address).hostname,
local_listen_port=local_listen_port,
ports_to_skip=lightgbm_ports
).result()
lightgbm_ports.add(port)
worker_ip_to_port[worker_address] = port

return worker_ip_to_port


def _concat(seq):
Expand All @@ -63,9 +116,20 @@ def _concat(seq):
raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0])))


def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400,
def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model,
time_out=120, **kwargs):
network_params = _build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out)
local_worker_address = get_worker().address
machine_list = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
for worker_address, port
in worker_address_to_port.items()
])
network_params = {
'machines': machine_list,
'local_listen_port': worker_address_to_port[local_worker_address],
'time_out': time_out,
'num_machines': len(worker_address_to_port)
}
params.update(network_params)

# Concatenate many parts into one
Expand Down Expand Up @@ -138,13 +202,22 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
'(%s), using "data" as default', params.get("tree_learner", None))
params['tree_learner'] = 'data'

# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
local_listen_port = params.get('local_listen_port', 12400)
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=local_listen_port
)

# Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part,
model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts,
worker_addresses=list(worker_map.keys()),
local_listen_port=params.get('local_listen_port', 12400),
worker_address_to_port=worker_address_to_port,
time_out=params.get('time_out', 120),
return_model=(worker == master_worker),
**kwargs)
Expand Down
57 changes: 42 additions & 15 deletions tests/python_package_test/test_dask.py
@@ -1,5 +1,6 @@
# coding: utf-8
import os
import socket
import sys

import pytest
Expand Down Expand Up @@ -89,6 +90,26 @@ def test_classifier(output, centers, client, listen_port):
assert_eq(y, p2)


def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('classification', output='array')

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 12400))

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=12400
)
for i in range(5):
dask_classifier.fit(
X=dX,
y=dy,
sample_weight=dw,
client=client
)
assert dask_classifier.booster_


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier_proba(output, centers, client, listen_port):
Expand Down Expand Up @@ -183,21 +204,27 @@ def test_regressor_local_predict(client, listen_port):
assert_eq(s1, s2)


def test_build_network_params():
workers_ips = [
'tcp://192.168.0.1:34545',
'tcp://192.168.0.2:34346',
'tcp://192.168.0.3:34347'
]

params = dlgbm._build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120)
exp_params = {
'machines': '192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402',
'local_listen_port': 12401,
'num_machines': len(workers_ips),
'time_out': 120
}
assert exp_params == params
def test_find_open_port_works():
worker_ip = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, 12400))
new_port = dlgbm._find_open_port(
worker_ip=worker_ip,
local_listen_port=12400,
ports_to_skip=set()
)
assert new_port == 12401

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_1:
s_1.bind((worker_ip, 12400))
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_2:
s_2.bind((worker_ip, 12401))
new_port = dlgbm._find_open_port(
worker_ip=worker_ip,
local_listen_port=12400,
ports_to_skip=set()
)
assert new_port == 12402


@gen_cluster(client=True, timeout=None)
Expand Down

2 comments on commit f6d2dce

@pseudotensor
Copy link

@pseudotensor pseudotensor commented on f6d2dce Jan 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jameslamb , thanks for all your and others work on dask lightgbm, I think it will be alot more stable in the end.

I wonder if you have noticed or considered these issues I hit:

dask/dask-lightgbm#22
dask/dask-lightgbm#24

I still hit these. The KeyError is most concerning, since it seems to happen randomly. GPU support would be nice to have.

Also, what is the status of dask support in lightgbm currently? Should I transition from dask-lightgbm package to lightgbm now? Were you able to incorporate my changes for early stopping support?

e.g. I notice you guys still point to dask-lightgbm package at: https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html

@jameslamb
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you have noticed or considered these issues I hit:

Ah I hadn't moved those over to LightGBM. That's done now.

Were you able to incorporate my changes for early stopping support?

Not yet. This is documented in #3712 and @ffineis has offered to contribute it.

Also, what is the status of dask support in lightgbm currently? Should I transition from dask-lightgbm package to lightgbm now?
e.g. I notice you guys still point to dask-lightgbm package at: https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html

There isn't an official release of LightGBM yet that includes lightgbm.dask. That will come in LightGBM 3.2.0, which doesn't have a planned release date yet. You shouldn't consider the version currently on master to be stable. We're still working on it and still might make breaking changes to it. We won't remove references to dask-lightgbm from the documentation until the next LightGBM release.

If you cut over from dask-lightgbm to master of LightGBM I'd appreciate it, because you might find other bugs we can fix before 3.2.0. But I want to be sure you understand that it isn't stable so you can make an informed decision.


One other note...in the future, please open an issue instead of commenting on a commit. That will be more visible to maintainers, contributors, and others arriving from search engines.

Please sign in to comment.