Skip to content

Commit 6fb65e6

Browse files
authored
Refactored ReentrantAsyncLock and AsyncLock with a new method 'base'. (#142)
* Update docstring for AsyncLock. * Swap ReentrantAsyncLock with AsyncLock so AsyncLock is subclassed from ReentrantAsyncLock. * Added method ReentrantAsyncLock.base. * Fix typo. * Drop pytest reruns since failing tests will continue to fail.
1 parent ed94cb7 commit 6fb65e6

File tree

3 files changed

+44
-47
lines changed

3 files changed

+44
-47
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141

4242
- name: Run tests
4343
timeout-minutes: 5
44-
run: uv run pytest --reruns 3 --reruns-delay 1 --maxfail 10 --rerun-except AssertionError -vvl
44+
run: uv run pytest -vvl
4545

4646
- name: Type checking with basedpyright
4747
run: uv run basedpyright
@@ -61,7 +61,7 @@ jobs:
6161

6262
- name: Run tests with coverage
6363
timeout-minutes: 5
64-
run: uv run pytest --reruns 2 --rerun-except AssertionError --maxfail 10 --cov --cov-report=xml --junitxml=junit.xml -o junit_family=legacy --cov-fail-under=100
64+
run: uv run pytest --cov --cov-report=xml --junitxml=junit.xml -o junit_family=legacy --cov-fail-under=100
6565

6666
- name: Upload coverage reports to Codecov
6767
if: ${{ !cancelled() }}

src/async_kernel/caller.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,24 +1076,29 @@ async def wait(
10761076
return done, pending
10771077

10781078

1079-
class AsyncLock:
1079+
class ReentrantAsyncLock:
10801080
"""
1081-
Implements a mutex asynchronous lock that is compatible with [async_kernel.caller.Caller][].
1081+
A Reentrant asynchronous lock compatible with [async_kernel.caller.Caller][].
1082+
1083+
The lock is reentrant in terms of [contextvars.Context][].
10821084
10831085
!!! note
10841086
1085-
- Attempting to lock a 'mutuex' configured lock that is *locked* will raise a [RuntimeError][].
1087+
- The lock context can be exitied in any order.
1088+
- The context can potentially leak.
1089+
- A 'reentrant' lock can *release* control to another context and then re-enter later for
1090+
tasks or threads called from a locked thread maintaining the same reentrant context.
10861091
"""
10871092

1088-
_reentrant: ClassVar[bool] = False
1093+
_reentrant: ClassVar[bool] = True
10891094
_count: int = 0
10901095
_ctx_count: int = 0
10911096
_ctx_current: int = 0
10921097
_releasing: bool = False
10931098

10941099
def __init__(self):
10951100
self._ctx_var: contextvars.ContextVar[int] = contextvars.ContextVar(f"Lock:{id(self)}", default=0)
1096-
self._queue: deque[tuple[int, Future[Future | None]]] = deque()
1101+
self._queue: deque[tuple[int, Future[bool]]] = deque()
10971102

10981103
@override
10991104
def __repr__(self) -> str:
@@ -1115,13 +1120,13 @@ async def acquire(self) -> Self:
11151120
"""
11161121
Acquire a lock.
11171122
1118-
If the lock is reentrant the internal counter increments to share the lock.
1123+
The internal counter increments when the lock is entered.
11191124
"""
11201125
if not self._reentrant and self.is_in_context():
11211126
msg = "Already locked and not reentrant!"
11221127
raise RuntimeError(msg)
11231128
# Get the context.
1124-
if not self._reentrant or not (ctx := self._ctx_var.get()):
1129+
if (self._ctx_count == 0) or not self._reentrant or not (ctx := self._ctx_var.get()):
11251130
self._ctx_count = ctx = self._ctx_count + 1
11261131
self._ctx_var.set(ctx)
11271132
# Check if we can lock or re-enter an active lock.
@@ -1130,21 +1135,20 @@ async def acquire(self) -> Self:
11301135
self._ctx_current = ctx
11311136
return self
11321137
# Join the queue.
1133-
k: tuple[int, Future[None | Future[Future[None] | None]]] = ctx, Future()
1138+
k: tuple[int, Future[bool]] = ctx, Future()
11341139
self._queue.append(k)
11351140
try:
1136-
fut = await k[1]
1141+
result = await k[1]
11371142
finally:
11381143
if k in self._queue:
11391144
self._queue.remove(k)
1140-
if fut:
1145+
if result:
11411146
self._ctx_current = ctx
1142-
fut.set_result(None)
11431147
if self._reentrant:
11441148
for k in tuple(self._queue):
11451149
if k[0] == ctx:
11461150
self._queue.remove(k)
1147-
k[1].set_result(None)
1151+
k[1].set_result(False)
11481152
self._count += 1
11491153
self._releasing = False
11501154
return self
@@ -1155,49 +1159,45 @@ async def release(self) -> None:
11551159
11561160
If the current depth==1 the lock will be passed to the next queued or released if there isn't one.
11571161
"""
1158-
if not self.is_in_context():
1159-
raise InvalidStateError
11601162
if self._count == 1 and self._queue and not self._releasing:
11611163
self._releasing = True
11621164
self._ctx_var.set(0)
1163-
try:
1164-
fut = Future()
1165-
k = self._queue.popleft()
1166-
k[1].set_result(fut)
1167-
await k[1]
1168-
except Exception:
1169-
self._releasing = False
1165+
self._queue.popleft()[1].set_result(True)
11701166
else:
11711167
self._count -= 1
11721168
if self._count == 0:
11731169
self._ctx_current = 0
11741170

11751171
def is_in_context(self) -> bool:
1176-
"Returns `True` if the current context has the lock."
1172+
"Returns `True` if the current [contextvars.Context][] has the lock."
11771173
return bool(self._count and self._ctx_current and (self._ctx_var.get() == self._ctx_current))
11781174

1175+
@asynccontextmanager
1176+
async def base(self) -> AsyncGenerator[Self, Any]:
1177+
"""
1178+
Acquire the lock as a new [contextvars.Context][].
11791179
1180-
class ReentrantAsyncLock(AsyncLock):
1181-
"""
1182-
Implements a Reentrant asynchronous lock compatible with [async_kernel.caller.Caller][].
1180+
Use this to ensure exclusive access from within this [contextvars.Context][].
11831181
1182+
!!! note
1183+
- This method is not useful for the mutex variant ([async_kernel.caller.AsyncLock][]) which does this by default.
11841184
1185-
!!! example
1185+
!!! warning
1186+
Using this inside its own acquired lock will cause a deadlock.
1187+
"""
1188+
if self._reentrant:
1189+
self._ctx_var.set(0)
1190+
async with self:
1191+
yield self
11861192

1187-
```python
1188-
# Inside a coroutine running inside a thread where a [asyncio.caller.Caller][] instance is running.
11891193

1190-
lock = ReentrantAsyncLock(reentrant=True) # a reentrant lock
1191-
async with lock:
1192-
async with lock:
1193-
Caller().to_thread(...) # The lock is shared with the thread.
1194-
```
1194+
class AsyncLock(ReentrantAsyncLock):
1195+
"""
1196+
A mutex asynchronous lock that is compatible with [async_kernel.caller.Caller][].
11951197
11961198
!!! note
11971199
1198-
- The lock context can be exitied in any order.
1199-
- A 'reentrant' lock can *release* control to another context and then re-enter later for
1200-
tasks or threads called from a locked thread maintaining the same reentrant context.
1200+
- Attempting to acquire the lock from inside a locked [contextvars.Context][] will raise a [RuntimeError][].
12011201
"""
12021202

1203-
_reentrant: ClassVar[bool] = True
1203+
_reentrant: ClassVar[bool] = False

tests/test_caller.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -704,26 +704,23 @@ async def _locked():
704704
unlock.set()
705705
assert not lock._queue # pyright: ignore[reportPrivateUsage]
706706

707-
async def test_invald_release(self, caller):
708-
lock = AsyncLock()
709-
with pytest.raises(InvalidStateError):
710-
await lock.release()
711-
712707
async def test_reentrant(self, caller: Caller):
713-
lock = ReentrantAsyncLock()
708+
lock: ReentrantAsyncLock = ReentrantAsyncLock()
714709

715710
async def func():
716711
assert lock.count == 2
717712
async with lock:
718713
assert lock.is_in_context()
719714
assert lock.count == 3
720-
return True
721715

722-
async with lock:
716+
async with lock.base():
723717
assert lock.is_in_context()
724718
assert lock.count == 1
725719
async with lock:
726720
await caller.call_soon(func)
721+
with pytest.raises(TimeoutError), anyio.fail_after(0.1):
722+
async with lock.base():
723+
pass
727724
assert lock.count == 0
728725
assert not lock.is_in_context()
729726

0 commit comments

Comments
 (0)