diff --git a/django/core/cache/backends/base.py b/django/core/cache/backends/base.py index f360f4f57abd0..af531812bd831 100644 --- a/django/core/cache/backends/base.py +++ b/django/core/cache/backends/base.py @@ -3,6 +3,7 @@ import warnings from django.core.exceptions import ImproperlyConfigured +from django.utils.functional import cached_property from django.utils.module_loading import import_string @@ -52,6 +53,8 @@ def get_key_func(key_func): class BaseCache: + supports_get_with_default = True + def __init__(self, params): timeout = params.get('timeout', params.get('TIMEOUT', 300)) if timeout is not None: @@ -78,6 +81,10 @@ def __init__(self, params): self.version = params.get('VERSION', 1) self.key_func = get_key_func(params.get('KEY_FUNCTION')) + @cached_property + def _sentinel(self): + return object() if self.supports_get_with_default else None + def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT): """ Return the timeout value usable by this backend based upon the provided @@ -151,8 +158,8 @@ def get_many(self, keys, version=None): """ d = {} for k in keys: - val = self.get(k, version=version) - if val is not None: + val = self.get(k, self._sentinel, version=version) + if val is not self._sentinel: d[k] = val return d @@ -165,31 +172,30 @@ def get_or_set(self, key, default, timeout=DEFAULT_TIMEOUT, version=None): Return the value of the key stored or retrieved. """ - val = self.get(key, version=version) - if val is None: + val = self.get(key, self._sentinel, version=version) + if val is self._sentinel: if callable(default): default = default() - if default is not None: - self.add(key, default, timeout=timeout, version=version) - # Fetch the value again to avoid a race condition if another - # caller added a value between the first get() and the add() - # above. - return self.get(key, default, version=version) + self.add(key, default, timeout=timeout, version=version) + # Fetch the value again to avoid a race condition if another + # caller added a value between the first get() and the add() + # above. + return self.get(key, default, version=version) return val def has_key(self, key, version=None): """ Return True if the key is in the cache and has not expired. """ - return self.get(key, version=version) is not None + return self.get(key, self._sentinel, version=version) is not self._sentinel def incr(self, key, delta=1, version=None): """ Add delta to value in the cache. If the key does not exist, raise a ValueError exception. """ - value = self.get(key, version=version) - if value is None: + value = self.get(key, self._sentinel, version=version) + if value is self._sentinel: raise ValueError("Key '%s' not found" % key) new_value = value + delta self.set(key, new_value, version=version) @@ -257,8 +263,8 @@ def incr_version(self, key, delta=1, version=None): if version is None: version = self.version - value = self.get(key, version=version) - if value is None: + value = self.get(key, self._sentinel, version=version) + if value is self._sentinel: raise ValueError("Key '%s' not found" % key) self.set(key, value, version=version + delta) diff --git a/django/core/cache/backends/memcached.py b/django/core/cache/backends/memcached.py index cc5648bb1ca28..f87a77baadccb 100644 --- a/django/core/cache/backends/memcached.py +++ b/django/core/cache/backends/memcached.py @@ -72,6 +72,9 @@ def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): def get(self, key, default=None, version=None): key = self.make_key(key, version=version) self.validate_key(key) + if not self.supports_get_with_default: + val = self._cache.get(key) + return default if val is self._sentinel else val return self._cache.get(key, default) def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): @@ -163,6 +166,11 @@ def validate_key(self, key): class MemcachedCache(BaseMemcachedCache): "An implementation of a cache binding using python-memcached" + + # python-memcached doesn't support default values in get(). + # https://github.com/linsomniac/python-memcached/issues/159 + supports_get_with_default = False + def __init__(self, server, params): # python-memcached ≥ 1.45 returns None for a nonexistent key in # incr/decr(), python-memcached < 1.45 raises ValueError. @@ -170,17 +178,6 @@ def __init__(self, server, params): super().__init__(server, params, library=memcache, value_not_found_exception=ValueError) self._options = {'pickleProtocol': pickle.HIGHEST_PROTOCOL, **self._options} - def get(self, key, default=None, version=None): - key = self.make_key(key, version=version) - self.validate_key(key) - val = self._cache.get(key) - # python-memcached doesn't support default values in get(). - # https://github.com/linsomniac/python-memcached/issues/159 - # Remove this method if that issue is fixed. - if val is None: - return default - return val - def delete(self, key, version=None): # python-memcached's delete() returns True when key doesn't exist. # https://github.com/linsomniac/python-memcached/issues/170 diff --git a/tests/cache/tests.py b/tests/cache/tests.py index 367d2d7119c4a..7d018e5a80ac8 100644 --- a/tests/cache/tests.py +++ b/tests/cache/tests.py @@ -317,6 +317,9 @@ def test_get_many(self): self.assertEqual(cache.get_many(['a', 'c', 'd']), {'a': 'a', 'c': 'c', 'd': 'd'}) self.assertEqual(cache.get_many(['a', 'b', 'e']), {'a': 'a', 'b': 'b'}) self.assertEqual(cache.get_many(iter(['a', 'b', 'e'])), {'a': 'a', 'b': 'b'}) + # Test behavior with None value stored in the cache. + cache.set_many({'x': None, 'y': 1}) + self.assertEqual(cache.get_many(['x', 'y']), {'x': None, 'y': 1}) def test_delete(self): # Cache keys can be deleted @@ -336,6 +339,10 @@ def test_has_key(self): self.assertIs(cache.has_key("goodbye1"), False) cache.set("no_expiry", "here", None) self.assertIs(cache.has_key("no_expiry"), True) + # Test behavior with None value stored in the cache. + cache.set('null', None) + expected = True if cache.supports_get_with_default else False + self.assertIs(cache.has_key('null'), expected) def test_in(self): # The in operator can be used to inspect cache contents @@ -353,6 +360,12 @@ def test_incr(self): self.assertEqual(cache.incr('answer', -10), 42) with self.assertRaises(ValueError): cache.incr('does_not_exist') + # Test behavior with None value stored in the cache. + cache.set('null', None) + # FIXME: Memcached backends raise custom exceptions, not TypeError. + expected = TypeError if cache.supports_get_with_default else ValueError + with self.assertRaises(expected): + cache.incr('null') def test_decr(self): # Cache values can be decremented @@ -364,6 +377,12 @@ def test_decr(self): self.assertEqual(cache.decr('answer', -10), 42) with self.assertRaises(ValueError): cache.decr('does_not_exist') + # Test behavior with None value stored in the cache. + cache.set('null', None) + # FIXME: Memcached backends raise custom exceptions, not TypeError. + expected = TypeError if cache.supports_get_with_default else ValueError + with self.assertRaises(expected): + cache.decr('null') def test_close(self): self.assertTrue(hasattr(cache, 'close')) @@ -911,6 +930,14 @@ def test_incr_version(self): with self.assertRaises(ValueError): cache.incr_version('does_not_exist') + # Test behavior with None value stored in the cache. + cache.set('null', None) + if cache.supports_get_with_default: + self.assertEqual(cache.incr_version('null'), 2) + else: + with self.assertRaises(ValueError): + cache.incr_version('null') + def test_decr_version(self): cache.set('answer', 42, version=2) self.assertIsNone(cache.get('answer')) @@ -935,6 +962,14 @@ def test_decr_version(self): with self.assertRaises(ValueError): cache.decr_version('does_not_exist', version=2) + # Test behavior with None value stored in the cache. + cache.set('null', None, version=2) + if cache.supports_get_with_default: + self.assertEqual(cache.decr_version('null', version=2), 1) + else: + with self.assertRaises(ValueError): + cache.decr_version('null', version=2) + def test_custom_key_func(self): # Two caches with different key functions aren't visible to each other cache.set('answer1', 42) @@ -991,7 +1026,13 @@ def test_get_or_set(self): self.assertIsNone(cache.get('projector')) self.assertEqual(cache.get_or_set('projector', 42), 42) self.assertEqual(cache.get('projector'), 42) + # Test behavior with None value stored in the cache. self.assertIsNone(cache.get_or_set('null', None)) + if cache.supports_get_with_default: + # Previous get_or_set() stores None in the cache. + self.assertIsNone(cache.get('null', 'default')) + else: + self.assertEqual(cache.get('null', 'default'), 'default') def test_get_or_set_callable(self): def my_callable(): @@ -1000,10 +1041,13 @@ def my_callable(): self.assertEqual(cache.get_or_set('mykey', my_callable), 'value') self.assertEqual(cache.get_or_set('mykey', my_callable()), 'value') - def test_get_or_set_callable_returning_none(self): - self.assertIsNone(cache.get_or_set('mykey', lambda: None)) - # Previous get_or_set() doesn't store None in the cache. - self.assertEqual(cache.get('mykey', 'default'), 'default') + # Test behavior with None value stored in the cache. + self.assertIsNone(cache.get_or_set('null', lambda: None)) + if cache.supports_get_with_default: + # Previous get_or_set() stores None in the cache. + self.assertIsNone(cache.get('null', 'default')) + else: + self.assertEqual(cache.get('null', 'default'), 'default') def test_get_or_set_version(self): msg = "get_or_set() missing 1 required positional argument: 'default'"