Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alternative SSHCluster implementation #2827

Merged
merged 8 commits into from
Jul 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ matrix:

install:
- if [[ $TESTS == true ]]; then source continuous_integration/travis/install.sh ; fi
- if [[ $TESTS == true ]]; then source continuous_integration/travis/setup-ssh.sh ; fi

script:
- if [[ $TESTS == true ]]; then source continuous_integration/travis/run_tests.sh ; fi
Expand Down
1 change: 1 addition & 0 deletions continuous_integration/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-deps
pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps
pip install -q sortedcollections msgpack --no-deps
pip install -q keras --upgrade --no-deps
pip install -q asyncssh

if [[ $CRICK == true ]]; then
conda install -q cython
Expand Down
2 changes: 2 additions & 0 deletions continuous_integration/travis/setup-ssh.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ssh-keygen -t rsa -f ~/.ssh/id_rsa -N "" -q
cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
1 change: 1 addition & 0 deletions distributed/deploy/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
loop=loop,
asynchronous=asynchronous,
silence_logs=silence_logs,
security=security,
)

def __repr__(self):
Expand Down
30 changes: 21 additions & 9 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from tornado import gen

from .cluster import Cluster
from ..core import rpc
from ..utils import LoopRunner, silence_logging, ignoring
from ..scheduler import Scheduler
from ..security import Security


class SpecCluster(Cluster):
Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(
worker=None,
asynchronous=False,
loop=None,
security=None,
silence_logs=False,
):
self._created = weakref.WeakSet()
Expand All @@ -125,6 +128,8 @@ def __init__(
self.workers = {}
self._i = 0
self._asynchronous = asynchronous
self.security = security or Security()
self.scheduler_comm = None

if silence_logs:
self._old_logging_level = silence_logging(level=silence_logs)
Expand Down Expand Up @@ -156,6 +161,10 @@ async def _start(self):
self._lock = asyncio.Lock()
self.status = "starting"
self.scheduler = await self.scheduler
self.scheduler_comm = rpc(
self.scheduler.address,
connection_args=self.security.get_connection_args("client"),
)
self.status = "running"

def _correct_state(self):
Expand All @@ -174,11 +183,13 @@ async def _correct_state_internal(self):
pre = list(set(self.workers))
to_close = set(self.workers) - set(self.worker_spec)
if to_close:
await self.scheduler.retire_workers(workers=list(to_close))
if self.scheduler.status == "running":
await self.scheduler_comm.retire_workers(workers=list(to_close))
tasks = [self.workers[w].close() for w in to_close]
await asyncio.wait(tasks)
for task in tasks: # for tornado gen.coroutine support
await task
with ignoring(RuntimeError):
await task
for name in to_close:
del self.workers[name]

Expand Down Expand Up @@ -214,11 +225,10 @@ async def _():
return _().__await__()

async def _wait_for_workers(self):
# TODO: this function needs to query scheduler and worker state
# remotely without assuming that they are local
while {d["name"] for d in self.scheduler.identity()["workers"].values()} != set(
self.workers
):
while {
str(d["name"])
for d in (await self.scheduler_comm.identity())["workers"].values()
} != set(map(str, self.workers)):
if (
any(w.status == "closed" for w in self.workers.values())
and self.scheduler.status == "running"
Expand All @@ -240,12 +250,14 @@ async def _close(self):
return
self.status = "closing"

async with self._lock:
await self.scheduler.close(close_workers=True)
self.scale(0)
await self._correct_state()
async with self._lock:
await self.scheduler_comm.close(close_workers=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the scheduler comm has closed for some reason and resulted in a broken pipe we will get an exception here.

Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/comm/tcp.py", line 194, in read
    n_frames = yield stream.read_bytes(8)
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/iostream.py", line 436, in read_bytes
    future = self._start_read()
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/iostream.py", line 797, in _start_read
    self._check_closed()  # Before reading, check that stream is not closed.
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/iostream.py", line 1009, in _check_closed
    raise StreamClosedError(real_error=self.error)
tornado.iostream.StreamClosedError: Stream is closed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/core.py", line 675, in send_recv_from_rpc
    result = yield send_recv(comm=comm, op=key, **kwargs)
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/gen.py", line 735, in run
    value = future.result()
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/gen.py", line 742, in run
    yielded = self.gen.throw(*exc_info)  # type: ignore
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/core.py", line 535, in send_recv
    response = yield comm.read(deserializers=deserializers)
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/gen.py", line 735, in run
    value = future.result()
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/comm/tcp.py", line 214, in read
    convert_stream_closed_error(self, e)
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/comm/tcp.py", line 139, in convert_stream_closed_error
    raise CommClosedError("in %s: %s: %s" % (obj, exc.__class__.__name__, exc))
distributed.comm.core.CommClosedError: in <closed TCP>: BrokenPipeError: [Errno 32] Broken pipe

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/contextlib.py", line 130, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/utils.py", line 182, in ignoring
    yield
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/deploy/spec.py", line 343, in close_clusters
    cluster.close(timeout=10)
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/deploy/spec.py", line 271, in close
    return self.sync(self._close, callback_timeout=timeout)
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/deploy/cluster.py", line 245, in sync
    return sync(self.loop, func, *args, **kwargs)
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/utils.py", line 332, in sync
    six.reraise(*error[0])
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/six.py", line 693, in reraise
    raise value
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/utils.py", line 317, in f
    result[0] = yield future
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/gen.py", line 735, in run
    value = future.result()
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/deploy/spec.py", line 258, in _close
    await self.scheduler_comm.close(close_workers=True)
  File "/home/nfs/jtomlinson/miniconda3/envs/dask-csp/lib/python3.7/site-packages/tornado/gen.py", line 742, in run
    yielded = self.gen.throw(*exc_info)  # type: ignore
  File "/home/nfs/jtomlinson/Projects/dask/distributed/distributed/core.py", line 678, in send_recv_from_rpc
    "%s: while trying to call remote method %r" % (e, key)
distributed.comm.core.CommClosedError: in <closed TCP>: BrokenPipeError: [Errno 32] Broken pipe: while trying to call remote method 'close'

This results in self.scheduler.close() never being called and failing to clean up.

await self.scheduler.close()
for w in self._created:
assert w.status == "closed"
self.scheduler_comm.close_rpc()

if hasattr(self, "_old_logging_level"):
silence_logging(self._old_logging_level)
Expand Down
171 changes: 171 additions & 0 deletions distributed/deploy/ssh2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import asyncio
import logging
import sys
import warnings
import weakref

import asyncssh

from .spec import SpecCluster

logger = logging.getLogger(__name__)

warnings.warn(
"the distributed.deploy.ssh2 module is experimental "
"and will move/change in the future without notice"
)


class Process:
""" A superclass for SSH Workers and Nannies

See Also
--------
Worker
Scheduler
"""

def __init__(self):
self.lock = asyncio.Lock()
self.connection = None
self.proc = None
self.status = "created"

def __await__(self):
async def _():
async with self.lock:
if not self.connection:
await self.start()
assert self.connection
weakref.finalize(self, self.proc.terminate)
return self

return _().__await__()

async def close(self):
self.proc.terminate()
self.connection.close()
self.status = "closed"

def __repr__(self):
return "<SSH %s: status=%s>" % (type(self).__name__, self.status)


class Worker(Process):
""" A Remote Dask Worker controled by SSH

Parameters
----------
scheduler: str
The address of the scheduler
address: str
The hostname where we should run this worker
connect_kwargs: dict
kwargs to be passed to asyncssh connections
kwargs:
TODO
"""

def __init__(self, scheduler: str, address: str, connect_kwargs: dict, **kwargs):
self.address = address
self.scheduler = scheduler
self.connect_kwargs = connect_kwargs
self.kwargs = kwargs

super().__init__()

async def start(self):
self.connection = await asyncssh.connect(self.address, **self.connect_kwargs)
self.proc = await self.connection.create_process(
" ".join(
[
sys.executable,
"-m",
"distributed.cli.dask_worker",
self.scheduler,
"--name", # we need to have name for SpecCluster
str(self.kwargs["name"]),
]
)
)

# We watch stderr in order to get the address, then we return
while True:
line = await self.proc.stderr.readline()
if "worker at" in line:
self.address = line.split("worker at:")[1].strip()
self.status = "running"
break
logger.debug("%s", line)


class Scheduler(Process):
""" A Remote Dask Scheduler controled by SSH

Parameters
----------
address: str
The hostname where we should run this worker
connect_kwargs: dict
kwargs to be passed to asyncssh connections
kwargs:
TODO
"""

def __init__(self, address: str, connect_kwargs: dict, **kwargs):
self.address = address
self.kwargs = kwargs
self.connect_kwargs = connect_kwargs

super().__init__()

async def start(self):
logger.debug("Created Scheduler Connection")

self.connection = await asyncssh.connect(self.address, **self.connect_kwargs)

self.proc = await self.connection.create_process(
" ".join([sys.executable, "-m", "distributed.cli.dask_scheduler"])
)

# We watch stderr in order to get the address, then we return
while True:
line = await self.proc.stderr.readline()
if "Scheduler at" in line:
self.address = line.split("Scheduler at:")[1].strip()
break
logger.debug("%s", line)


def SSHCluster(hosts, connect_kwargs, **kwargs):
""" Deploy a Dask cluster using SSH

Parameters
----------
hosts: List[str]
List of hostnames or addresses on which to launch our cluster
The first will be used for the scheduler and the rest for workers
connect_kwargs:
known_hosts: List[str] or None
The list of keys which will be used to validate the server host
key presented during the SSH handshake. If this is not specified,
the keys will be looked up in the file .ssh/known_hosts. If this
is explicitly set to None, server host key validation will be disabled.
TODO
kwargs:
TODO
----
This doesn't handle any keyword arguments yet. It is a proof of concept
"""
scheduler = {
"cls": Scheduler,
"options": {"address": hosts[0], "connect_kwargs": connect_kwargs},
}
workers = {
i: {
"cls": Worker,
"options": {"address": host, "connect_kwargs": connect_kwargs},
}
for i, host in enumerate(hosts[1:])
}
return SpecCluster(workers, scheduler, **kwargs)
17 changes: 17 additions & 0 deletions distributed/deploy/tests/test_ssh2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

pytest.importorskip("asyncssh")

from dask.distributed import Client
from distributed.deploy.ssh2 import SSHCluster


@pytest.mark.asyncio
async def test_basic():
async with SSHCluster(
["127.0.0.1"] * 3, connect_kwargs=dict(known_hosts=None), asynchronous=True
) as cluster:
assert len(cluster.workers) == 2
async with Client(cluster, asynchronous=True) as client:
result = await client.submit(lambda x: x + 1, 10)
assert result == 11