Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add types to async_helpers #8260

Merged
merged 4 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8260.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.util.async_helpers`.
3 changes: 2 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ files =
synapse/http/federation/well_known_resolver.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/logging/,
synapse/logging,
synapse/metrics,
synapse/module_api,
synapse/notifier.py,
Expand All @@ -54,6 +54,7 @@ files =
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/async_helpers.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
Expand Down
135 changes: 85 additions & 50 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,25 @@
import collections
import logging
from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
from typing import (
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
TypeVar,
Union,
)

import attr
from typing_extensions import ContextManager

from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure

from synapse.logging.context import (
Expand Down Expand Up @@ -54,7 +66,7 @@ class ObservableDeferred:

__slots__ = ["_deferred", "_observers", "_result"]

def __init__(self, deferred, consumeErrors=False):
def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
Expand Down Expand Up @@ -111,44 +123,46 @@ def remove(r):
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)

def observers(self):
def observers(self) -> List[defer.Deferred]:
return self._observers

def has_called(self):
def has_called(self) -> bool:
return self._result is not None

def has_succeeded(self):
def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True

def get_result(self):
def get_result(self) -> Any:
return self._result[1]

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
return getattr(self._deferred, name)

def __setattr__(self, name, value):
def __setattr__(self, name: str, value: Any) -> None:
setattr(self._deferred, name, value)

def __repr__(self):
def __repr__(self) -> str:
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self),
self._result,
self._deferred,
)


def concurrently_execute(func, args, limit):
"""Executes the function with each argument conncurrently while limiting
def concurrently_execute(
func: Callable, args: Iterable[Any], limit: int
) -> defer.Deferred:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.

Args:
func (func): Function to execute, should return a deferred or coroutine.
args (Iterable): List of arguments to pass to func, each invocation of func
func: Function to execute, should return a deferred or coroutine.
args: List of arguments to pass to func, each invocation of func
gets a single argument.
limit (int): Maximum number of conccurent executions.
limit: Maximum number of conccurent executions.

Returns:
deferred: Resolved when all function invocations have finished.
Deferred[list]: Resolved when all function invocations have finished.
"""
it = iter(args)

Expand All @@ -167,14 +181,17 @@ async def _concurrently_execute_inner():
).addErrback(unwrapFirstError)


def yieldable_gather_results(func, iter, *args, **kwargs):
def yieldable_gather_results(
func: Callable, iter: Iterable, *args: Any, **kwargs: Any
) -> defer.Deferred:
"""Executes the function with each argument concurrently.

Args:
func (func): Function to execute that returns a Deferred
iter (iter): An iterable that yields items that get passed as the first
func: Function to execute that returns a Deferred
iter: An iterable that yields items that get passed as the first
argument to the function
*args: Arguments to be passed to each call to func
**kwargs: Keyword arguments to be passed to each call to func

Returns
Deferred[list]: Resolved when all functions have been invoked, or errors if
Expand All @@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
).addErrback(unwrapFirstError)


@attr.s(slots=True)
class _LinearizerEntry:
# The number of things executing.
count = attr.ib(type=int)
# Deferreds for the things blocked from executing.
deferreds = attr.ib(type=collections.OrderedDict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not awaitables. Deferreds are Awaitables too, right?

Or is this because the rest of the file refers to them as deferreds?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it is actually a Deferred, not just an Awaitable (and since this is internal to the functionality of the Linearizer it seems better to be specific).

new_defer = make_deferred_yieldable(defer.Deferred())



class Linearizer:
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few things happen at a time on a given resource.

Example:

with (yield limiter.queue("test_key")):
with await limiter.queue("test_key"):
# do some work.

"""

def __init__(self, name=None, max_count=1, clock=None):
def __init__(
self,
name: Optional[str] = None,
max_count: int = 1,
clock: Optional[Clock] = None,
):
"""
Args:
max_count(int): The maximum number of concurrent accesses
max_count: The maximum number of concurrent accesses
"""
if name is None:
self.name = id(self)
self.name = id(self) # type: Union[str, int]
else:
self.name = name

Expand All @@ -216,15 +246,10 @@ def __init__(self, name=None, max_count=1, clock=None):
self._clock = clock
self.max_count = max_count

# key_to_defer is a map from the key to a 2 element list where
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
self.key_to_defer = (
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
# key_to_defer is a map from the key to a _LinearizerEntry.
self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry]

def is_queued(self, key) -> bool:
def is_queued(self, key: Hashable) -> bool:
"""Checks whether there is a process queued up waiting
"""
entry = self.key_to_defer.get(key)
Expand All @@ -234,25 +259,27 @@ def is_queued(self, key) -> bool:

# There are waiting deferreds only in the OrderedDict of deferreds is
# non-empty.
return bool(entry[1])
return bool(entry.deferreds)

def queue(self, key):
def queue(self, key: Hashable) -> defer.Deferred:
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
# propagated inside inlineCallbacks until Twisted 18.7)
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
entry = self.key_to_defer.setdefault(
key, _LinearizerEntry(0, collections.OrderedDict())
)

# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
# When one of the things currently executing finishes it will callback
# this item so that it can continue executing.
if entry[0] >= self.max_count:
if entry.count >= self.max_count:
res = self._await_lock(key)
else:
logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key
)
entry[0] += 1
entry.count += 1
res = defer.succeed(None)

# once we successfully get the lock, we need to return a context manager which
Expand All @@ -267,23 +294,23 @@ def _ctx_manager(_):

# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
entry[0] -= 1
entry.count -= 1

if entry[1]:
(next_def, _) = entry[1].popitem(last=False)
if entry.deferreds:
(next_def, _) = entry.deferreds.popitem(last=False)

# we need to run the next thing in the sentinel context.
with PreserveLoggingContext():
next_def.callback(None)
elif entry[0] == 0:
elif entry.count == 0:
# We were the last thing for this key: remove it from the
# map.
del self.key_to_defer[key]

res.addCallback(_ctx_manager)
return res

def _await_lock(self, key):
def _await_lock(self, key: Hashable) -> defer.Deferred:
"""Helper for queue: adds a deferred to the queue

Assumes that we've already checked that we've reached the limit of the number
Expand All @@ -298,11 +325,11 @@ def _await_lock(self, key):
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)

new_defer = make_deferred_yieldable(defer.Deferred())
entry[1][new_defer] = 1
entry.deferreds[new_defer] = 1

def cb(_r):
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1
entry.count += 1

# if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can
Expand Down Expand Up @@ -331,7 +358,7 @@ def eb(e):
)

# we just have to take ourselves back out of the queue.
del entry[1][new_defer]
del entry.deferreds[new_defer]
return e

new_defer.addCallbacks(cb, eb)
Expand Down Expand Up @@ -419,14 +446,22 @@ def _ctx_manager():
return _ctx_manager()


def _cancelled_to_timed_out_error(value, timeout):
R = TypeVar("R")


def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
if isinstance(value, failure.Failure):
value.trap(CancelledError)
raise defer.TimeoutError(timeout, "Deferred")
return value


def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
def timeout_deferred(
deferred: defer.Deferred,
timeout: float,
reactor: IReactorTime,
on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
) -> defer.Deferred:
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new
deferred that wraps and times out the given deferred, correctly handling
Expand All @@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred

Args:
deferred (Deferred)
timeout (float): Timeout in seconds
reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
on_timeout_cancel (callable): A callable which is called immediately
deferred: The Deferred to potentially timeout.
timeout: Timeout in seconds
reactor: The twisted reactor to use
on_timeout_cancel: A callable which is called immediately
after the deferred times out, and not if this deferred is
otherwise cancelled before the timeout.

Expand All @@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
CancelledError Failure into a defer.TimeoutError.

Returns:
Deferred
A new Deferred.
"""

new_d = defer.Deferred()
Expand Down