Skip to content

Commit

Permalink
Add Global Variable coordination object
Browse files Browse the repository at this point in the history
This creates a global `Variable` object to which clients can set and get
values.  Coordination happens through the central scheduler.

```python
In [1]: from distributed import Client, Variable
In [2]: client = Client()

In [3]: v = Variable('var-1')
In [4]: v.set(1)
In [5]: v.get()
Out[5]: 1

In [6]: client.scheduler.address
Out[6]: 'tcp://127.0.0.1:36535'
```

```python
In [1]: from distributed import Client, Variable
In [2]: client = Client('tcp://127.0.0.1:36535')
In [3]: v = Variable('var-1')  # same name

In [4]: v.get()
Out[4]: 1
```

This can send msgpack values or futures
  • Loading branch information
mrocklin committed Jun 3, 2017
1 parent 74c2a9a commit ebc6bd3
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 7 deletions.
1 change: 1 addition & 0 deletions distributed/__init__.py
Expand Up @@ -11,6 +11,7 @@
from .queues import Queue
from .scheduler import Scheduler
from .utils import sync
from .variable import Variable
from .worker import Worker, get_worker
from .worker_client import local_client, worker_client

Expand Down
15 changes: 9 additions & 6 deletions distributed/scheduler.py
Expand Up @@ -32,17 +32,19 @@
error_message, clean_exception, CommClosedError)
from .metrics import time
from .node import ServerNode
from .publish import PublishExtension
from .channels import ChannelScheduler
from .queues import QueueExtension
from .stealing import WorkStealing
from .recreate_exceptions import ReplayExceptionScheduler
from .security import Security
from .utils import (All, ignoring, get_ip, get_fileno_limit, log_errors,
key_split, validate_key)
from .utils_comm import (scatter_to_workers, gather_from_workers)
from .versions import get_versions

from .channels import ChannelScheduler
from .publish import PublishExtension
from .queues import QueueExtension
from .recreate_exceptions import ReplayExceptionScheduler
from .stealing import WorkStealing
from .variable import VariableExtension


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -190,7 +192,8 @@ def __init__(self, center=None, loop=None,
delete_interval=500, synchronize_worker_interval=60000,
services=None, allowed_failures=ALLOWED_FAILURES,
extensions=[ChannelScheduler, PublishExtension, WorkStealing,
ReplayExceptionScheduler, QueueExtension],
ReplayExceptionScheduler, QueueExtension,
VariableExtension],
validate=False, scheduler_file=None, security=None,
**kwargs):

Expand Down
1 change: 0 additions & 1 deletion distributed/tests/test_queues.py
Expand Up @@ -8,7 +8,6 @@
from tornado import gen

from distributed import Client, Queue
from distributed import worker_client
from distributed.metrics import time
from distributed.utils_test import gen_cluster, inc, loop, cluster, slowinc

Expand Down
114 changes: 114 additions & 0 deletions distributed/tests/test_variable.py
@@ -0,0 +1,114 @@
from __future__ import print_function, division, absolute_import

from operator import add
from time import sleep

import pytest
from toolz import take
from tornado import gen

from distributed import Client, Variable
from distributed.metrics import time
from distributed.utils_test import gen_cluster, inc, loop, cluster, slowinc


@gen_cluster(client=True)
def test_variable(c, s, a, b):
x = Variable('x')
xx = Variable('x')
assert x.client is c

future = c.submit(inc, 1)

yield x._set(future)
future2 = yield xx._get()
assert future.key == future2.key

del future, future2

yield gen.sleep(0.1)
assert s.task_state # future still present

yield x.delete()

start = time()
while s.task_state:
yield gen.sleep(0.01)
assert time() < start + 5


@gen_cluster(client=True)
def test_queue_with_data(c, s, a, b):
x = Variable('x')
xx = Variable('x')
assert x.client is c

yield x._set([1, 'hello'])
data = yield xx._get()

assert data == [1, 'hello']


def test_sync(loop):
with cluster() as (s, [a, b]):
with Client(s['address']) as c:
future = c.submit(lambda x: x + 1, 10)
x = Variable('x')
xx = Variable('x')
x.set(future)
future2 = xx.get()

assert future2.result() == 11


@gen_cluster()
def test_hold_futures(s, a, b):
c1 = yield Client(s.address, asynchronous=True)
future = c1.submit(lambda x: x + 1, 10)
x1 = Variable('x')
yield x1._set(future)
del x1
yield c1._shutdown()

yield gen.sleep(0.1)

c2 = yield Client(s.address, asynchronous=True)
x2 = Variable('x')
future2 = yield x2._get()
result = yield future2

assert result == 11
yield c2._shutdown()


@gen_cluster(client=True)
def test_timeout(c, s, a, b):
v = Variable('v')

start = time()
with pytest.raises(gen.TimeoutError):
yield v._get(timeout=0.1)
assert time() - start < 0.5


@gen_cluster(client=True)
def test_cleanup(c, s, a, b):
v = Variable('v')
vv = Variable('v')

x = c.submit(lambda x: x + 1, 10)
y = c.submit(lambda x: x + 1, 20)
x_key = x.key

yield v._set(x)
del x
yield gen.sleep(0.1)

t_future = xx = vv._get()
yield gen.moment
v._set(y)

future = yield t_future
assert future.key == x_key
result = yield future
assert result == 11
156 changes: 156 additions & 0 deletions distributed/variable.py
@@ -0,0 +1,156 @@
from __future__ import print_function, division, absolute_import

from collections import defaultdict
import logging
import uuid

from tornado import gen
import tornado.locks

from .client import Future, _get_global_client
from .metrics import time
from .utils import tokey, sync, log_errors

logger = logging.getLogger(__name__)


class VariableExtension(object):
""" An extension for the scheduler to manage queues
This adds the following routes to the scheduler
* variable-set
* variable-get
* variable-delete
"""
def __init__(self, scheduler):
self.scheduler = scheduler
self.variables = dict()
self.lingering = defaultdict(set)
self.events = defaultdict(tornado.locks.Event)
self.started = tornado.locks.Condition()

self.scheduler.handlers.update({'variable_set': self.set,
'variable_get': self.get})

self.scheduler.client_handlers['variable-future-release'] = self.future_release
self.scheduler.client_handlers['variable_delete'] = self.delete

self.scheduler.extensions['queues'] = self

def set(self, stream=None, name=None, key=None, data=None, client=None, timeout=None):
if key is not None:
record = {'type': 'Future', 'value': key}
self.scheduler.client_desires_keys(keys=[key], client='variable-%s' % name)
else:
record = {'type': 'msgpack', 'value': data}
try:
old = self.variables[name]
except KeyError:
pass
else:
if old['type'] == 'Future':
self.release(old['value'], name)
if name not in self.variables:
self.started.notify_all()
self.variables[name] = record

@gen.coroutine
def release(self, key, name):
while self.lingering[key, name]:
yield self.events[name].wait()

self.scheduler.client_releases_keys(keys=[key],
client='variable-%s' % name)
del self.lingering[key, name]

def future_release(self, name=None, key=None, client=None):
self.lingering[key, name].remove(client)
self.events[name].set()

@gen.coroutine
def get(self, stream=None, name=None, client=None, timeout=None):
start = time()
while name not in self.variables:
if timeout is not None:
timeout2 = timeout - (time() - start)
else:
timeout2 = None
if timeout2 < 0:
raise gen.TimeoutError()
yield self.started.wait(timeout=timeout2)
record = self.variables[name]
if record['type'] == 'Future':
self.lingering[record['value'], name].add(client)
raise gen.Return(record)

@gen.coroutine
def delete(self, stream=None, name=None, client=None):
with log_errors():
try:
old = self.variables[name]
except KeyError:
pass
else:
if old['type'] == 'Future':
yield self.release(old['value'], name)
del self.events[name]
del self.variables[name]


class Variable(object):
""" Distributed Global Variable
This allows multiple clients to share futures and data between each other
with a single mutable variable. All metadata is sequentialized through the
scheduler. Race conditions can occur.
Examples
--------
>>> from dask.distributed import Client, Variable # doctest: +SKIP
>>> client = Client() # doctest: +SKIP
>>> x = Variable('x') # doctest: +SKIP
>>> x.set(123) # docttest: +SKIP
>>> x.get() # docttest: +SKIP
123
>>> future = client.submit(f, x) # doctest: +SKIP
>>> x.set(future) # doctest: +SKIP
"""
def __init__(self, name=None, client=None, maxsize=0):
self.client = client or _get_global_client()
self.name = name or 'variable-' + uuid.uuid4().hex

@gen.coroutine
def _set(self, value):
if isinstance(value, Future):
yield self.client.scheduler.variable_set(key=tokey(value.key),
name=self.name)
else:
yield self.client.scheduler.variable_set(data=value,
name=self.name)

def set(self, value, timeout=None):
return sync(self.client.loop, self._set, value)

@gen.coroutine
def _get(self, timeout=None):
d = yield self.client.scheduler.variable_get(timeout=timeout,
name=self.name,
client=self.client.id)
if d['type'] == 'Future':
value = Future(d['value'], self.client, inform=True)
self.client._send_to_scheduler({'op': 'variable-future-release',
'name': self.name,
'key': d['value'],
'client': self.client.id})
else:
value = d['value']
raise gen.Return(value)

def get(self, timeout=None):
return sync(self.client.loop, self._get, timeout=timeout)

def delete(self):
if self.client.status == 'running': # TODO: can leave zombie futures
self.client._send_to_scheduler({'op': 'variable_delete',
'name': self.name})

0 comments on commit ebc6bd3

Please sign in to comment.