diff --git a/distributed/__init__.py b/distributed/__init__.py index 520af057f4..8d3d2df214 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -11,6 +11,7 @@ from .queues import Queue from .scheduler import Scheduler from .utils import sync +from .variable import Variable from .worker import Worker, get_worker from .worker_client import local_client, worker_client diff --git a/distributed/scheduler.py b/distributed/scheduler.py index efe31b4402..7e1ec8d557 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -32,17 +32,19 @@ error_message, clean_exception, CommClosedError) from .metrics import time from .node import ServerNode -from .publish import PublishExtension -from .channels import ChannelScheduler -from .queues import QueueExtension -from .stealing import WorkStealing -from .recreate_exceptions import ReplayExceptionScheduler from .security import Security from .utils import (All, ignoring, get_ip, get_fileno_limit, log_errors, key_split, validate_key) from .utils_comm import (scatter_to_workers, gather_from_workers) from .versions import get_versions +from .channels import ChannelScheduler +from .publish import PublishExtension +from .queues import QueueExtension +from .recreate_exceptions import ReplayExceptionScheduler +from .stealing import WorkStealing +from .variable import VariableExtension + logger = logging.getLogger(__name__) @@ -190,7 +192,8 @@ def __init__(self, center=None, loop=None, delete_interval=500, synchronize_worker_interval=60000, services=None, allowed_failures=ALLOWED_FAILURES, extensions=[ChannelScheduler, PublishExtension, WorkStealing, - ReplayExceptionScheduler, QueueExtension], + ReplayExceptionScheduler, QueueExtension, + VariableExtension], validate=False, scheduler_file=None, security=None, **kwargs): diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 08ad76b1d0..a57c855cd1 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -8,7 +8,6 @@ from tornado import gen from distributed import Client, Queue -from distributed import worker_client from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, loop, cluster, slowinc diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py new file mode 100644 index 0000000000..9d6422923f --- /dev/null +++ b/distributed/tests/test_variable.py @@ -0,0 +1,114 @@ +from __future__ import print_function, division, absolute_import + +from operator import add +from time import sleep + +import pytest +from toolz import take +from tornado import gen + +from distributed import Client, Variable +from distributed.metrics import time +from distributed.utils_test import gen_cluster, inc, loop, cluster, slowinc + + +@gen_cluster(client=True) +def test_variable(c, s, a, b): + x = Variable('x') + xx = Variable('x') + assert x.client is c + + future = c.submit(inc, 1) + + yield x._set(future) + future2 = yield xx._get() + assert future.key == future2.key + + del future, future2 + + yield gen.sleep(0.1) + assert s.task_state # future still present + + yield x.delete() + + start = time() + while s.task_state: + yield gen.sleep(0.01) + assert time() < start + 5 + + +@gen_cluster(client=True) +def test_queue_with_data(c, s, a, b): + x = Variable('x') + xx = Variable('x') + assert x.client is c + + yield x._set([1, 'hello']) + data = yield xx._get() + + assert data == [1, 'hello'] + + +def test_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address']) as c: + future = c.submit(lambda x: x + 1, 10) + x = Variable('x') + xx = Variable('x') + x.set(future) + future2 = xx.get() + + assert future2.result() == 11 + + +@gen_cluster() +def test_hold_futures(s, a, b): + c1 = yield Client(s.address, asynchronous=True) + future = c1.submit(lambda x: x + 1, 10) + x1 = Variable('x') + yield x1._set(future) + del x1 + yield c1._shutdown() + + yield gen.sleep(0.1) + + c2 = yield Client(s.address, asynchronous=True) + x2 = Variable('x') + future2 = yield x2._get() + result = yield future2 + + assert result == 11 + yield c2._shutdown() + + +@gen_cluster(client=True) +def test_timeout(c, s, a, b): + v = Variable('v') + + start = time() + with pytest.raises(gen.TimeoutError): + yield v._get(timeout=0.1) + assert time() - start < 0.5 + + +@gen_cluster(client=True) +def test_cleanup(c, s, a, b): + v = Variable('v') + vv = Variable('v') + + x = c.submit(lambda x: x + 1, 10) + y = c.submit(lambda x: x + 1, 20) + x_key = x.key + + yield v._set(x) + del x + yield gen.sleep(0.1) + + t_future = xx = vv._get() + yield gen.moment + v._set(y) + + future = yield t_future + assert future.key == x_key + result = yield future + assert result == 11 diff --git a/distributed/variable.py b/distributed/variable.py new file mode 100644 index 0000000000..76eeb4fc09 --- /dev/null +++ b/distributed/variable.py @@ -0,0 +1,156 @@ +from __future__ import print_function, division, absolute_import + +from collections import defaultdict +import logging +import uuid + +from tornado import gen +import tornado.locks + +from .client import Future, _get_global_client +from .metrics import time +from .utils import tokey, sync, log_errors + +logger = logging.getLogger(__name__) + + +class VariableExtension(object): + """ An extension for the scheduler to manage queues + + This adds the following routes to the scheduler + + * variable-set + * variable-get + * variable-delete + """ + def __init__(self, scheduler): + self.scheduler = scheduler + self.variables = dict() + self.lingering = defaultdict(set) + self.events = defaultdict(tornado.locks.Event) + self.started = tornado.locks.Condition() + + self.scheduler.handlers.update({'variable_set': self.set, + 'variable_get': self.get}) + + self.scheduler.client_handlers['variable-future-release'] = self.future_release + self.scheduler.client_handlers['variable_delete'] = self.delete + + self.scheduler.extensions['queues'] = self + + def set(self, stream=None, name=None, key=None, data=None, client=None, timeout=None): + if key is not None: + record = {'type': 'Future', 'value': key} + self.scheduler.client_desires_keys(keys=[key], client='variable-%s' % name) + else: + record = {'type': 'msgpack', 'value': data} + try: + old = self.variables[name] + except KeyError: + pass + else: + if old['type'] == 'Future': + self.release(old['value'], name) + if name not in self.variables: + self.started.notify_all() + self.variables[name] = record + + @gen.coroutine + def release(self, key, name): + while self.lingering[key, name]: + yield self.events[name].wait() + + self.scheduler.client_releases_keys(keys=[key], + client='variable-%s' % name) + del self.lingering[key, name] + + def future_release(self, name=None, key=None, client=None): + self.lingering[key, name].remove(client) + self.events[name].set() + + @gen.coroutine + def get(self, stream=None, name=None, client=None, timeout=None): + start = time() + while name not in self.variables: + if timeout is not None: + timeout2 = timeout - (time() - start) + else: + timeout2 = None + if timeout2 < 0: + raise gen.TimeoutError() + yield self.started.wait(timeout=timeout2) + record = self.variables[name] + if record['type'] == 'Future': + self.lingering[record['value'], name].add(client) + raise gen.Return(record) + + @gen.coroutine + def delete(self, stream=None, name=None, client=None): + with log_errors(): + try: + old = self.variables[name] + except KeyError: + pass + else: + if old['type'] == 'Future': + yield self.release(old['value'], name) + del self.events[name] + del self.variables[name] + + +class Variable(object): + """ Distributed Global Variable + + This allows multiple clients to share futures and data between each other + with a single mutable variable. All metadata is sequentialized through the + scheduler. Race conditions can occur. + + Examples + -------- + >>> from dask.distributed import Client, Variable # doctest: +SKIP + >>> client = Client() # doctest: +SKIP + >>> x = Variable('x') # doctest: +SKIP + >>> x.set(123) # docttest: +SKIP + >>> x.get() # docttest: +SKIP + 123 + >>> future = client.submit(f, x) # doctest: +SKIP + >>> x.set(future) # doctest: +SKIP + """ + def __init__(self, name=None, client=None, maxsize=0): + self.client = client or _get_global_client() + self.name = name or 'variable-' + uuid.uuid4().hex + + @gen.coroutine + def _set(self, value): + if isinstance(value, Future): + yield self.client.scheduler.variable_set(key=tokey(value.key), + name=self.name) + else: + yield self.client.scheduler.variable_set(data=value, + name=self.name) + + def set(self, value, timeout=None): + return sync(self.client.loop, self._set, value) + + @gen.coroutine + def _get(self, timeout=None): + d = yield self.client.scheduler.variable_get(timeout=timeout, + name=self.name, + client=self.client.id) + if d['type'] == 'Future': + value = Future(d['value'], self.client, inform=True) + self.client._send_to_scheduler({'op': 'variable-future-release', + 'name': self.name, + 'key': d['value'], + 'client': self.client.id}) + else: + value = d['value'] + raise gen.Return(value) + + def get(self, timeout=None): + return sync(self.client.loop, self._get, timeout=timeout) + + def delete(self): + if self.client.status == 'running': # TODO: can leave zombie futures + self.client._send_to_scheduler({'op': 'variable_delete', + 'name': self.name})