Skip to content

Commit e5c4073

Browse files
committed
Replace threading.local() with ContextVar for async compatibility
threading.local() doesn't propagate across asyncio.to_thread() boundaries, so per-request state like timezone activation would be lost when sync views run in the executor. ContextVar propagates correctly through both asyncio.to_thread() and asyncio.create_task(). - timezone.py: _active uses ContextVar with token-based reset in override - resolvers.py: recursion guard uses module-level ContextVar[frozenset[int]]
1 parent 79383d0 commit e5c4073

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

plain/plain/urls/resolvers.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import functools
1212
import re
13-
from threading import local
13+
from contextvars import ContextVar
1414
from typing import TYPE_CHECKING, Any
1515
from urllib.parse import quote
1616

@@ -23,6 +23,10 @@
2323
from .exceptions import NoReverseMatch, Resolver404
2424
from .patterns import RegexPattern, RoutePattern, URLPattern
2525

26+
# Tracks which URLResolver instances are currently inside _populate(),
27+
# to prevent infinite recursion when resolvers reference each other.
28+
_populating: ContextVar[frozenset[int]] = ContextVar("_populating", default=frozenset())
29+
2630
if TYPE_CHECKING:
2731
from plain.preflight import PreflightResult
2832

@@ -111,7 +115,6 @@ def __init__(
111115
self._reverse_dict: MultiValueDict = MultiValueDict()
112116
self._namespace_dict: dict[str, tuple[str, URLResolver]] = {}
113117
self._populated = False
114-
self._local = local()
115118

116119
# Set these immediately, in part so we can find routers
117120
# where the attributes weren't set correctly.
@@ -129,14 +132,15 @@ def preflight(self) -> list[PreflightResult]:
129132
return messages
130133

131134
def _populate(self) -> None:
132-
# Short-circuit if called recursively in this thread to prevent
133-
# infinite recursion. Concurrent threads may call this at the same
134-
# time and will need to continue, so set 'populating' on a
135-
# thread-local variable.
136-
if getattr(self._local, "populating", False):
135+
# Short-circuit if called recursively in this context to prevent
136+
# infinite recursion. Concurrent contexts may call this at the same
137+
# time and will need to continue, so track populating resolvers in a
138+
# context variable.
139+
current = _populating.get()
140+
if id(self) in current:
137141
return
142+
token = _populating.set(current | {id(self)})
138143
try:
139-
self._local.populating = True
140144
lookups = MultiValueDict()
141145
namespaces = {}
142146
for url_pattern in reversed(self.url_patterns):
@@ -196,7 +200,7 @@ def _populate(self) -> None:
196200
self._reverse_dict = lookups
197201
self._populated = True
198202
finally:
199-
self._local.populating = False
203+
_populating.reset(token)
200204

201205
@property
202206
def reverse_dict(self) -> MultiValueDict:

plain/plain/utils/timezone.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import functools
88
import zoneinfo
99
from contextlib import ContextDecorator
10+
from contextvars import ContextVar
1011
from datetime import UTC, datetime, time, timedelta, timezone, tzinfo
11-
from threading import local
1212
from types import TracebackType
1313

1414
from plain.runtime import settings
@@ -59,12 +59,13 @@ def get_default_timezone_name() -> str:
5959
return _get_timezone_name(get_default_timezone())
6060

6161

62-
_active = local()
62+
_active: ContextVar[tzinfo | None] = ContextVar("_active", default=None)
6363

6464

6565
def get_current_timezone() -> tzinfo:
6666
"""Return the currently active time zone as a tzinfo instance."""
67-
return getattr(_active, "value", get_default_timezone())
67+
tz = _active.get()
68+
return tz if tz is not None else get_default_timezone()
6869

6970

7071
def get_current_timezone_name() -> str:
@@ -88,32 +89,31 @@ def _get_timezone_name(timezone: tzinfo) -> str:
8889

8990
def activate(timezone: tzinfo | str) -> None:
9091
"""
91-
Set the time zone for the current thread.
92+
Set the time zone for the current context.
9293
9394
The ``timezone`` argument must be an instance of a tzinfo subclass or a
9495
time zone name.
9596
"""
9697
if isinstance(timezone, tzinfo):
97-
_active.value = timezone
98+
_active.set(timezone)
9899
elif isinstance(timezone, str):
99-
_active.value = zoneinfo.ZoneInfo(timezone)
100+
_active.set(zoneinfo.ZoneInfo(timezone))
100101
else:
101102
raise ValueError(f"Invalid timezone: {timezone!r}")
102103

103104

104105
def deactivate() -> None:
105106
"""
106-
Unset the time zone for the current thread.
107+
Unset the time zone for the current context.
107108
108109
Plain will then use the time zone defined by settings.TIME_ZONE.
109110
"""
110-
if hasattr(_active, "value"):
111-
del _active.value
111+
_active.set(None)
112112

113113

114114
class override(ContextDecorator):
115115
"""
116-
Temporarily set the time zone for the current thread.
116+
Temporarily set the time zone for the current context.
117117
118118
This is a context manager that uses plain.utils.timezone.activate()
119119
to set the timezone on entry and restores the previously active timezone
@@ -126,25 +126,22 @@ class override(ContextDecorator):
126126

127127
def __init__(self, timezone: tzinfo | str | None) -> None:
128128
self.timezone = timezone
129-
self.old_timezone: tzinfo | None = None
130129

131130
def __enter__(self) -> None:
132-
self.old_timezone = getattr(_active, "value", None)
133131
if self.timezone is None:
134-
deactivate()
132+
self._token = _active.set(None)
133+
elif isinstance(self.timezone, str):
134+
self._token = _active.set(zoneinfo.ZoneInfo(self.timezone))
135135
else:
136-
activate(self.timezone)
136+
self._token = _active.set(self.timezone)
137137

138138
def __exit__(
139139
self,
140140
exc_type: type[BaseException] | None,
141141
exc_value: BaseException | None,
142142
traceback: TracebackType | None,
143143
) -> None:
144-
if self.old_timezone is None:
145-
deactivate()
146-
else:
147-
_active.value = self.old_timezone
144+
_active.reset(self._token)
148145

149146

150147
# Utilities

0 commit comments

Comments
 (0)