Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:

- name: Run tests
timeout-minutes: 5
run: uv run pytest --reruns 3 --reruns-delay 1 --maxfail 10 --rerun-except AssertionError -vvl
run: uv run pytest -vvl

- name: Type checking with basedpyright
run: uv run basedpyright
Expand All @@ -61,7 +61,7 @@ jobs:

- name: Run tests with coverage
timeout-minutes: 5
run: uv run pytest --reruns 2 --rerun-except AssertionError --maxfail 10 --cov --cov-report=xml --junitxml=junit.xml -o junit_family=legacy --cov-fail-under=100
run: uv run pytest --cov --cov-report=xml --junitxml=junit.xml -o junit_family=legacy --cov-fail-under=100

- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
Expand Down
74 changes: 37 additions & 37 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,24 +1076,29 @@ async def wait(
return done, pending


class AsyncLock:
class ReentrantAsyncLock:
"""
Implements a mutex asynchronous lock that is compatible with [async_kernel.caller.Caller][].
A Reentrant asynchronous lock compatible with [async_kernel.caller.Caller][].

The lock is reentrant in terms of [contextvars.Context][].

!!! note

- Attempting to lock a 'mutuex' configured lock that is *locked* will raise a [RuntimeError][].
- The lock context can be exitied in any order.
- The context can potentially leak.
- A 'reentrant' lock can *release* control to another context and then re-enter later for
tasks or threads called from a locked thread maintaining the same reentrant context.
"""

_reentrant: ClassVar[bool] = False
_reentrant: ClassVar[bool] = True
_count: int = 0
_ctx_count: int = 0
_ctx_current: int = 0
_releasing: bool = False

def __init__(self):
self._ctx_var: contextvars.ContextVar[int] = contextvars.ContextVar(f"Lock:{id(self)}", default=0)
self._queue: deque[tuple[int, Future[Future | None]]] = deque()
self._queue: deque[tuple[int, Future[bool]]] = deque()

@override
def __repr__(self) -> str:
Expand All @@ -1115,13 +1120,13 @@ async def acquire(self) -> Self:
"""
Acquire a lock.

If the lock is reentrant the internal counter increments to share the lock.
The internal counter increments when the lock is entered.
"""
if not self._reentrant and self.is_in_context():
msg = "Already locked and not reentrant!"
raise RuntimeError(msg)
# Get the context.
if not self._reentrant or not (ctx := self._ctx_var.get()):
if (self._ctx_count == 0) or not self._reentrant or not (ctx := self._ctx_var.get()):
self._ctx_count = ctx = self._ctx_count + 1
self._ctx_var.set(ctx)
# Check if we can lock or re-enter an active lock.
Expand All @@ -1130,21 +1135,20 @@ async def acquire(self) -> Self:
self._ctx_current = ctx
return self
# Join the queue.
k: tuple[int, Future[None | Future[Future[None] | None]]] = ctx, Future()
k: tuple[int, Future[bool]] = ctx, Future()
self._queue.append(k)
try:
fut = await k[1]
result = await k[1]
finally:
if k in self._queue:
self._queue.remove(k)
if fut:
if result:
self._ctx_current = ctx
fut.set_result(None)
if self._reentrant:
for k in tuple(self._queue):
if k[0] == ctx:
self._queue.remove(k)
k[1].set_result(None)
k[1].set_result(False)
self._count += 1
self._releasing = False
return self
Expand All @@ -1155,49 +1159,45 @@ async def release(self) -> None:

If the current depth==1 the lock will be passed to the next queued or released if there isn't one.
"""
if not self.is_in_context():
raise InvalidStateError
if self._count == 1 and self._queue and not self._releasing:
self._releasing = True
self._ctx_var.set(0)
try:
fut = Future()
k = self._queue.popleft()
k[1].set_result(fut)
await k[1]
except Exception:
self._releasing = False
self._queue.popleft()[1].set_result(True)
else:
self._count -= 1
if self._count == 0:
self._ctx_current = 0

def is_in_context(self) -> bool:
"Returns `True` if the current context has the lock."
"Returns `True` if the current [contextvars.Context][] has the lock."
return bool(self._count and self._ctx_current and (self._ctx_var.get() == self._ctx_current))

@asynccontextmanager
async def base(self) -> AsyncGenerator[Self, Any]:
"""
Acquire the lock as a new [contextvars.Context][].

class ReentrantAsyncLock(AsyncLock):
"""
Implements a Reentrant asynchronous lock compatible with [async_kernel.caller.Caller][].
Use this to ensure exclusive access from within this [contextvars.Context][].

!!! note
- This method is not useful for the mutex variant ([async_kernel.caller.AsyncLock][]) which does this by default.

!!! example
!!! warning
Using this inside its own acquired lock will cause a deadlock.
"""
if self._reentrant:
self._ctx_var.set(0)
async with self:
yield self

```python
# Inside a coroutine running inside a thread where a [asyncio.caller.Caller][] instance is running.

lock = ReentrantAsyncLock(reentrant=True) # a reentrant lock
async with lock:
async with lock:
Caller().to_thread(...) # The lock is shared with the thread.
```
class AsyncLock(ReentrantAsyncLock):
"""
A mutex asynchronous lock that is compatible with [async_kernel.caller.Caller][].

!!! note

- The lock context can be exitied in any order.
- A 'reentrant' lock can *release* control to another context and then re-enter later for
tasks or threads called from a locked thread maintaining the same reentrant context.
- Attempting to acquire the lock from inside a locked [contextvars.Context][] will raise a [RuntimeError][].
"""

_reentrant: ClassVar[bool] = True
_reentrant: ClassVar[bool] = False
13 changes: 5 additions & 8 deletions tests/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,26 +704,23 @@ async def _locked():
unlock.set()
assert not lock._queue # pyright: ignore[reportPrivateUsage]

async def test_invald_release(self, caller):
lock = AsyncLock()
with pytest.raises(InvalidStateError):
await lock.release()

async def test_reentrant(self, caller: Caller):
lock = ReentrantAsyncLock()
lock: ReentrantAsyncLock = ReentrantAsyncLock()

async def func():
assert lock.count == 2
async with lock:
assert lock.is_in_context()
assert lock.count == 3
return True

async with lock:
async with lock.base():
assert lock.is_in_context()
assert lock.count == 1
async with lock:
await caller.call_soon(func)
with pytest.raises(TimeoutError), anyio.fail_after(0.1):
async with lock.base():
pass
assert lock.count == 0
assert not lock.is_in_context()

Expand Down
Loading