Skip to content

Commit

Permalink
[PoC] Support None values stored in caches.
Browse files Browse the repository at this point in the history
  • Loading branch information
ngnpope committed Nov 12, 2020
1 parent 4cd77f9 commit 2510b0b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 30 deletions.
36 changes: 21 additions & 15 deletions django/core/cache/backends/base.py
Expand Up @@ -3,6 +3,7 @@
import warnings

from django.core.exceptions import ImproperlyConfigured
from django.utils.functional import cached_property
from django.utils.module_loading import import_string


Expand Down Expand Up @@ -52,6 +53,8 @@ def get_key_func(key_func):


class BaseCache:
supports_get_with_default = True

def __init__(self, params):
timeout = params.get('timeout', params.get('TIMEOUT', 300))
if timeout is not None:
Expand All @@ -78,6 +81,10 @@ def __init__(self, params):
self.version = params.get('VERSION', 1)
self.key_func = get_key_func(params.get('KEY_FUNCTION'))

@cached_property
def _sentinel(self):
return object() if self.supports_get_with_default else None

def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
"""
Return the timeout value usable by this backend based upon the provided
Expand Down Expand Up @@ -151,8 +158,8 @@ def get_many(self, keys, version=None):
"""
d = {}
for k in keys:
val = self.get(k, version=version)
if val is not None:
val = self.get(k, self._sentinel, version=version)
if val is not self._sentinel:
d[k] = val
return d

Expand All @@ -165,31 +172,30 @@ def get_or_set(self, key, default, timeout=DEFAULT_TIMEOUT, version=None):
Return the value of the key stored or retrieved.
"""
val = self.get(key, version=version)
if val is None:
val = self.get(key, self._sentinel, version=version)
if val is self._sentinel:
if callable(default):
default = default()
if default is not None:
self.add(key, default, timeout=timeout, version=version)
# Fetch the value again to avoid a race condition if another
# caller added a value between the first get() and the add()
# above.
return self.get(key, default, version=version)
self.add(key, default, timeout=timeout, version=version)
# Fetch the value again to avoid a race condition if another
# caller added a value between the first get() and the add()
# above.
return self.get(key, default, version=version)
return val

def has_key(self, key, version=None):
"""
Return True if the key is in the cache and has not expired.
"""
return self.get(key, version=version) is not None
return self.get(key, self._sentinel, version=version) is not self._sentinel

def incr(self, key, delta=1, version=None):
"""
Add delta to value in the cache. If the key does not exist, raise a
ValueError exception.
"""
value = self.get(key, version=version)
if value is None:
value = self.get(key, self._sentinel, version=version)
if value is self._sentinel:
raise ValueError("Key '%s' not found" % key)
new_value = value + delta
self.set(key, new_value, version=version)
Expand Down Expand Up @@ -257,8 +263,8 @@ def incr_version(self, key, delta=1, version=None):
if version is None:
version = self.version

value = self.get(key, version=version)
if value is None:
value = self.get(key, self._sentinel, version=version)
if value is self._sentinel:
raise ValueError("Key '%s' not found" % key)

self.set(key, value, version=version + delta)
Expand Down
19 changes: 8 additions & 11 deletions django/core/cache/backends/memcached.py
Expand Up @@ -72,6 +72,9 @@ def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
def get(self, key, default=None, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
if not self.supports_get_with_default:
val = self._cache.get(key)
return default if val is self._sentinel else val
return self._cache.get(key, default)

def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
Expand Down Expand Up @@ -163,24 +166,18 @@ def validate_key(self, key):

class MemcachedCache(BaseMemcachedCache):
"An implementation of a cache binding using python-memcached"

# python-memcached doesn't support default values in get().
# https://github.com/linsomniac/python-memcached/issues/159
supports_get_with_default = False

def __init__(self, server, params):
# python-memcached ≥ 1.45 returns None for a nonexistent key in
# incr/decr(), python-memcached < 1.45 raises ValueError.
import memcache
super().__init__(server, params, library=memcache, value_not_found_exception=ValueError)
self._options = {'pickleProtocol': pickle.HIGHEST_PROTOCOL, **self._options}

def get(self, key, default=None, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
val = self._cache.get(key)
# python-memcached doesn't support default values in get().
# https://github.com/linsomniac/python-memcached/issues/159
# Remove this method if that issue is fixed.
if val is None:
return default
return val

def delete(self, key, version=None):
# python-memcached's delete() returns True when key doesn't exist.
# https://github.com/linsomniac/python-memcached/issues/170
Expand Down
52 changes: 48 additions & 4 deletions tests/cache/tests.py
Expand Up @@ -317,6 +317,9 @@ def test_get_many(self):
self.assertEqual(cache.get_many(['a', 'c', 'd']), {'a': 'a', 'c': 'c', 'd': 'd'})
self.assertEqual(cache.get_many(['a', 'b', 'e']), {'a': 'a', 'b': 'b'})
self.assertEqual(cache.get_many(iter(['a', 'b', 'e'])), {'a': 'a', 'b': 'b'})
# Test behavior with None value stored in the cache.
cache.set_many({'x': None, 'y': 1})
self.assertEqual(cache.get_many(['x', 'y']), {'x': None, 'y': 1})

def test_delete(self):
# Cache keys can be deleted
Expand All @@ -336,6 +339,10 @@ def test_has_key(self):
self.assertIs(cache.has_key("goodbye1"), False)
cache.set("no_expiry", "here", None)
self.assertIs(cache.has_key("no_expiry"), True)
# Test behavior with None value stored in the cache.
cache.set('null', None)
expected = True if cache.supports_get_with_default else False
self.assertIs(cache.has_key('null'), expected)

def test_in(self):
# The in operator can be used to inspect cache contents
Expand All @@ -353,6 +360,12 @@ def test_incr(self):
self.assertEqual(cache.incr('answer', -10), 42)
with self.assertRaises(ValueError):
cache.incr('does_not_exist')
# Test behavior with None value stored in the cache.
cache.set('null', None)
# FIXME: Memcached backends raise custom exceptions, not TypeError.
expected = TypeError if cache.supports_get_with_default else ValueError
with self.assertRaises(expected):
cache.incr('null')

def test_decr(self):
# Cache values can be decremented
Expand All @@ -364,6 +377,12 @@ def test_decr(self):
self.assertEqual(cache.decr('answer', -10), 42)
with self.assertRaises(ValueError):
cache.decr('does_not_exist')
# Test behavior with None value stored in the cache.
cache.set('null', None)
# FIXME: Memcached backends raise custom exceptions, not TypeError.
expected = TypeError if cache.supports_get_with_default else ValueError
with self.assertRaises(expected):
cache.decr('null')

def test_close(self):
self.assertTrue(hasattr(cache, 'close'))
Expand Down Expand Up @@ -911,6 +930,14 @@ def test_incr_version(self):
with self.assertRaises(ValueError):
cache.incr_version('does_not_exist')

# Test behavior with None value stored in the cache.
cache.set('null', None)
if cache.supports_get_with_default:
self.assertEqual(cache.incr_version('null'), 2)
else:
with self.assertRaises(ValueError):
cache.incr_version('null')

def test_decr_version(self):
cache.set('answer', 42, version=2)
self.assertIsNone(cache.get('answer'))
Expand All @@ -935,6 +962,14 @@ def test_decr_version(self):
with self.assertRaises(ValueError):
cache.decr_version('does_not_exist', version=2)

# Test behavior with None value stored in the cache.
cache.set('null', None, version=2)
if cache.supports_get_with_default:
self.assertEqual(cache.decr_version('null', version=2), 1)
else:
with self.assertRaises(ValueError):
cache.decr_version('null', version=2)

def test_custom_key_func(self):
# Two caches with different key functions aren't visible to each other
cache.set('answer1', 42)
Expand Down Expand Up @@ -991,7 +1026,13 @@ def test_get_or_set(self):
self.assertIsNone(cache.get('projector'))
self.assertEqual(cache.get_or_set('projector', 42), 42)
self.assertEqual(cache.get('projector'), 42)
# Test behavior with None value stored in the cache.
self.assertIsNone(cache.get_or_set('null', None))
if cache.supports_get_with_default:
# Previous get_or_set() stores None in the cache.
self.assertIsNone(cache.get('null', 'default'))
else:
self.assertEqual(cache.get('null', 'default'), 'default')

def test_get_or_set_callable(self):
def my_callable():
Expand All @@ -1000,10 +1041,13 @@ def my_callable():
self.assertEqual(cache.get_or_set('mykey', my_callable), 'value')
self.assertEqual(cache.get_or_set('mykey', my_callable()), 'value')

def test_get_or_set_callable_returning_none(self):
self.assertIsNone(cache.get_or_set('mykey', lambda: None))
# Previous get_or_set() doesn't store None in the cache.
self.assertEqual(cache.get('mykey', 'default'), 'default')
# Test behavior with None value stored in the cache.
self.assertIsNone(cache.get_or_set('null', lambda: None))
if cache.supports_get_with_default:
# Previous get_or_set() stores None in the cache.
self.assertIsNone(cache.get('null', 'default'))
else:
self.assertEqual(cache.get('null', 'default'), 'default')

def test_get_or_set_version(self):
msg = "get_or_set() missing 1 required positional argument: 'default'"
Expand Down

0 comments on commit 2510b0b

Please sign in to comment.