Skip to content

Commit

Permalink
Add a threadlocal drop-in replacement as "local"
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgodwin committed Apr 13, 2019
1 parent 5f45237 commit 768a071
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 2 deletions.
9 changes: 9 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ Django view system with SyncToAsync to allow it to run inside the (asynchronous)
ASGI server.


Threadlocal replacement
-----------------------

This is a drop-in replacement for ``threading.local`` that works with both
threads and asyncio Tasks. Even better, it will proxy values through from a
task-local context to a thread-local context when you use ``sync_to_async``
to run things in a threadpool.


Server base classes
-------------------

Expand Down
48 changes: 48 additions & 0 deletions asgiref/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from .sync import SyncToAsync


class Local:
"""
A drop-in replacement for threading.locals that also works with asyncio
Tasks (via the current_task asyncio method), and correctly passes locals
in and out of threads made with sync_to_async.
This doesn't use contextvars as it needs to support 3.6. Once it can support
3.7 only, we can then reimplement the storage much more nicely.
"""

def __init__(self):
self._storage = {}

def _get_context_id(self):
"""
Get the ID we should use for looking up variables
"""
# First, pull the current task if we can
context_id = SyncToAsync.get_current_task()
# If that fails, then try and pull the proxy ID from a threadlocal
if context_id is None:
context_id = SyncToAsync.threadlocal.current_task
# If that fails, error
if context_id is None:
raise RuntimeError("Cannot find task context for Local storage")
return context_id

def __getattr__(self, key):
context_id = self._get_context_id()
if key in self._storage.get(context_id, {}):
return self._storage[context_id][key]
else:
raise AttributeError("%r object has no attribute %r" % (self, key))

def __setattr__(self, key, value):
if key == "_storage":
super().__setattr__(key, value)
self._storage.setdefault(self._get_context_id(), {})[key] = value

def __delattr__(self, key):
context_id = self._get_context_id()
if key in self._storage.get(context_id, {}):
del self._storage[context_id][key]
else:
raise AttributeError("%r object has no attribute %r" % (self, key))
29 changes: 27 additions & 2 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,15 @@ async def __call__(self, *args, **kwargs):
func = self.func

future = loop.run_in_executor(
None, functools.partial(self.thread_handler, loop, func, *args, **kwargs)
None,
functools.partial(
self.thread_handler,
loop,
self.get_current_task(),
func,
*args,
**kwargs
),
)
return await asyncio.wait_for(future, timeout=None)

Expand All @@ -127,15 +135,32 @@ def __get__(self, parent, objtype):
"""
return functools.partial(self.__call__, parent)

def thread_handler(self, loop, func, *args, **kwargs):
def thread_handler(self, loop, current_task, func, *args, **kwargs):
"""
Wraps the sync application with exception handling.
"""
# Set the threadlocal for AsyncToSync
self.threadlocal.main_event_loop = loop
# Set the threadlocal for task mapping (used for locals)
self.threadlocal.current_task = current_task
# Run the function
return func(*args, **kwargs)

@staticmethod
def get_current_task():
"""
Cross-version implementation of asyncio.current_task()
"""
if hasattr(asyncio, "current_task"):
# Python 3.7 and up
try:
return asyncio.current_task()
except RuntimeError:
return None
else:
# Python 3.6
return asyncio.Task.current_task


# Lowercase is more sensible for most things
sync_to_async = SyncToAsync
Expand Down
59 changes: 59 additions & 0 deletions tests/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

from asgiref.local import Local
from asgiref.sync import sync_to_async


@pytest.mark.asyncio
async def test_local_task():
"""
Tests that local works just inside a normal task context
"""

test_local = Local()
# Unassigned should be an error
with pytest.raises(AttributeError):
test_local.foo == 1
# Assign and check it persists
test_local.foo = 1
assert test_local.foo == 1
# Delete and check it errors again
del test_local.foo
with pytest.raises(AttributeError):
test_local.foo == 1


def test_local_thread():
"""
Tests that local works just inside a normal thread context
"""

test_local = Local()
# Unassigned should be an error
with pytest.raises(AttributeError):
test_local.foo == 2
# Assign and check it persists
test_local.foo = 2
assert test_local.foo == 2
# Delete and check it errors again
del test_local.foo
with pytest.raises(AttributeError):
test_local.foo == 2


@pytest.mark.asyncio
async def test_local_task_to_sync():
"""
Tests that local carries through sync_to_async
"""
# Set up the local
test_local = Local()
test_local.foo = 3
# Look at it in a sync context
def sync_function():
assert test_local.foo == 3
test_local.foo = "phew, done"

await sync_to_async(sync_function)()
# Check the value passed out again
assert test_local.foo == "phew, done"

0 comments on commit 768a071

Please sign in to comment.