Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .nanny import Nanny
from .pubsub import Pub, Sub
from .queues import Queue
from .semaphore import Semaphore
from .scheduler import Scheduler
from .threadpoolexecutor import rejoin
from .utils import sync
Expand Down
2 changes: 2 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ distributed:
worker-ttl: null # like '60s'. Time to live for workers. They must heartbeat faster than this
preload: []
preload-argv: []
locks:
lease-validation-interval: 10s

worker:
blocked-handlers: []
Expand Down
2 changes: 2 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

from .publish import PublishExtension
from .queues import QueueExtension
from .semaphore import SemaphoreExtension
from .recreate_exceptions import ReplayExceptionScheduler
from .lock import LockExtension
from .pubsub import PubSubSchedulerExtension
Expand All @@ -85,6 +86,7 @@
QueueExtension,
VariableExtension,
PubSubSchedulerExtension,
SemaphoreExtension,
]

if dask.config.get("distributed.scheduler.work-stealing"):
Expand Down
264 changes: 264 additions & 0 deletions distributed/semaphore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from __future__ import absolute_import, division, print_function

import uuid
from collections import defaultdict, deque
from functools import partial

import dask
import tornado.locks
import tornado.queues
from tornado import gen

from .client import Client, _get_global_client
from .utils import PeriodicCallback, log_errors, parse_timedelta
from .worker import get_client, get_worker
from toolz.dicttoolz import valmap
from .metrics import time


class _Watch(object):
def __init__(self, duration=None):
self.duration = duration
self.started_at = None

def start(self):
self.started_at = time()

def leftover(self):
if self.duration is None:
return None
else:
elapsed = time() - self.started_at
return max(0, self.duration - elapsed)


class SemaphoreExtension(object):
""" An extension for the scheduler to manage Semaphores

This adds the following routes to the scheduler

* semaphore_acquire
* semaphore_release
* semaphore_create
"""

def __init__(self, scheduler):
self.locks = defaultdict(tornado.locks.Lock)
self.scheduler = scheduler
self.leases = defaultdict(deque)
self.events = defaultdict(tornado.locks.Event)
self.max_leases = dict()
self.leases_per_client = defaultdict(partial(defaultdict, deque))
self.scheduler.handlers.update(
{
"semaphore_create": self.create,
"semaphore_acquire": self.acquire,
"semaphore_release": self.release,
}
)

self.scheduler.extensions["semaphores"] = self
self.pc_validate_leases = PeriodicCallback(
self._validate_leases,
parse_timedelta(
dask.config.get("distributed.scheduler.locks.lease-validation-interval")
),
io_loop=self.scheduler.loop,
)
self.pc_validate_leases.start()
self._validation_running = False

@gen.coroutine
def create(
self, stream=None, name=None, client=None, timeout=None, max_leases=None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend renaming stream= to comm=. This switch happened a while ago but apparently not all code was updated.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually a little surprised that this worked. We must be using a positional argument somewhere.

):
if name not in self.leases:
assert isinstance(max_leases, int), max_leases
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: is there a reason an assert is preferred to a more explicit ValueError here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably even remove this. I used this a bit for debugging. In we keep it, a ValueError or TypeError would make the most sense

self.max_leases[name] = max_leases
else:
if max_leases != self.max_leases[name]:
raise ValueError(
"Inconsistent max leases: %s, expected: %s"
% (max_leases, self.max_leases[name])
)

@gen.coroutine
def _get_lease(self, client, name, identifier):
# We should make sure that the client is already properly registered with the scheduler
# otherwise the lease validation will mop up every acquired release immediately
while client not in self.scheduler.clients:
yield
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
yield
yield gen.sleep(0.100)

Otherwise this stresses out the event loop.

with (yield self.locks[name].acquire()):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an async lock around this block? There are no yield points within it, so I wouldn't expect there to be any chance for another coroutine to happen at the same time.

result = True
if len(self.leases[name]) < self.max_leases[name]:
self.leases[name].append(identifier)
self.leases_per_client[client][name].append(identifier)
else:
result = False
raise gen.Return(result)

@gen.coroutine
def acquire(
self, stream=None, name=None, client=None, timeout=None, identifier=None
):
with log_errors():
if isinstance(name, list):
name = tuple(name)
w = _Watch(timeout)
w.start()

while True:
# Reset the event and try to get a release. The event will be set if the state
# is changed and helps to identify when it is worth to retry an acquire
self.events[name].clear()
future = self._get_lease(client, name, identifier)
if timeout is not None:
future = gen.with_timeout(w.leftover(), future)
try:
result = yield future
except gen.TimeoutError:
result = False

# If acquiring fails, we wait for the event to be set, i.e. something has
# been released and we can try to acquire again (continue loop)
if not result:
future = self.events[name].wait()
if timeout is not None:
future = gen.with_timeout(w.leftover(), future)
try:
yield future
continue
except gen.TimeoutError:
result = False
raise gen.Return(result)

@gen.coroutine
def release(self, stream=None, name=None, client=None, identifier=None):
with log_errors():
if isinstance(name, list):
name = tuple(name)
if name in self.leases and identifier in self.leases[name]:
self.scheduler.loop.add_callback(
self._release_value, name, client, identifier
)
else:
raise ValueError("Semaphore released too many times.")

@gen.coroutine
def _release_value(self, name, client, identifier):
with (yield self.locks[name].acquire()):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that these locks aren't necessary (I don't think they're ever used around code with yields) then my guess is that a bit of this can be simplified.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, works without locks. Will remove them again.

self.leases_per_client[client][name].remove(identifier)
self.leases[name].remove(identifier)
self.events[name].set()

def _release_client(self, client):
semaphore_names = list(self.leases_per_client[client])
for name in semaphore_names:
ids = list(self.leases_per_client[client][name])
for _id in list(ids):
self.scheduler.loop.add_callback(
self._release_value, name=name, client=client, identifier=_id
)

@gen.coroutine
def _validate_leases(self):
if not self._validation_running:
self._validation_running = True
known_clients = set(self.leases_per_client.keys())
scheduler_clients = set(self.scheduler.clients.keys())
for client in known_clients - scheduler_clients:
client_has_leases = sum(
valmap(len, self.leases_per_client[client]).values()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valmap(func, d).values() -> map(func, d.values()) ?

)
if client_has_leases:
self._release_client(client)
else:
self._validation_running = False


class Semaphore(object):
def __init__(self, max_leases=1, name=None, client=None):
self.client = client or _get_global_client() or get_worker().client
self.id = uuid.uuid4().hex
self.name = name or "semaphore-" + uuid.uuid4().hex
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reuse id here?

self.max_leases = max_leases

if self.client.asynchronous:
self._started = self.client.scheduler.semaphore_create(
name=self.name, max_leases=max_leases
)
else:
self.client.sync(
self.client.scheduler.semaphore_create,
name=self.name,
max_leases=max_leases,
)
self._started = gen.moment

def __await__(self):
@gen.coroutine
def _():
yield self._started
raise gen.Return(self)

return _().__await__()

def acquire(self, timeout=None):
"""
Acquire a semaphore.

If the internal counter is greater than zero, decrement it by one and return True immediately.
If it is zero, wait until a release() is called and return True.
"""
return self.client.sync(
self.client.scheduler.semaphore_acquire,
name=self.name,
timeout=timeout,
client=self.client.id,
identifier=self.id,
)

def release(self):
"""
Release a semaphore.

Increment the internal counter by one.
"""

""" Release the lock if already acquired """
return self.client.sync(
self.client.scheduler.semaphore_release,
name=self.name,
client=self.client.id,
identifier=self.id,
)

def __enter__(self):
self.acquire()
return self

def __exit__(self, *args, **kwargs):
self.release()

@gen.coroutine
def __aenter__(self):
yield self.acquire()
raise gen.Return(self)

@gen.coroutine
def __aexit__(self, *args, **kwargs):
yield self.release()

def __getstate__(self):
return (self.name, self.client.scheduler.address, self.max_leases)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should id also be serialized here, or do we want Semaphores which have gone through the serialization / deserialization process to be considered different?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR: I think we should not serialize the ID. I should probably add a test for this behavior.

Consider the following:

from dask import delayed
from distributed import Semaphore

def access_database(sem, *args):
    with sem:
        ...
    return data

sem = Semaphore("my_database", max_leases=2)
tasks = []
for ix in range(10):
    tasks.append(
        delated(access_database)(sem, ix)
    )

dask.compute(tasks)

This is how I would probably write a piece of code to download data from a DB. I could rely on dask.distributed to serialize the class for me. Otherwise I would need to pass all initialization parameters to the function and initialize the objects myself to keep everything consistent.
When I use the serialization approach, I still require the semaphore instance for every job to be unique by ID. Otherwise the tracking is off and the mechanism for tracking leases by ID and client would not work.


def __setstate__(self, state):
name, address, max_leases = state
try:
client = get_client(address)
except (AttributeError, AssertionError):
client = Client(address, set_as_default=False)
self.__init__(name=name, client=client, max_leases=max_leases)

def close(self):
self.client.sync(self.client.scheduler.semaphore_close, name=self.name)
Loading