Skip to content

Commit

Permalink
Add alternative SSHCluster implementation (#2827)
Browse files Browse the repository at this point in the history
This is a proof of concept here for two reasons:

1.  It opens up a possible alternative for SSH deployment (which was
    surprisingly popular in the user survey)
2.  It is the first non-local application of `SpecCluster` and so serves
    as a proof of concept for other future deployments that are mostly
    defined by creating a remote Worker/Scheduler object

This forced some changes in `SpecCluster`, notably we now have an `rpc`
object that does remote calls rather than accessing the scheduler
directly.  Also, we're going to have to figure out how to handle all of
the keyword arguments.  In this case we need to pass them from Python
down to the CLI, and presumably we'll also want a `dask-ssh` CLI command
which has to translate the other way.
  • Loading branch information
mrocklin committed Jul 18, 2019
1 parent af64e07 commit df2addc
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 9 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ matrix:

install:
- if [[ $TESTS == true ]]; then source continuous_integration/travis/install.sh ; fi
- if [[ $TESTS == true ]]; then source continuous_integration/travis/setup-ssh.sh ; fi

script:
- if [[ $TESTS == true ]]; then source continuous_integration/travis/run_tests.sh ; fi
Expand Down
1 change: 1 addition & 0 deletions continuous_integration/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pip install -q git+https://github.com/dask/s3fs.git --upgrade --no-deps
pip install -q git+https://github.com/dask/zict.git --upgrade --no-deps
pip install -q sortedcollections msgpack --no-deps
pip install -q keras --upgrade --no-deps
pip install -q asyncssh

if [[ $CRICK == true ]]; then
conda install -q cython
Expand Down
2 changes: 2 additions & 0 deletions continuous_integration/travis/setup-ssh.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ssh-keygen -t rsa -f ~/.ssh/id_rsa -N "" -q
cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
1 change: 1 addition & 0 deletions distributed/deploy/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(
loop=loop,
asynchronous=asynchronous,
silence_logs=silence_logs,
security=security,
)

def __repr__(self):
Expand Down
30 changes: 21 additions & 9 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from tornado import gen

from .cluster import Cluster
from ..core import rpc
from ..utils import LoopRunner, silence_logging, ignoring
from ..scheduler import Scheduler
from ..security import Security


class SpecCluster(Cluster):
Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(
worker=None,
asynchronous=False,
loop=None,
security=None,
silence_logs=False,
):
self._created = weakref.WeakSet()
Expand All @@ -125,6 +128,8 @@ def __init__(
self.workers = {}
self._i = 0
self._asynchronous = asynchronous
self.security = security or Security()
self.scheduler_comm = None

if silence_logs:
self._old_logging_level = silence_logging(level=silence_logs)
Expand Down Expand Up @@ -156,6 +161,10 @@ async def _start(self):
self._lock = asyncio.Lock()
self.status = "starting"
self.scheduler = await self.scheduler
self.scheduler_comm = rpc(
self.scheduler.address,
connection_args=self.security.get_connection_args("client"),
)
self.status = "running"

def _correct_state(self):
Expand All @@ -174,11 +183,13 @@ async def _correct_state_internal(self):
pre = list(set(self.workers))
to_close = set(self.workers) - set(self.worker_spec)
if to_close:
await self.scheduler.retire_workers(workers=list(to_close))
if self.scheduler.status == "running":
await self.scheduler_comm.retire_workers(workers=list(to_close))
tasks = [self.workers[w].close() for w in to_close]
await asyncio.wait(tasks)
for task in tasks: # for tornado gen.coroutine support
await task
with ignoring(RuntimeError):
await task
for name in to_close:
del self.workers[name]

Expand Down Expand Up @@ -214,11 +225,10 @@ async def _():
return _().__await__()

async def _wait_for_workers(self):
# TODO: this function needs to query scheduler and worker state
# remotely without assuming that they are local
while {d["name"] for d in self.scheduler.identity()["workers"].values()} != set(
self.workers
):
while {
str(d["name"])
for d in (await self.scheduler_comm.identity())["workers"].values()
} != set(map(str, self.workers)):
if (
any(w.status == "closed" for w in self.workers.values())
and self.scheduler.status == "running"
Expand All @@ -240,12 +250,14 @@ async def _close(self):
return
self.status = "closing"

async with self._lock:
await self.scheduler.close(close_workers=True)
self.scale(0)
await self._correct_state()
async with self._lock:
await self.scheduler_comm.close(close_workers=True)
await self.scheduler.close()
for w in self._created:
assert w.status == "closed"
self.scheduler_comm.close_rpc()

if hasattr(self, "_old_logging_level"):
silence_logging(self._old_logging_level)
Expand Down
171 changes: 171 additions & 0 deletions distributed/deploy/ssh2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import asyncio
import logging
import sys
import warnings
import weakref

import asyncssh

from .spec import SpecCluster

logger = logging.getLogger(__name__)

warnings.warn(
"the distributed.deploy.ssh2 module is experimental "
"and will move/change in the future without notice"
)


class Process:
""" A superclass for SSH Workers and Nannies
See Also
--------
Worker
Scheduler
"""

def __init__(self):
self.lock = asyncio.Lock()
self.connection = None
self.proc = None
self.status = "created"

def __await__(self):
async def _():
async with self.lock:
if not self.connection:
await self.start()
assert self.connection
weakref.finalize(self, self.proc.terminate)
return self

return _().__await__()

async def close(self):
self.proc.terminate()
self.connection.close()
self.status = "closed"

def __repr__(self):
return "<SSH %s: status=%s>" % (type(self).__name__, self.status)


class Worker(Process):
""" A Remote Dask Worker controled by SSH
Parameters
----------
scheduler: str
The address of the scheduler
address: str
The hostname where we should run this worker
connect_kwargs: dict
kwargs to be passed to asyncssh connections
kwargs:
TODO
"""

def __init__(self, scheduler: str, address: str, connect_kwargs: dict, **kwargs):
self.address = address
self.scheduler = scheduler
self.connect_kwargs = connect_kwargs
self.kwargs = kwargs

super().__init__()

async def start(self):
self.connection = await asyncssh.connect(self.address, **self.connect_kwargs)
self.proc = await self.connection.create_process(
" ".join(
[
sys.executable,
"-m",
"distributed.cli.dask_worker",
self.scheduler,
"--name", # we need to have name for SpecCluster
str(self.kwargs["name"]),
]
)
)

# We watch stderr in order to get the address, then we return
while True:
line = await self.proc.stderr.readline()
if "worker at" in line:
self.address = line.split("worker at:")[1].strip()
self.status = "running"
break
logger.debug("%s", line)


class Scheduler(Process):
""" A Remote Dask Scheduler controled by SSH
Parameters
----------
address: str
The hostname where we should run this worker
connect_kwargs: dict
kwargs to be passed to asyncssh connections
kwargs:
TODO
"""

def __init__(self, address: str, connect_kwargs: dict, **kwargs):
self.address = address
self.kwargs = kwargs
self.connect_kwargs = connect_kwargs

super().__init__()

async def start(self):
logger.debug("Created Scheduler Connection")

self.connection = await asyncssh.connect(self.address, **self.connect_kwargs)

self.proc = await self.connection.create_process(
" ".join([sys.executable, "-m", "distributed.cli.dask_scheduler"])
)

# We watch stderr in order to get the address, then we return
while True:
line = await self.proc.stderr.readline()
if "Scheduler at" in line:
self.address = line.split("Scheduler at:")[1].strip()
break
logger.debug("%s", line)


def SSHCluster(hosts, connect_kwargs, **kwargs):
""" Deploy a Dask cluster using SSH
Parameters
----------
hosts: List[str]
List of hostnames or addresses on which to launch our cluster
The first will be used for the scheduler and the rest for workers
connect_kwargs:
known_hosts: List[str] or None
The list of keys which will be used to validate the server host
key presented during the SSH handshake. If this is not specified,
the keys will be looked up in the file .ssh/known_hosts. If this
is explicitly set to None, server host key validation will be disabled.
TODO
kwargs:
TODO
----
This doesn't handle any keyword arguments yet. It is a proof of concept
"""
scheduler = {
"cls": Scheduler,
"options": {"address": hosts[0], "connect_kwargs": connect_kwargs},
}
workers = {
i: {
"cls": Worker,
"options": {"address": host, "connect_kwargs": connect_kwargs},
}
for i, host in enumerate(hosts[1:])
}
return SpecCluster(workers, scheduler, **kwargs)
17 changes: 17 additions & 0 deletions distributed/deploy/tests/test_ssh2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

pytest.importorskip("asyncssh")

from dask.distributed import Client
from distributed.deploy.ssh2 import SSHCluster


@pytest.mark.asyncio
async def test_basic():
async with SSHCluster(
["127.0.0.1"] * 3, connect_kwargs=dict(known_hosts=None), asynchronous=True
) as cluster:
assert len(cluster.workers) == 2
async with Client(cluster, asynchronous=True) as client:
result = await client.submit(lambda x: x + 1, 10)
assert result == 11

0 comments on commit df2addc

Please sign in to comment.