Skip to content

Commit

Permalink
Fixed #33060 -- Added BaseCache.make_and_validate_key() hook.
Browse files Browse the repository at this point in the history
This helper function reduces the amount of duplicated code and makes it
easier to ensure that we always validate the keys.
  • Loading branch information
ngnpope authored and felixxm committed Sep 3, 2021
1 parent 3316fe9 commit 670d5ff
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 75 deletions.
6 changes: 6 additions & 0 deletions django/core/cache/backends/base.py
Expand Up @@ -276,6 +276,12 @@ def close(self, **kwargs):
"""Close the cache connection"""
pass

def make_and_validate_key(self, key, version=None):
"""Helper to make and validate keys."""
key = self.make_key(key, version=version)
self.validate_key(key)
return key


def memcache_key_warnings(key):
if len(key) > MEMCACHE_MAX_KEY_LENGTH:
Expand Down
29 changes: 8 additions & 21 deletions django/core/cache/backends/db.py
Expand Up @@ -54,11 +54,7 @@ def get_many(self, keys, version=None):
if not keys:
return {}

key_map = {}
for key in keys:
new_key = self.make_key(key, version)
self.validate_key(new_key)
key_map[new_key] = key
key_map = {self.make_and_validate_key(key, version=version): key for key in keys}

db = router.db_for_read(self.cache_model_class)
connection = connections[db]
Expand Down Expand Up @@ -96,18 +92,15 @@ def get_many(self, keys, version=None):
return result

def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
self._base_set('set', key, value, timeout)

def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return self._base_set('add', key, value, timeout)

def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return self._base_set('touch', key, None, timeout)

def _base_set(self, mode, key, value, timeout=DEFAULT_TIMEOUT):
Expand Down Expand Up @@ -197,17 +190,12 @@ def _base_set(self, mode, key, value, timeout=DEFAULT_TIMEOUT):
return True

def delete(self, key, version=None):
key = self.make_key(key, version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return self._base_delete_many([key])

def delete_many(self, keys, version=None):
key_list = []
for key in keys:
key = self.make_key(key, version)
self.validate_key(key)
key_list.append(key)
self._base_delete_many(key_list)
keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._base_delete_many(keys)

def _base_delete_many(self, keys):
if not keys:
Expand All @@ -230,8 +218,7 @@ def _base_delete_many(self, keys):
return bool(cursor.rowcount)

def has_key(self, key, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)

db = router.db_for_read(self.cache_model_class)
connection = connections[db]
Expand Down
18 changes: 6 additions & 12 deletions django/core/cache/backends/dummy.py
Expand Up @@ -8,32 +8,26 @@ def __init__(self, host, *args, **kwargs):
super().__init__(*args, **kwargs)

def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
self.make_and_validate_key(key, version=version)
return True

def get(self, key, default=None, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
self.make_and_validate_key(key, version=version)
return default

def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
self.make_and_validate_key(key, version=version)

def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
self.make_and_validate_key(key, version=version)
return False

def delete(self, key, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
self.make_and_validate_key(key, version=version)
return False

def has_key(self, key, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
self.make_and_validate_key(key, version=version)
return False

def clear(self):
Expand Down
3 changes: 1 addition & 2 deletions django/core/cache/backends/filebased.py
Expand Up @@ -127,8 +127,7 @@ def _key_to_file(self, key, version=None):
Convert a key into a cache file path. Basically this is the
root cache path joined with the md5sum of the key and a suffix.
"""
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return os.path.join(self._dir, ''.join(
[hashlib.md5(key.encode()).hexdigest(), self.cache_suffix]))

Expand Down
21 changes: 7 additions & 14 deletions django/core/cache/backends/locmem.py
Expand Up @@ -23,8 +23,7 @@ def __init__(self, name, params):
self._lock = _locks.setdefault(name, Lock())

def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
pickled = pickle.dumps(value, self.pickle_protocol)
with self._lock:
if self._has_expired(key):
Expand All @@ -33,8 +32,7 @@ def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
return False

def get(self, key, default=None, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
Expand All @@ -51,24 +49,21 @@ def _set(self, key, value, timeout=DEFAULT_TIMEOUT):
self._expire_info[key] = self.get_backend_timeout(timeout)

def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
pickled = pickle.dumps(value, self.pickle_protocol)
with self._lock:
self._set(key, pickled, timeout)

def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
return False
self._expire_info[key] = self.get_backend_timeout(timeout)
return True

def incr(self, key, delta=1, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
Expand All @@ -82,8 +77,7 @@ def incr(self, key, delta=1, version=None):
return new_value

def has_key(self, key, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
Expand Down Expand Up @@ -113,8 +107,7 @@ def _delete(self, key):
return True

def delete(self, key, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
with self._lock:
return self._delete(key)

Expand Down
38 changes: 12 additions & 26 deletions django/core/cache/backends/memcached.py
Expand Up @@ -67,36 +67,29 @@ def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
return int(timeout)

def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return self._cache.add(key, value, self.get_backend_timeout(timeout))

def get(self, key, default=None, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return self._cache.get(key, default)

def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
if not self._cache.set(key, value, self.get_backend_timeout(timeout)):
# make sure the key doesn't keep its old value in case of failure to set (memcached's 1MB limit)
self._cache.delete(key)

def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return bool(self._cache.touch(key, self.get_backend_timeout(timeout)))

def delete(self, key, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return bool(self._cache.delete(key))

def get_many(self, keys, version=None):
key_map = {self.make_key(key, version=version): key for key in keys}
for key in key_map:
self.validate_key(key)
key_map = {self.make_and_validate_key(key, version=version): key for key in keys}
ret = self._cache.get_multi(key_map.keys())
return {key_map[k]: v for k, v in ret.items()}

Expand All @@ -105,8 +98,7 @@ def close(self, **kwargs):
self._cache.disconnect_all()

def incr(self, key, delta=1, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
try:
# Memcached doesn't support negative delta.
if delta < 0:
Expand All @@ -126,17 +118,14 @@ def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
safe_data = {}
original_keys = {}
for key, value in data.items():
safe_key = self.make_key(key, version=version)
self.validate_key(safe_key)
safe_key = self.make_and_validate_key(key, version=version)
safe_data[safe_key] = value
original_keys[safe_key] = key
failed_keys = self._cache.set_multi(safe_data, self.get_backend_timeout(timeout))
return [original_keys[k] for k in failed_keys]

def delete_many(self, keys, version=None):
keys = [self.make_key(key, version=version) for key in keys]
for key in keys:
self.validate_key(key)
keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._cache.delete_multi(keys)

def clear(self):
Expand Down Expand Up @@ -167,8 +156,7 @@ def __init__(self, server, params):
self._options = {'pickleProtocol': pickle.HIGHEST_PROTOCOL, **self._options}

def get(self, key, default=None, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
val = self._cache.get(key)
# python-memcached doesn't support default values in get().
# https://github.com/linsomniac/python-memcached/issues/159
Expand All @@ -181,8 +169,7 @@ def delete(self, key, version=None):
# python-memcached's delete() returns True when key doesn't exist.
# https://github.com/linsomniac/python-memcached/issues/170
# Call _deletetouch() without the NOT_FOUND in expected results.
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
return bool(self._cache._deletetouch([b'DELETED'], 'delete', key))


Expand All @@ -200,8 +187,7 @@ def client_servers(self):
return output

def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_key(key, version=version)
self.validate_key(key)
key = self.make_and_validate_key(key, version=version)
if timeout == 0:
return self._cache.delete(key)
return self._cache.touch(key, self.get_backend_timeout(timeout))
Expand Down

0 comments on commit 670d5ff

Please sign in to comment.