Skip to content

Commit

Permalink
Clean up local storage for dead threads/tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgodwin committed Apr 13, 2019
1 parent 7b9f5d7 commit 0ad2c43
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
32 changes: 30 additions & 2 deletions asgiref/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import threading
import time

from .sync import SyncToAsync

Expand All @@ -13,8 +15,12 @@ class Local:
3.7 only, we can then reimplement the storage much more nicely.
"""

CLEANUP_INTERVAL = 60 # seconds

def __init__(self):
self._storage = {}
self._last_cleanup = time.time()
self._clean_lock = threading.Lock()

@staticmethod
def _get_context_id():
Expand All @@ -37,6 +43,27 @@ def _get_context_id():
raise RuntimeError("Cannot find task context for Local storage")
return context_id

def _cleanup(self):
"""
Cleans up any references to dead threads or tasks
"""
for key in list(self._storage.keys()):
if isinstance(key, threading.Thread):
if not key.is_alive():
del self._storage[key]
elif isinstance(key, asyncio.Task):
if key.done():
del self._storage[key]
self._last_cleanup = time.time()

def _maybe_cleanup(self):
"""
Cleans up if enough time has passed
"""
if time.time() - self._last_cleanup > self.CLEANUP_INTERVAL:
with self._clean_lock:
self._cleanup()

def __getattr__(self, key):
context_id = self._get_context_id()
if key in self._storage.get(context_id, {}):
Expand All @@ -45,8 +72,9 @@ def __getattr__(self, key):
raise AttributeError("%r object has no attribute %r" % (self, key))

def __setattr__(self, key, value):
if key == "_storage":
super().__setattr__(key, value)
if key in ("_storage", "_last_cleanup", "_clean_lock"):
return super().__setattr__(key, value)
self._maybe_cleanup()
self._storage.setdefault(self._get_context_id(), {})[key] = value

def __delattr__(self, key):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_local.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import threading

import pytest
Expand Down Expand Up @@ -99,3 +100,33 @@ def sync_function():
await sync_to_async(sync_function)()
# Check the value passed out again
assert test_local.foo == "phew, done"


@pytest.mark.asyncio
async def test_local_cleanup():
"""
Tests that local cleans up dead threads and tasks
"""
# Set up the local
test_local = Local()
# Assign in a thread
class TestThread(threading.Thread):
def run(self):
test_local.foo = 456

thread = TestThread()
thread.start()
thread.join()
# Assign in a Task
async def test_task():
test_local.foo = 456

test_future = asyncio.ensure_future(test_task())
await test_future
# Check there are two things in the storage
assert len(test_local._storage) == 2
# Force cleanup
test_local._last_cleanup = 0
test_local.foo = 1
# There should now only be one thing (this task) in the storage
assert len(test_local._storage) == 1

0 comments on commit 0ad2c43

Please sign in to comment.