-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a threadlocal drop-in replacement as "local"
- Loading branch information
1 parent
5f45237
commit 768a071
Showing
4 changed files
with
143 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from .sync import SyncToAsync | ||
|
||
|
||
class Local: | ||
""" | ||
A drop-in replacement for threading.locals that also works with asyncio | ||
Tasks (via the current_task asyncio method), and correctly passes locals | ||
in and out of threads made with sync_to_async. | ||
This doesn't use contextvars as it needs to support 3.6. Once it can support | ||
3.7 only, we can then reimplement the storage much more nicely. | ||
""" | ||
|
||
def __init__(self): | ||
self._storage = {} | ||
|
||
def _get_context_id(self): | ||
""" | ||
Get the ID we should use for looking up variables | ||
""" | ||
# First, pull the current task if we can | ||
context_id = SyncToAsync.get_current_task() | ||
# If that fails, then try and pull the proxy ID from a threadlocal | ||
if context_id is None: | ||
context_id = SyncToAsync.threadlocal.current_task | ||
# If that fails, error | ||
if context_id is None: | ||
raise RuntimeError("Cannot find task context for Local storage") | ||
return context_id | ||
|
||
def __getattr__(self, key): | ||
context_id = self._get_context_id() | ||
if key in self._storage.get(context_id, {}): | ||
return self._storage[context_id][key] | ||
else: | ||
raise AttributeError("%r object has no attribute %r" % (self, key)) | ||
|
||
def __setattr__(self, key, value): | ||
if key == "_storage": | ||
super().__setattr__(key, value) | ||
self._storage.setdefault(self._get_context_id(), {})[key] = value | ||
|
||
def __delattr__(self, key): | ||
context_id = self._get_context_id() | ||
if key in self._storage.get(context_id, {}): | ||
del self._storage[context_id][key] | ||
else: | ||
raise AttributeError("%r object has no attribute %r" % (self, key)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pytest | ||
|
||
from asgiref.local import Local | ||
from asgiref.sync import sync_to_async | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_local_task(): | ||
""" | ||
Tests that local works just inside a normal task context | ||
""" | ||
|
||
test_local = Local() | ||
# Unassigned should be an error | ||
with pytest.raises(AttributeError): | ||
test_local.foo == 1 | ||
# Assign and check it persists | ||
test_local.foo = 1 | ||
assert test_local.foo == 1 | ||
# Delete and check it errors again | ||
del test_local.foo | ||
with pytest.raises(AttributeError): | ||
test_local.foo == 1 | ||
|
||
|
||
def test_local_thread(): | ||
""" | ||
Tests that local works just inside a normal thread context | ||
""" | ||
|
||
test_local = Local() | ||
# Unassigned should be an error | ||
with pytest.raises(AttributeError): | ||
test_local.foo == 2 | ||
# Assign and check it persists | ||
test_local.foo = 2 | ||
assert test_local.foo == 2 | ||
# Delete and check it errors again | ||
del test_local.foo | ||
with pytest.raises(AttributeError): | ||
test_local.foo == 2 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_local_task_to_sync(): | ||
""" | ||
Tests that local carries through sync_to_async | ||
""" | ||
# Set up the local | ||
test_local = Local() | ||
test_local.foo = 3 | ||
# Look at it in a sync context | ||
def sync_function(): | ||
assert test_local.foo == 3 | ||
test_local.foo = "phew, done" | ||
|
||
await sync_to_async(sync_function)() | ||
# Check the value passed out again | ||
assert test_local.foo == "phew, done" |