Skip to content

Commit

Permalink
[dask] find all needed ports in each host at once (fixes #4458) (#4498)
Browse files Browse the repository at this point in the history
* find all needed ports in each worker at once

* lint

* better naming

* use _HostWorkers in test
  • Loading branch information
jmoralez committed Aug 3, 2021
1 parent 1dbf438 commit 5fe27d5
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 95 deletions.
123 changes: 65 additions & 58 deletions python-package/lightgbm/dask.py
Expand Up @@ -7,7 +7,7 @@
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import socket
from collections import defaultdict
from collections import defaultdict, namedtuple
from copy import deepcopy
from enum import Enum, auto
from functools import partial
Expand All @@ -30,6 +30,8 @@
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]

_HostWorkers = namedtuple('HostWorkers', ['default', 'all'])


class _DatasetNames(Enum):
"""Placeholder names used by lightgbm.dask internals to say 'also evaluate the training data'.
Expand Down Expand Up @@ -62,18 +64,71 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return client


def _find_random_open_port() -> int:
"""Find a random open port on localhost.
def _find_n_open_ports(n: int) -> List[int]:
"""Find n random open ports on localhost.
Returns
-------
port : int
A free port on localhost
ports : list of int
n random open ports on localhost.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
sockets = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
port = s.getsockname()[1]
return port
sockets.append(s)
ports = []
for s in sockets:
ports.append(s.getsockname()[1])
s.close()
return ports


def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWorkers]:
"""Group all worker addresses by hostname.
Returns
-------
host_to_workers : dict
mapping from hostname to all its workers.
"""
host_to_workers: Dict[str, _HostWorkers] = {}
for address in worker_addresses:
hostname = urlparse(address).hostname
if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all=[address])
else:
host_to_workers[hostname].all.append(address)
return host_to_workers


def _assign_open_ports_to_workers(
client: Client,
host_to_workers: Dict[str, _HostWorkers]
) -> Dict[str, int]:
"""Assign an open port to each worker.
Returns
-------
worker_to_port: dict
mapping from worker address to an open port.
"""
host_ports_futures = {}
for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all)
host_ports_futures[hostname] = client.submit(
_find_n_open_ports,
n=n_workers_in_host,
workers=[workers.default],
pure=False,
allow_other_workers=False,
)
found_ports = client.gather(host_ports_futures)
worker_to_port = {}
for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all, found_ports[hostname]):
worker_to_port[worker] = port
return worker_to_port


def _concat(seq: List[_DaskPart]) -> _DaskPart:
Expand Down Expand Up @@ -330,44 +385,6 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
return out


def _possibly_fix_worker_map_duplicates(worker_map: Dict[str, int], client: Client) -> Dict[str, int]:
"""Fix any duplicate IP-port pairs in a ``worker_map``."""
worker_map = deepcopy(worker_map)
workers_that_need_new_ports = []
host_to_port = defaultdict(set)
for worker, port in worker_map.items():
host = urlparse(worker).hostname
if port in host_to_port[host]:
workers_that_need_new_ports.append(worker)
else:
host_to_port[host].add(port)

# if any duplicates were found, search for new ports one by one
for worker in workers_that_need_new_ports:
_log_info(f"Searching for a LightGBM training port for worker '{worker}'")
host = urlparse(worker).hostname
retries_remaining = 100
while retries_remaining > 0:
retries_remaining -= 1
new_port = client.submit(
_find_random_open_port,
workers=[worker],
allow_other_workers=False,
pure=False
).result()
if new_port not in host_to_port[host]:
worker_map[worker] = new_port
host_to_port[host].add(new_port)
break

if retries_remaining == 0:
raise LightGBMError(
"Failed to find an open port. Try re-running training or explicitly setting 'machines' or 'local_listen_port'."
)

return worker_map


def _train(
client: Client,
data: _DaskMatrixLike,
Expand Down Expand Up @@ -726,18 +743,8 @@ def _train(
}
else:
_log_info("Finding random open ports for workers")
# this approach with client.run() is faster than searching for ports
# serially, but can produce duplicates sometimes. Try the fast approach one
# time, then pass it through a function that will use a slower but more reliable
# approach if duplicates are found.
worker_address_to_port = client.run(
_find_random_open_port,
workers=list(worker_addresses)
)
worker_address_to_port = _possibly_fix_worker_map_duplicates(
worker_map=worker_address_to_port,
client=client
)
host_to_workers = _group_workers_by_host(worker_map.keys())
worker_address_to_port = _assign_open_ports_to_workers(client, host_to_workers)

machines = ','.join([
f'{urlparse(worker_address).hostname}:{port}'
Expand Down
61 changes: 24 additions & 37 deletions tests/python_package_test/test_dask.py
Expand Up @@ -446,11 +446,29 @@ def test_classifier_pred_contrib(output, task, cluster):
assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1)


def test_find_random_open_port(cluster):
def test_group_workers_by_host():
hosts = [f'0.0.0.{i}' for i in range(2)]
workers = [f'tcp://{host}:{p}' for p in range(2) for host in hosts]
expected = {
host: lgb.dask._HostWorkers(
default=f'tcp://{host}:0',
all=[f'tcp://{host}:0', f'tcp://{host}:1']
)
for host in hosts
}
host_to_workers = lgb.dask._group_workers_by_host(workers)
assert host_to_workers == expected


def test_assign_open_ports_to_workers(cluster):
with Client(cluster) as client:
for _ in range(5):
worker_address_to_port = client.run(lgb.dask._find_random_open_port)
workers = client.scheduler_info()['workers'].keys()
n_workers = len(workers)
host_to_workers = lgb.dask._group_workers_by_host(workers)
for _ in range(1_000):
worker_address_to_port = lgb.dask._assign_open_ports_to_workers(client, host_to_workers)
found_ports = worker_address_to_port.values()
assert len(found_ports) == n_workers
# check that found ports are different for same address (LocalCluster)
assert len(set(found_ports)) == len(found_ports)
# check that the ports are indeed open
Expand All @@ -459,37 +477,6 @@ def test_find_random_open_port(cluster):
s.bind(('', port))


def test_possibly_fix_worker_map(capsys, cluster):
with Client(cluster) as client:
worker_addresses = list(client.scheduler_info()["workers"].keys())

retry_msg = 'Searching for a LightGBM training port for worker'

# should handle worker maps without any duplicates
map_without_duplicates = {
worker_address: 12400 + i
for i, worker_address in enumerate(worker_addresses)
}
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_without_duplicates
)
assert patched_map == map_without_duplicates
assert retry_msg not in capsys.readouterr().out

# should handle worker maps with duplicates
map_with_duplicates = {
worker_address: 12400
for i, worker_address in enumerate(worker_addresses)
}
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_with_duplicates
)
assert retry_msg in capsys.readouterr().out
assert len(set(patched_map.values())) == len(worker_addresses)


def test_training_does_not_fail_on_port_conflicts(cluster):
with Client(cluster) as client:
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
Expand Down Expand Up @@ -1406,7 +1393,7 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c

# model 2 - machines given
n_workers = len(client.scheduler_info()['workers'])
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model2 = dask_model_factory(
n_estimators=5,
num_leaves=5,
Expand Down Expand Up @@ -1452,7 +1439,7 @@ def test_machines_should_be_used_if_provided(task, cluster):

n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model = dask_model_factory(
n_estimators=5,
num_leaves=5,
Expand All @@ -1474,7 +1461,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
client.restart()

# an informative error should be raised if "machines" has duplicates
one_open_port = lgb.dask._find_random_open_port()
one_open_port = lgb.dask._find_n_open_ports(1)
dask_model.set_params(
machines=",".join([
f"127.0.0.1:{one_open_port}"
Expand Down

0 comments on commit 5fe27d5

Please sign in to comment.