From cfad841fc83fa25b88618d478046d0ef49b762de Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 18 Jan 2019 09:53:11 +0100 Subject: [PATCH] ENH: Add Semaphore extension --- distributed/__init__.py | 1 + distributed/distributed.yaml | 2 + distributed/scheduler.py | 2 + distributed/semaphore.py | 264 ++++++++++++++++++++++++++++ distributed/tests/test_semaphore.py | 133 ++++++++++++++ 5 files changed, 402 insertions(+) create mode 100644 distributed/semaphore.py create mode 100644 distributed/tests/test_semaphore.py diff --git a/distributed/__init__.py b/distributed/__init__.py index 7b2bc4ab082..6d2e8e6802a 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -22,6 +22,7 @@ from .nanny import Nanny from .pubsub import Pub, Sub from .queues import Queue +from .semaphore import Semaphore from .scheduler import Scheduler from .threadpoolexecutor import rejoin from .utils import sync diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 3ae9b7ee690..43bd94ad32b 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -22,6 +22,8 @@ distributed: worker-ttl: null # like '60s'. Time to live for workers. They must heartbeat faster than this preload: [] preload-argv: [] + locks: + lease-validation-interval: 10s worker: blocked-handlers: [] diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 68a80ac664b..2cc1dc5a016 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -62,6 +62,7 @@ from .publish import PublishExtension from .queues import QueueExtension +from .semaphore import SemaphoreExtension from .recreate_exceptions import ReplayExceptionScheduler from .lock import LockExtension from .pubsub import PubSubSchedulerExtension @@ -85,6 +86,7 @@ QueueExtension, VariableExtension, PubSubSchedulerExtension, + SemaphoreExtension, ] if dask.config.get("distributed.scheduler.work-stealing"): diff --git a/distributed/semaphore.py b/distributed/semaphore.py new file mode 100644 index 00000000000..3c5813d6441 --- /dev/null +++ b/distributed/semaphore.py @@ -0,0 +1,264 @@ +from __future__ import absolute_import, division, print_function + +import uuid +from collections import defaultdict, deque +from functools import partial + +import dask +import tornado.locks +import tornado.queues +from tornado import gen + +from .client import Client, _get_global_client +from .utils import PeriodicCallback, log_errors, parse_timedelta +from .worker import get_client, get_worker +from toolz.dicttoolz import valmap +from .metrics import time + + +class _Watch(object): + def __init__(self, duration=None): + self.duration = duration + self.started_at = None + + def start(self): + self.started_at = time() + + def leftover(self): + if self.duration is None: + return None + else: + elapsed = time() - self.started_at + return max(0, self.duration - elapsed) + + +class SemaphoreExtension(object): + """ An extension for the scheduler to manage Semaphores + + This adds the following routes to the scheduler + + * semaphore_acquire + * semaphore_release + * semaphore_create + """ + + def __init__(self, scheduler): + self.locks = defaultdict(tornado.locks.Lock) + self.scheduler = scheduler + self.leases = defaultdict(deque) + self.events = defaultdict(tornado.locks.Event) + self.max_leases = dict() + self.leases_per_client = defaultdict(partial(defaultdict, deque)) + self.scheduler.handlers.update( + { + "semaphore_create": self.create, + "semaphore_acquire": self.acquire, + "semaphore_release": self.release, + } + ) + + self.scheduler.extensions["semaphores"] = self + self.pc_validate_leases = PeriodicCallback( + self._validate_leases, + parse_timedelta( + dask.config.get("distributed.scheduler.locks.lease-validation-interval") + ), + io_loop=self.scheduler.loop, + ) + self.pc_validate_leases.start() + self._validation_running = False + + @gen.coroutine + def create( + self, stream=None, name=None, client=None, timeout=None, max_leases=None + ): + if name not in self.leases: + assert isinstance(max_leases, int), max_leases + self.max_leases[name] = max_leases + else: + if max_leases != self.max_leases[name]: + raise ValueError( + "Inconsistent max leases: %s, expected: %s" + % (max_leases, self.max_leases[name]) + ) + + @gen.coroutine + def _get_lease(self, client, name, identifier): + # We should make sure that the client is already properly registered with the scheduler + # otherwise the lease validation will mop up every acquired release immediately + while client not in self.scheduler.clients: + yield + with (yield self.locks[name].acquire()): + result = True + if len(self.leases[name]) < self.max_leases[name]: + self.leases[name].append(identifier) + self.leases_per_client[client][name].append(identifier) + else: + result = False + raise gen.Return(result) + + @gen.coroutine + def acquire( + self, stream=None, name=None, client=None, timeout=None, identifier=None + ): + with log_errors(): + if isinstance(name, list): + name = tuple(name) + w = _Watch(timeout) + w.start() + + while True: + # Reset the event and try to get a release. The event will be set if the state + # is changed and helps to identify when it is worth to retry an acquire + self.events[name].clear() + future = self._get_lease(client, name, identifier) + if timeout is not None: + future = gen.with_timeout(w.leftover(), future) + try: + result = yield future + except gen.TimeoutError: + result = False + + # If acquiring fails, we wait for the event to be set, i.e. something has + # been released and we can try to acquire again (continue loop) + if not result: + future = self.events[name].wait() + if timeout is not None: + future = gen.with_timeout(w.leftover(), future) + try: + yield future + continue + except gen.TimeoutError: + result = False + raise gen.Return(result) + + @gen.coroutine + def release(self, stream=None, name=None, client=None, identifier=None): + with log_errors(): + if isinstance(name, list): + name = tuple(name) + if name in self.leases and identifier in self.leases[name]: + self.scheduler.loop.add_callback( + self._release_value, name, client, identifier + ) + else: + raise ValueError("Semaphore released too many times.") + + @gen.coroutine + def _release_value(self, name, client, identifier): + with (yield self.locks[name].acquire()): + self.leases_per_client[client][name].remove(identifier) + self.leases[name].remove(identifier) + self.events[name].set() + + def _release_client(self, client): + semaphore_names = list(self.leases_per_client[client]) + for name in semaphore_names: + ids = list(self.leases_per_client[client][name]) + for _id in list(ids): + self.scheduler.loop.add_callback( + self._release_value, name=name, client=client, identifier=_id + ) + + @gen.coroutine + def _validate_leases(self): + if not self._validation_running: + self._validation_running = True + known_clients = set(self.leases_per_client.keys()) + scheduler_clients = set(self.scheduler.clients.keys()) + for client in known_clients - scheduler_clients: + client_has_leases = sum( + valmap(len, self.leases_per_client[client]).values() + ) + if client_has_leases: + self._release_client(client) + else: + self._validation_running = False + + +class Semaphore(object): + def __init__(self, max_leases=1, name=None, client=None): + self.client = client or _get_global_client() or get_worker().client + self.id = uuid.uuid4().hex + self.name = name or "semaphore-" + uuid.uuid4().hex + self.max_leases = max_leases + + if self.client.asynchronous: + self._started = self.client.scheduler.semaphore_create( + name=self.name, max_leases=max_leases + ) + else: + self.client.sync( + self.client.scheduler.semaphore_create, + name=self.name, + max_leases=max_leases, + ) + self._started = gen.moment + + def __await__(self): + @gen.coroutine + def _(): + yield self._started + raise gen.Return(self) + + return _().__await__() + + def acquire(self, timeout=None): + """ + Acquire a semaphore. + + If the internal counter is greater than zero, decrement it by one and return True immediately. + If it is zero, wait until a release() is called and return True. + """ + return self.client.sync( + self.client.scheduler.semaphore_acquire, + name=self.name, + timeout=timeout, + client=self.client.id, + identifier=self.id, + ) + + def release(self): + """ + Release a semaphore. + + Increment the internal counter by one. + """ + + """ Release the lock if already acquired """ + return self.client.sync( + self.client.scheduler.semaphore_release, + name=self.name, + client=self.client.id, + identifier=self.id, + ) + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, *args, **kwargs): + self.release() + + @gen.coroutine + def __aenter__(self): + yield self.acquire() + raise gen.Return(self) + + @gen.coroutine + def __aexit__(self, *args, **kwargs): + yield self.release() + + def __getstate__(self): + return (self.name, self.client.scheduler.address, self.max_leases) + + def __setstate__(self, state): + name, address, max_leases = state + try: + client = get_client(address) + except (AttributeError, AssertionError): + client = Client(address, set_as_default=False) + self.__init__(name=name, client=client, max_leases=max_leases) + + def close(self): + self.client.sync(self.client.scheduler.semaphore_close, name=self.name) diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py new file mode 100644 index 00000000000..e18a6fc951e --- /dev/null +++ b/distributed/tests/test_semaphore.py @@ -0,0 +1,133 @@ +from __future__ import absolute_import, division, print_function + +import pickle +from time import sleep + +import dask +from distributed import Semaphore, get_client +from distributed.metrics import time +from distributed.utils_test import ( # noqa + cluster_fixture, + client, + gen_cluster, + loop, +) + +from distributed import Client + + +@gen_cluster(client=True) +def test_semaphore(c, s, a, b): + semaphore = Semaphore(max_leases=2, name="x") + result = yield semaphore.acquire() + assert result is True + + second = yield semaphore.acquire() + assert second is True + start = time() + result = yield semaphore.acquire(timeout=0.1) + stop = time() + assert stop - start < 0.3 + assert result is False + + +@gen_cluster(client=True) +def test_serializable(c, s, a, b): + + sem = Semaphore(max_leases=2, name="x") + res = yield sem.acquire() + assert len(s.extensions["semaphores"].leases["x"]) == 1 + assert res + sem2 = pickle.loads(pickle.dumps(sem)) + assert sem2.name == sem.name + assert sem2.client.scheduler.address == sem.client.scheduler.address + + # actual leases didn't change + assert len(s.extensions["semaphores"].leases["x"]) == 1 + + res = yield sem2.acquire() + assert res + + # Ensure that both objects access the same semaphore + res = yield sem.acquire(timeout=0) + + assert not res + res = yield sem2.acquire(timeout=0) + + assert not res + + +@gen_cluster(client=True) +def test_release_simple(c, s, a, b): + def f(x, semaphore=None): + with semaphore: + assert semaphore.name == "x" + return x + 1 + + sem = Semaphore(max_leases=2, name="x") + futures = c.map(f, list(range(10)), semaphore=sem) + yield c.gather(futures) + + +@gen_cluster(client=True) +def test_acquires_with_zero_timeout(c, s, a, b): + sem = Semaphore(1, "x") + yield sem.acquire(timeout=0) + res = yield sem.acquire(timeout=0) + assert res is False + yield sem.release() + + res = yield sem.acquire(timeout=1) + assert res + yield sem.release() + res = yield sem.acquire(timeout=1) + assert res + yield sem.release() + + +def test_timeout_sync(client): + with Semaphore(name="x"): + assert Semaphore(1, "x").acquire(timeout=0.05) is False + + +def test_lock_name_only(client): + def f(x): + with Semaphore(name="x"): + client = get_client() + assert client.get_metadata("locked") is False + client.set_metadata("locked", True) + sleep(0.01) + assert client.get_metadata("locked") is True + client.set_metadata("locked", False) + + client.set_metadata("locked", False) + futures = client.map(f, range(10)) + client.gather(futures) + + +@gen_cluster(client=True) +def test_release_semaphore_after_timeout(c, s, a, b): + with dask.config.set( + {"distributed.scheduler.locks.lease-validation-interval": "100ms"} + ): + sem = Semaphore(name="x", max_leases=2) + yield sem.acquire() + semY = Semaphore(name="y") + + with Client(s.address, asynchronous=True, name="ClientB") as clientB: + semB = Semaphore(name="x", max_leases=2, client=clientB) + semYB = Semaphore(name="y", client=clientB) + + assert (yield semB.acquire()) + assert (yield semYB.acquire()) + + assert not (yield sem.acquire(timeout=0)) + assert not (yield semB.acquire(timeout=0)) + assert not (yield semYB.acquire(timeout=0)) + + # At this point, we should be able to acquire x and y once + assert (yield sem.acquire()) + assert (yield semY.acquire()) + + assert not (yield semY.acquire(timeout=0)) + assert not (yield sem.acquire(timeout=0))