From bb70bcd2492e01cfa57ffe32961820fb53e4ab27 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 10 Jan 2018 14:39:49 -0600 Subject: [PATCH 01/24] Started implementing connection error throws --- .gitignore | 1 + fakeredis.py | 126 +++++++++++++++++++++++++++++++++++++++++----- test_fakeredis.py | 25 +++++++++ 3 files changed, 139 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index c1bdfc9..4ad3cab 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ dump.rdb extras/* .tox *.pyc +.idea diff --git a/fakeredis.py b/fakeredis.py index 29e7196..8b0c799 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -211,6 +211,15 @@ def _patch_responses(obj): setattr(obj, attr_name, func) +def check_conn(func): + """Used to mock connection errors""" + def func_wrapper(*args, **kwargs): + if not args[0]._connected: + raise redis.ConnectionError + return func(*args, **kwargs) + return func_wrapper + + class FakeStrictRedis(object): @classmethod def from_url(cls, url, db=None, **kwargs): @@ -223,7 +232,7 @@ def from_url(cls, url, db=None, **kwargs): return cls(db=db, **kwargs) def __init__(self, db=0, charset='utf-8', errors='strict', - decode_responses=False, **kwargs): + decode_responses=False, connected=True, **kwargs): if db not in DATABASES: DATABASES[db] = _StrKeyDict() self._db = DATABASES[db] @@ -232,13 +241,16 @@ def __init__(self, db=0, charset='utf-8', errors='strict', self._encoding_errors = errors self._pubsubs = [] self._decode_responses = decode_responses + self._connected = connected if decode_responses: _patch_responses(self) + @check_conn def flushdb(self): DATABASES[self._db_num].clear() return True + @check_conn def flushall(self): for db in DATABASES: DATABASES[db].clear() @@ -246,11 +258,13 @@ def flushall(self): del self._pubsubs[:] # Basic key commands + @check_conn def append(self, key, value): self._db.setdefault(key, b'') self._db[key] += to_bytes(value) return len(self._db[key]) + @check_conn def bitcount(self, name, start=0, end=-1): if end == -1: end = None @@ -262,6 +276,7 @@ def bitcount(self, name, start=0, end=-1): except KeyError: return 0 + @check_conn def decr(self, name, amount=1): try: self._db[name] = int(self._db.get(name, '0')) - amount @@ -270,6 +285,7 @@ def decr(self, name, amount=1): "range.") return self._db[name] + @check_conn def exists(self, name): return name in self._db __contains__ = exists @@ -280,6 +296,7 @@ def expire(self, name, time): def pexpire(self, name, millis): return self._expire(name, millis, 1000) + @check_conn def _expire(self, name, time, multiplier=1): if isinstance(time, timedelta): time = int(timedelta_total_seconds(time) * multiplier) @@ -299,6 +316,7 @@ def expireat(self, name, when): def pexpireat(self, name, when): return self._expireat(name, when, 1000) + @check_conn def _expireat(self, name, when, multiplier=1): if not isinstance(when, datetime): when = datetime.fromtimestamp(when / float(multiplier)) @@ -308,11 +326,13 @@ def _expireat(self, name, when, multiplier=1): else: return False + @check_conn def echo(self, value): if isinstance(value, text_type): return value.encode('utf-8') return value + @check_conn def get(self, name): value = self._db.get(name) if isinstance(value, _StrKeyDict): @@ -324,6 +344,7 @@ def get(self, name): def __getitem__(self, name): return self.get(name) + @check_conn def getbit(self, name, offset): """Returns a boolean indicating the value of ``offset`` in ``name``""" val = self._db.get(name, '\x00') @@ -336,6 +357,7 @@ def getbit(self, name, offset): return 0 return 1 if (1 << actual_bitoffset) & actual_val else 0 + @check_conn def getset(self, name, value): """ Set the value at key ``name`` to ``value`` if key doesn't exist @@ -345,6 +367,7 @@ def getset(self, name, value): self._db[name] = value return val + @check_conn def incr(self, name, amount=1): """ Increments the value of ``key`` by ``amount``. If no key exists, @@ -366,6 +389,7 @@ def incrby(self, name, amount=1): """ return self.incr(name, amount) + @check_conn def incrbyfloat(self, name, amount=1.0): try: self._db[name] = float(self._db.get(name, '0')) + amount @@ -373,11 +397,13 @@ def incrbyfloat(self, name, amount=1.0): raise redis.ResponseError("value is not a valid float.") return self._db[name] + @check_conn def keys(self, pattern=None): return [key for key in self._db if not key or not pattern or fnmatch.fnmatch(to_native(key), to_native(pattern))] + @check_conn def mget(self, keys, *args): all_keys = self._list_or_args(keys, args) found = [] @@ -388,6 +414,7 @@ def mget(self, keys, *args): found.append(self._db.get(key)) return found + @check_conn def mset(self, *args, **kwargs): if args: if len(args) != 1 or not isinstance(args[0], dict): @@ -398,6 +425,7 @@ def mset(self, *args, **kwargs): self.set(key, val) return True + @check_conn def msetnx(self, mapping): """ Sets each key in the ``mapping`` dict to its corresponding value if @@ -412,6 +440,7 @@ def msetnx(self, mapping): def move(self, name, db): pass + @check_conn def persist(self, name): self._db.persist(name) @@ -421,6 +450,7 @@ def ping(self): def randomkey(self): pass + @check_conn def rename(self, src, dst): try: value = self._db[src] @@ -430,12 +460,14 @@ def rename(self, src, dst): del self._db[src] return True + @check_conn def renamenx(self, src, dst): if dst in self._db: return False else: return self.rename(src, dst) + @check_conn def set(self, name, value, ex=None, px=None, nx=False, xx=False): if (not nx and not xx) or (nx and self._db.get(name, None) is None) \ or (xx and not self._db.get(name, None) is None): @@ -465,6 +497,7 @@ def set(self, name, value, ex=None, px=None, nx=False, xx=False): __setitem__ = set + @check_conn def setbit(self, name, offset, value): val = self._db.get(name, b'\x00') byte = offset // 8 @@ -505,6 +538,7 @@ def setnx(self, name, value): return False return result + @check_conn def setrange(self, name, offset, value): val = self._db.get(name, b"") if len(val) < offset: @@ -513,12 +547,14 @@ def setrange(self, name, offset, value): self.set(name, val) return len(val) + @check_conn def strlen(self, name): try: return len(self._db[name]) except KeyError: return 0 + @check_conn def substr(self, name, start, end=-1): if end == -1: end = None @@ -538,6 +574,7 @@ def ttl(self, name): def pttl(self, name): return self._ttl(name, 1000) + @check_conn def _ttl(self, name, multiplier=1): if name not in self._db: return -2 @@ -554,6 +591,7 @@ def _ttl(self, name, multiplier=1): (exp_time - now).seconds + (exp_time - now).microseconds / 1E6) * multiplier)) + @check_conn def type(self, name): key = self._db.get(name) if hasattr(key.__class__, 'redis_type'): @@ -574,6 +612,7 @@ def watch(self, *names): def unwatch(self): pass + @check_conn def delete(self, *names): deleted = 0 for name in names: @@ -586,6 +625,7 @@ def delete(self, *names): continue return deleted + @check_conn def sort(self, name, start=None, num=None, by=None, get=None, desc=False, alpha=False, store=None): """Sort and return the list, set or sorted set at ``name``. @@ -633,6 +673,7 @@ def sort(self, name, start=None, num=None, by=None, get=None, desc=False, except KeyError: return [] + @check_conn def _retrive_data_from_sort(self, data, get): if get is not None: if isinstance(get, string_types): @@ -645,6 +686,7 @@ def _retrive_data_from_sort(self, data, get): data = new_data return data + @check_conn def _get_single_item(self, k, g): g = to_bytes(g) if b'*' in g: @@ -660,6 +702,7 @@ def _get_single_item(self, k, g): single_item = None return single_item + @check_conn def _strtod_key_func(self, arg): # str()'ing the arg is important! Don't ever remove this. arg = to_bytes(arg) @@ -686,11 +729,13 @@ def _by_key(arg): data.sort(key=_by_key) + @check_conn def lpush(self, name, *values): self._db.setdefault(name, [])[0:0] = list(reversed( [to_bytes(x) for x in values])) return len(self._db[name]) + @check_conn def lrange(self, name, start, end): if end == -1: end = None @@ -698,9 +743,11 @@ def lrange(self, name, start, end): end += 1 return self._db.get(name, [])[start:end] + @check_conn def llen(self, name): return len(self._db.get(name, [])) + @check_conn def lrem(self, name, count, value): value = to_bytes(value) a_list = self._db.get(name, []) @@ -720,28 +767,33 @@ def lrem(self, name, count, value): del a_list[index] return len(indices_to_remove) + @check_conn def rpush(self, name, *values): self._db.setdefault(name, []).extend([to_bytes(x) for x in values]) return len(self._db[name]) + @check_conn def lpop(self, name): try: return self._db.get(name, []).pop(0) except IndexError: return None + @check_conn def lset(self, name, index, value): try: self._db.get(name, [])[index] = to_bytes(value) except IndexError: raise redis.ResponseError("index out of range") + @check_conn def rpushx(self, name, value): try: self._db[name].append(to_bytes(value)) except KeyError: return + @check_conn def ltrim(self, name, start, end): try: val = self._db[name] @@ -754,28 +806,33 @@ def ltrim(self, name, start, end): self._db[name] = val[start:end] return True + @check_conn def lindex(self, name, index): try: return self._db.get(name, [])[index] except IndexError: return None + @check_conn def lpushx(self, name, value): try: self._db[name].insert(0, to_bytes(value)) except KeyError: return + @check_conn def rpop(self, name): try: return self._db.get(name, []).pop() except IndexError: return None + @check_conn def linsert(self, name, where, refvalue, value): index = self._db.get(name, []).index(to_bytes(refvalue)) self._db.get(name, []).insert(index, to_bytes(value)) + @check_conn def rpoplpush(self, src, dst): el = self.rpop(src) if el is not None: @@ -786,6 +843,7 @@ def rpoplpush(self, src, dst): self._db[dst] = [el] return el + @check_conn def blpop(self, keys, timeout=0): # This has to be a best effort approximation which follows # these rules: @@ -801,6 +859,7 @@ def blpop(self, keys, timeout=0): if self._db.get(key, []): return (key, self._db[key].pop(0)) + @check_conn def brpop(self, keys, timeout=0): if isinstance(keys, string_types): keys = [to_bytes(keys)] @@ -810,6 +869,7 @@ def brpop(self, keys, timeout=0): if self._db.get(key, []): return (key, self._db[key].pop()) + @check_conn def brpoplpush(self, src, dst, timeout=0): el = self.rpop(src) if el is not None: @@ -820,6 +880,7 @@ def brpoplpush(self, src, dst, timeout=0): self._db[dst] = [el] return el + @check_conn def hdel(self, name, *keys): h = self._db.get(name, {}) rem = 0 @@ -829,6 +890,7 @@ def hdel(self, name, *keys): rem += 1 return rem + @check_conn def hexists(self, name, key): "Returns a boolean indicating if ``key`` exists within hash ``name``" if self._db.get(name, {}).get(key) is None: @@ -836,10 +898,12 @@ def hexists(self, name, key): else: return 1 + @check_conn def hget(self, name, key): "Return the value of ``key`` within the hash ``name``" return self._db.get(name, {}).get(key) + @check_conn def hgetall(self, name): "Return a Python dict of the hash's name/value pairs" all_items = self._db.get(name, {}) @@ -847,12 +911,14 @@ def hgetall(self, name): all_items = all_items.to_bare_dict() return all_items + @check_conn def hincrby(self, name, key, amount=1): "Increment the value of ``key`` in hash ``name`` by ``amount``" new = int(self._db.setdefault(name, _Hash()).get(key, '0')) + amount self._db[name][key] = new return new + @check_conn def hincrbyfloat(self, name, key, amount=1.0): """Increment the value of key in hash name by floating amount""" try: @@ -867,14 +933,17 @@ def hincrbyfloat(self, name, key, amount=1.0): self._db[name][key] = new return new + @check_conn def hkeys(self, name): "Return the list of keys within hash ``name``" return list(self._db.get(name, {})) + @check_conn def hlen(self, name): "Return the number of elements in hash ``name``" return len(self._db.get(name, {})) + @check_conn def hset(self, name, key, value): """ Set ``key`` to ``value`` within hash ``name`` @@ -884,6 +953,7 @@ def hset(self, name, key, value): self._db.setdefault(name, _Hash())[key] = to_bytes(value) return 1 if key_is_new else 0 + @check_conn def hsetnx(self, name, key, value): """ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not @@ -894,6 +964,7 @@ def hsetnx(self, name, key, value): self._db.setdefault(name, _Hash())[key] = to_bytes(value) return True + @check_conn def hmset(self, name, mapping): """ Sets each key in the ``mapping`` dict to its corresponding value @@ -907,35 +978,41 @@ def hmset(self, name, mapping): self._db.setdefault(name, _Hash()).update(new_mapping) return True + @check_conn def hmget(self, name, keys, *args): - "Returns a list of values ordered identically to ``keys``" + """Returns a list of values ordered identically to ``keys``""" h = self._db.get(name, {}) all_keys = self._list_or_args(keys, args) return [h.get(k) for k in all_keys] + @check_conn def hvals(self, name): - "Return the list of values within hash ``name``" + """Return the list of values within hash ``name``""" return list(self._db.get(name, {}).values()) + @check_conn def sadd(self, name, *values): - "Add ``value`` to set ``name``" + """Add ``value`` to set ``name``""" a_set = self._db.setdefault(name, set()) card = len(a_set) a_set |= set(to_bytes(x) for x in values) return len(a_set) - card + @check_conn def scard(self, name): - "Return the number of elements in set ``name``" + """Return the number of elements in set ``name``""" return len(self._db.get(name, set())) + @check_conn def sdiff(self, keys, *args): - "Return the difference of sets specified by ``keys``" + """Return the difference of sets specified by ``keys``""" all_keys = (to_bytes(x) for x in self._list_or_args(keys, args)) diff = self._db.get(next(all_keys), set()).copy() for key in all_keys: diff -= self._db.get(key, set()) return diff + @check_conn def sdiffstore(self, dest, keys, *args): """ Store the difference of sets specified by ``keys`` into a new @@ -945,14 +1022,16 @@ def sdiffstore(self, dest, keys, *args): self._db[dest] = set(to_bytes(x) for x in diff) return len(diff) + @check_conn def sinter(self, keys, *args): - "Return the intersection of sets specified by ``keys``" + """Return the intersection of sets specified by ``keys``""" all_keys = (to_bytes(x) for x in self._list_or_args(keys, args)) intersect = self._db.get(next(all_keys), set()).copy() for key in all_keys: intersect.intersection_update(self._db.get(key, set())) return intersect + @check_conn def sinterstore(self, dest, keys, *args): """ Store the intersection of sets specified by ``keys`` into a new @@ -962,14 +1041,17 @@ def sinterstore(self, dest, keys, *args): self._db[dest] = set(to_bytes(x) for x in intersect) return len(intersect) + @check_conn def sismember(self, name, value): - "Return a boolean indicating if ``value`` is a member of set ``name``" + """Return a boolean indicating if ``value`` is a member of set ``name``""" return to_bytes(value) in self._db.get(name, set()) + @check_conn def smembers(self, name): - "Return all members of the set ``name``" + """Return all members of the set ``name``""" return self._db.get(name, set()) + @check_conn def smove(self, src, dst, value): value = to_bytes(value) try: @@ -979,13 +1061,15 @@ def smove(self, src, dst, value): except KeyError: return False + @check_conn def spop(self, name): - "Remove and return a random member of set ``name``" + """Remove and return a random member of set ``name``""" try: return self._db.get(name, set()).pop() except KeyError: return None + @check_conn def srandmember(self, name, number=None): """ If ``number`` is None, returns a random member of set ``name``. @@ -1015,21 +1099,24 @@ def srandmember(self, name, number=None): in sorted(random.sample(range(len(members)), number)) ] + @check_conn def srem(self, name, *values): - "Remove ``value`` from set ``name``" + """Remove ``value`` from set ``name``""" a_set = self._db.setdefault(name, set()) card = len(a_set) a_set -= set(to_bytes(x) for x in values) return card - len(a_set) + @check_conn def sunion(self, keys, *args): - "Return the union of sets specifiued by ``keys``" + """Return the union of sets specifiued by ``keys``""" all_keys = (to_bytes(x) for x in self._list_or_args(keys, args)) union = self._db.get(next(all_keys), set()).copy() for key in all_keys: union.update(self._db.get(key, set())) return union + @check_conn def sunionstore(self, dest, keys, *args): """ Store the union of sets specified by ``keys`` into a new @@ -1061,6 +1148,7 @@ def _matches(x): right_comparator(x, actual_max)) return _matches + @check_conn def _get_comparator_and_val(self, value): try: if isinstance(value, string_types) and value.startswith('('): @@ -1098,6 +1186,7 @@ def _matches(x): right_comparator(x, actual_max)) return _matches + @check_conn def _get_lexcomp_and_str(self, value): if value.startswith(b'('): comparator = operator.lt @@ -1122,6 +1211,7 @@ def _get_lexcomp_and_str(self, value): return comparator, actual_value + @check_conn def zadd(self, name, *args, **kwargs): """ Set any number of score, element-name pairs to the key ``name``. Pairs @@ -1154,10 +1244,12 @@ def zadd(self, name, *args, **kwargs): raise redis.ResponseError("value is not a valid float") return added + @check_conn def zcard(self, name): "Return the number of elements in the sorted set ``name``" return len(self._db.get(name, {})) + @check_conn def zcount(self, name, min, max): found = 0 filter_func = self._get_zelement_range_filter_func(min, max) @@ -1166,6 +1258,7 @@ def zcount(self, name, min, max): found += 1 return found + @check_conn def zincrby(self, name, value, amount=1): "Increment the score of ``value`` in sorted set ``name`` by ``amount``" d = self._db.setdefault(name, _ZSet()) @@ -1220,6 +1313,7 @@ def zrange(self, name, start, end, desc=False, withscores=False): else: return [(k, all_items[k]) for k in items] + @check_conn def _get_zelements_in_order(self, all_items, reverse=False): by_keyname = sorted( all_items.items(), key=lambda x: x[0], reverse=reverse) @@ -1241,6 +1335,7 @@ def zrangebyscore(self, name, min, max, return self._zrangebyscore(name, min, max, start, num, withscores, reverse=False) + @check_conn def _zrangebyscore(self, name, min, max, start, num, withscores, reverse): if (start is not None and num is None) or \ (num is not None and start is None): @@ -1278,6 +1373,7 @@ def zrangebylex(self, name, min, max, return self._zrangebylex(name, min, max, start, num, reverse=False) + @check_conn def _zrangebylex(self, name, min, max, start, num, reverse): if (start is not None and num is None) or \ (num is not None and start is None): @@ -1296,6 +1392,7 @@ def _zrangebylex(self, name, min, max, start, num, reverse): matches = matches[start:start + num] return matches + @check_conn def zrank(self, name, value): """ Returns a 0-based value indicating the rank of ``value`` in sorted set @@ -1308,8 +1405,9 @@ def zrank(self, name, value): except ValueError: return None + @check_conn def zrem(self, name, *values): - "Remove member ``value`` from sorted set ``name``" + """Remove member ``value`` from sorted set ``name``""" z = self._db.get(name, {}) rem = 0 for v in values: @@ -1337,6 +1435,7 @@ def zremrangebyrank(self, name, min, max): num_deleted += 1 return num_deleted + @check_conn def zremrangebyscore(self, name, min, max): """ Remove all elements in the sorted set ``name`` with scores @@ -1464,6 +1563,7 @@ def zunionstore(self, dest, keys, aggregate=None): "for ZINTERSTORE/ZUNIONSTORE") self._zaggregate(dest, keys, aggregate, lambda x: True) + @check_conn def _zaggregate(self, dest, keys, aggregate, should_include): new_zset = _ZSet() if aggregate is None: diff --git a/test_fakeredis.py b/test_fakeredis.py index c98a34b..f8a954c 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -2936,5 +2936,30 @@ def test_searches_for_c_stdlib_and_raises_if_missing(self): reload(fakeredis) +class TestFakeStrictRedisConnectionErrors(unittest.TestCase): + + def create_redis(self): + return fakeredis.FakeStrictRedis(db=0, connected=False) + + def setUp(self): + self.redis = self.create_redis() + + def test_flushdb(self): + with self.assertRaises(redis.ConnectionError): + self.redis.flushdb() + + self.assertEqual(self.redis._db, {}, 'DB should be empty') + + def test_flushall(self): + with self.assertRaises(redis.ConnectionError): + self.redis.flushall() + + self.assertEqual(self.redis._db, {}, 'DB should be empty') + + def test_append(self): + with self.assertRaises(redis.ConnectionError): + self.redis.append('key', 'value') + self.assertEqual(self.redis._db, {}, 'DB should be empty') + if __name__ == '__main__': unittest.main() From 6aea29c6d4125c61c2626e8f258054fefe1d019f Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 10 Jan 2018 15:12:06 -0600 Subject: [PATCH 02/24] updated test cases --- test_fakeredis.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/test_fakeredis.py b/test_fakeredis.py index f8a954c..2a0522f 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -2944,22 +2944,36 @@ def create_redis(self): def setUp(self): self.redis = self.create_redis() + def tearDown(self): + del self.redis + def test_flushdb(self): with self.assertRaises(redis.ConnectionError): self.redis.flushdb() - self.assertEqual(self.redis._db, {}, 'DB should be empty') - def test_flushall(self): with self.assertRaises(redis.ConnectionError): self.redis.flushall() - self.assertEqual(self.redis._db, {}, 'DB should be empty') - def test_append(self): with self.assertRaises(redis.ConnectionError): self.redis.append('key', 'value') - self.assertEqual(self.redis._db, {}, 'DB should be empty') + + self.assertEqual(self.redis._db, {}, 'DB should be empty') + + def test_bitcount(self): + with self.assertRaises(redis.ConnectionError): + self.redis.bitcount('name', 0, 20) + + def test_decr(self): + with self.assertRaises(redis.ConnectionError): + self.redis.decr('key', 2) + + self.assertEqual(self.redis._db, {}, 'DB should be empty') + + def test_exists(self): + with self.assertRaises(redis.ConnectionError): + self.redis.exists('key') if __name__ == '__main__': unittest.main() From 5359bdf608f17331f4a0914c13298a4a7d20e510 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 10 Jan 2018 15:50:37 -0600 Subject: [PATCH 03/24] Updated ConnectionError test cases --- test_fakeredis.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/test_fakeredis.py b/test_fakeredis.py index 2a0522f..32aa1c7 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -2963,7 +2963,7 @@ def test_append(self): def test_bitcount(self): with self.assertRaises(redis.ConnectionError): - self.redis.bitcount('name', 0, 20) + self.redis.bitcount('key', 0, 20) def test_decr(self): with self.assertRaises(redis.ConnectionError): @@ -2975,5 +2975,65 @@ def test_exists(self): with self.assertRaises(redis.ConnectionError): self.redis.exists('key') + def test_expire(self): + with self.assertRaises(redis.ConnectionError): + self.redis.expire('key', 20) + + def test_pexpire(self): + with self.assertRaises(redis.ConnectionError): + self.redis.pexpire('key', 20) + + def test_echo(self): + with self.assertRaises(redis.ConnectionError): + self.redis.echo('value') + + def test_get(self): + with self.assertRaises(redis.ConnectionError): + self.redis.get('key') + + def test_getbit(self): + with self.assertRaises(redis.ConnectionError): + self.redis.getbit('key', 2) + + def test_getset(self): + with self.assertRaises(redis.ConnectionError): + self.redis.getset('key', 'value') + + def test_incr(self): + with self.assertRaises(redis.ConnectionError): + self.redis.incr('key') + + def test_incrby(self): + with self.assertRaises(redis.ConnectionError): + self.redis.incrby('key') + + def test_ncrbyfloat(self): + with self.assertRaises(redis.ConnectionError): + self.redis.incrbyfloat('key') + + def test_keys(self): + with self.assertRaises(redis.ConnectionError): + self.redis.keys() + + def test_mget(self): + with self.assertRaises(redis.ConnectionError): + self.redis.mget(['key1', 'key2']) + + def test_mset(self): + with self.assertRaises(redis.ConnectionError): + self.redis.mset(('key', 'value')) + + def test_msetnx(self): + with self.assertRaises(redis.ConnectionError): + self.redis.msetnx({'key': 'value'}) + + def test_persist(self): + with self.assertRaises(redis.ConnectionError): + self.redis.persist('key') + + def test_rename(self): + with self.assertRaises(redis.ConnectionError): + self.redis.rename('key1', 'key2') + if __name__ == '__main__': unittest.main() From f94017669cd66805cf89109398df6e541c08f3be Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 10 Jan 2018 16:39:02 -0600 Subject: [PATCH 04/24] Updated test case to actually fail when decorator does not exist instead of throwing an error --- test_fakeredis.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test_fakeredis.py b/test_fakeredis.py index 32aa1c7..4b3e37d 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -3032,8 +3032,13 @@ def test_persist(self): self.redis.persist('key') def test_rename(self): + self.redis._connected = True + self.redis.set('key1', 'value') + self.redis._connected = False with self.assertRaises(redis.ConnectionError): self.redis.rename('key1', 'key2') + self.redis._connected = True + self.assertTrue(self.redis.exists('key1')) if __name__ == '__main__': unittest.main() From ffb62052cadad37eb36a47f234738d2a00b6244b Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Fri, 6 Jul 2018 21:12:58 -0500 Subject: [PATCH 05/24] Added venv to ignore file --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4ad3cab..881eb2a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ extras/* .tox *.pyc .idea +venv/ From 4263daa04eb04e1750d991bcbdcb7d92123bd70d Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Fri, 6 Jul 2018 21:13:23 -0500 Subject: [PATCH 06/24] Added missing check_conn so tests work --- fakeredis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fakeredis.py b/fakeredis.py index ae60e98..a1fb176 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -594,6 +594,7 @@ def msetnx(self, mapping): return True return False + @check_conn def persist(self, name): self._db.persist(name) From 2164cb68d5011cd1f344e26c84ff88cf8b6a235b Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Fri, 6 Jul 2018 22:12:05 -0500 Subject: [PATCH 07/24] Updated tests and added check decorator as needed --- fakeredis.py | 31 ++++-- test_fakeredis.py | 256 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 276 insertions(+), 11 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index a1fb176..29d15b8 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -440,13 +440,14 @@ def exists(self, name): return name in self._db __contains__ = exists + @check_conn def expire(self, name, time): return self._expire(name, time) + @check_conn def pexpire(self, name, millis): return self._expire(name, millis, 1000) - @check_conn def _expire(self, name, time, multiplier=1): if isinstance(time, timedelta): time = int(timedelta_total_seconds(time) * multiplier) @@ -460,13 +461,14 @@ def _expire(self, name, time, multiplier=1): else: return False + @check_conn def expireat(self, name, when): return self._expireat(name, when) + @check_conn def pexpireat(self, name, when): return self._expireat(name, when, 1000) - @check_conn def _expireat(self, name, when, multiplier=1): if not isinstance(when, datetime): when = datetime.fromtimestamp(when / float(multiplier)) @@ -719,13 +721,14 @@ def substr(self, name, start, end=-1): # according to the docs. getrange = substr + @check_conn def ttl(self, name): return self._ttl(name) + @check_conn def pttl(self, name): return self._ttl(name, 1000) - @check_conn def _ttl(self, name, multiplier=1): if name not in self._db: return -2 @@ -757,10 +760,12 @@ def type(self, name): assert key is None return b'none' + @check_conn @_lua_reply(_lua_bool_ok) def watch(self, *names): pass + @check_conn @_lua_reply(_lua_bool_ok) def unwatch(self): pass @@ -827,6 +832,7 @@ def sort(self, name, start=None, num=None, by=None, get=None, desc=False, except KeyError: return [] + @check_conn def eval(self, script, numkeys, *keys_and_args): from lupa import LuaRuntime, LuaError @@ -1002,7 +1008,6 @@ def _retrieve_data_from_sort(self, data, get): data = new_data return data - @check_conn def _get_single_item(self, k, g): g = to_bytes(g) if b'*' in g: @@ -1018,7 +1023,6 @@ def _get_single_item(self, k, g): single_item = None return single_item - @check_conn def _strtod_key_func(self, arg): # str()'ing the arg is important! Don't ever remove this. arg = to_bytes(arg) @@ -1549,7 +1553,6 @@ def _matches(x): right_comparator(x, actual_max)) return _matches - @check_conn def _get_comparator_and_val(self, value): try: if isinstance(value, string_types) and value.startswith('('): @@ -1587,7 +1590,6 @@ def _matches(x): right_comparator(x, actual_max)) return _matches - @check_conn def _get_lexcomp_and_str(self, value): if value.startswith(b'('): comparator = operator.lt @@ -1665,6 +1667,7 @@ def zincrby(self, name, value, amount=1): d[value] = score return score + @check_conn @_remove_empty def zinterstore(self, dest, keys, aggregate=None): """ @@ -1696,6 +1699,7 @@ def _apply_score_cast_func(self, items, all_items, withscores, score_cast_func): else: return [(k, score_cast_func(to_bytes(all_items[k]))) for k in items] + @check_conn def zrange(self, name, start, end, desc=False, withscores=False, score_cast_func=float): """ Return a range of values from sorted set ``name`` between @@ -1723,13 +1727,13 @@ def zrange(self, name, start, end, desc=False, withscores=False, score_cast_func items = in_order[start:end] return self._apply_score_cast_func(items, all_items, withscores, score_cast_func) - @check_conn def _get_zelements_in_order(self, all_items, reverse=False): by_keyname = sorted( all_items.items(), key=lambda x: x[0], reverse=reverse) in_order = sorted(by_keyname, key=lambda x: x[1], reverse=reverse) return [el[0] for el in in_order] + @check_conn def zrangebyscore(self, name, min, max, start=None, num=None, withscores=False, score_cast_func=float): """ @@ -1747,7 +1751,6 @@ def zrangebyscore(self, name, min, max, start=None, num=None, return self._zrangebyscore(name, min, max, start, num, withscores, score_cast_func, reverse=False) - @check_conn def _zrangebyscore(self, name, min, max, start, num, withscores, score_cast_func, reverse): if (start is not None and num is None) or \ (num is not None and start is None): @@ -1764,6 +1767,7 @@ def _zrangebyscore(self, name, min, max, start, num, withscores, score_cast_func matches = matches[start:start + num] return self._apply_score_cast_func(matches, all_items, withscores, score_cast_func) + @check_conn def zrangebylex(self, name, min, max, start=None, num=None): """ @@ -1783,7 +1787,6 @@ def zrangebylex(self, name, min, max, return self._zrangebylex(name, min, max, start, num, reverse=False) - @check_conn def _zrangebylex(self, name, min, max, start, num, reverse): if (start is not None and num is None) or \ (num is not None and start is None): @@ -1827,6 +1830,7 @@ def zrem(self, name, *values): rem += 1 return rem + @check_conn @_remove_empty def zremrangebyrank(self, name, min, max): """ @@ -1863,6 +1867,7 @@ def zremrangebyscore(self, name, min, max): removed += 1 return removed + @check_conn @_remove_empty def zremrangebylex(self, name, min, max): """ @@ -1884,6 +1889,7 @@ def zremrangebylex(self, name, min, max): removed += 1 return removed + @check_conn def zlexcount(self, name, min, max): """ Returns a count of elements in the sorted set ``name`` @@ -1903,6 +1909,7 @@ def zlexcount(self, name, min, max): found += 1 return found + @check_conn def zrevrange(self, name, start, num, withscores=False, score_cast_func=float): """ Return a range of values from sorted set ``name`` between @@ -1917,6 +1924,7 @@ def zrevrange(self, name, start, num, withscores=False, score_cast_func=float): """ return self.zrange(name, start, num, True, withscores, score_cast_func) + @check_conn def zrevrangebyscore(self, name, max, min, start=None, num=None, withscores=False, score_cast_func=float): """ @@ -1934,6 +1942,7 @@ def zrevrangebyscore(self, name, max, min, start=None, num=None, return self._zrangebyscore(name, min, max, start, num, withscores, score_cast_func, reverse=True) + @check_conn def zrevrangebylex(self, name, max, min, start=None, num=None): """ @@ -1953,6 +1962,7 @@ def zrevrangebylex(self, name, max, min, return self._zrangebylex(name, min, max, start, num, reverse=True) + # TODO: left off here def zrevrank(self, name, value): """ Returns a 0-based value indicating the descending rank of @@ -1982,7 +1992,6 @@ def zunionstore(self, dest, keys, aggregate=None): "for ZINTERSTORE/ZUNIONSTORE") self._zaggregate(dest, keys, aggregate, lambda x: True) - @check_conn def _zaggregate(self, dest, keys, aggregate, should_include): new_zset = _ZSet() if aggregate is None: diff --git a/test_fakeredis.py b/test_fakeredis.py index 1a91b2e..15f0bdf 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -4087,5 +4087,261 @@ def test_rename(self): self.redis._connected = True self.assertTrue(self.redis.exists('key1')) + def test_watch(self): + with self.assertRaises(redis.ConnectionError): + self.redis.watch() + + def test_unwatch(self): + with self.assertRaises(redis.ConnectionError): + self.redis.unwatch() + + def test_eval(self): + with self.assertRaises(redis.ConnectionError): + self.redis.eval('', 0) + + def test_lpush(self): + with self.assertRaises(redis.ConnectionError): + self.redis.lpush('name', [1, 2]) + + def test_lrange(self): + with self.assertRaises(redis.ConnectionError): + self.redis.lrange('name', 1, 5) + + def test_llen(self): + with self.assertRaises(redis.ConnectionError): + self.redis.llen('name') + + def test_lrem(self): + with self.assertRaises(redis.ConnectionError): + self.redis.lrem('name', 2, 2) + + def test_rpush(self): + with self.assertRaises(redis.ConnectionError): + self.redis.rpush('name', [1]) + + def test_lpop(self): + with self.assertRaises(redis.ConnectionError): + self.redis.lpop('name') + + def test_lset(self): + with self.assertRaises(redis.ConnectionError): + self.redis.lset('name', 1, 4) + + def test_rpushx(self): + with self.assertRaises(redis.ConnectionError): + self.redis.rpushx('name', 1) + + def test_ltrim(self): + with self.assertRaises(redis.ConnectionError): + self.redis.ltrim('name', 1, 4) + + def test_lindex(self): + with self.assertRaises(redis.ConnectionError): + self.redis.lindex('name', 1) + + def test_lpushx(self): + with self.assertRaises(redis.ConnectionError): + self.redis.lpushx('name', 1) + + def test_rpop(self): + with self.assertRaises(redis.ConnectionError): + self.redis.rpop('name') + + def test_linsert(self): + with self.assertRaises(redis.ConnectionError): + self.redis.linsert('name', 'where', 'refvalue', 'value') + + def test_rpoplpush(self): + with self.assertRaises(redis.ConnectionError): + self.redis.rpoplpush('src', 'dst') + + def test_blpop(self): + with self.assertRaises(redis.ConnectionError): + self.redis.blpop('keys') + + def test_brpop(self): + with self.assertRaises(redis.ConnectionError): + self.redis.brpop('keys') + + def test_brpoplpush(self): + with self.assertRaises(redis.ConnectionError): + self.redis.brpoplpush('src', 'dst') + + def test_hdel(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hdel('name') + + def test_hexists(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hexists('name', 'key') + + def test_hget(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hget('name', 'key') + + def test_hgetall(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hgetall('name') + + def test_hincrby(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hincrby('name', 'key') + + def test_hincrbyfloat(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hincrbyfloat('name', 'key') + + def test_hkeys(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hkeys('name') + + def test_hlen(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hlen('name') + + def test_hset(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hset('name', 'key', 1) + + def test_hsetnx(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hsetnx('name', 'key', 2) + + def test_hmset(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hmset('name', {'key': 1}) + + def test_hmget(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hmget('name', ['a', 'b']) + + def test_hvals(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hvals('name') + + def test_sadd(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sadd('name', [1, 2]) + + def test_scard(self): + with self.assertRaises(redis.ConnectionError): + self.redis.scard('name') + + def test_sdiff(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sdiff(['a', 'b']) + + def test_sdiffstore(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sdiffstore('dest', ['a', 'b']) + + def test_sinter(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sinter(['a', 'b']) + + def test_sinterstore(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sinterstore('dest', ['a', 'b']) + + def test_sismember(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sismember('name', 20) + + def test_smembers(self): + with self.assertRaises(redis.ConnectionError): + self.redis.smembers('name') + + def test_smove(self): + with self.assertRaises(redis.ConnectionError): + self.redis.smove('src','dest', 20) + + def test_spop(self): + with self.assertRaises(redis.ConnectionError): + self.redis.spop('name') + + def test_srandmember(self): + with self.assertRaises(redis.ConnectionError): + self.redis.srandmember('name') + + def test_srem(self): + with self.assertRaises(redis.ConnectionError): + self.redis.srem('name') + + def test_sunion(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sunion(['a', 'b']) + + def test_sunionstore(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sunionstore('dest', ['a', 'b']) + + def test_zadd(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zadd('name') + + def test_zcard(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zcard('name') + + def test_zcount(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zcount('name', 1, 5) + + def test_zincrby(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zincrby('name', 1) + + def test_zinterstore(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zinterstore('dest', ['a', 'b']) + + def test_zrange(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrange('name', 1, 5) + + def test_zrangebyscore(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrangebyscore('name', 1, 5) + + def test_rangebylex(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrangebylex('name', 1, 4) + + def test_zrange(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrank('name', 1) + + def test_zrem(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrem('name', [1]) + + def test_zremrangebyrank(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zremrangebyrank('name', 1, 5) + + def test_zremrangebyscore(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zremrangebyscore('name', 1, 5) + + def test_zremrangebylex(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zremrangebylex('name', 1, 5) + + def test_zlexcount(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zlexcount('name', 1, 5) + + def test_zrevrange(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrevrange('name', 1, 5, 1) + + def test_zrevrangebyscore(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrevrangebyscore('name', 5, 1) + + def test_zrevrangebylex(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrevrangebylex('name', 5, 1) + if __name__ == '__main__': unittest.main() From cbaba988508e2e05aae0f40b2ae7361273b14f46 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Sat, 7 Jul 2018 10:38:32 -0500 Subject: [PATCH 08/24] finished remaining tests --- fakeredis.py | 18 ++++++++++++- test_fakeredis.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/fakeredis.py b/fakeredis.py index 29d15b8..4a4ed5c 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -1962,7 +1962,7 @@ def zrevrangebylex(self, name, max, min, return self._zrangebylex(name, min, max, start, num, reverse=True) - # TODO: left off here + @check_conn def zrevrank(self, name, value): """ Returns a 0-based value indicating the descending rank of @@ -1973,6 +1973,7 @@ def zrevrank(self, name, value): if zrank is not None: return num_items - self.zrank(name, value) - 1 + @check_conn def zscore(self, name, value): "Return the score of element ``value`` in sorted set ``name``" all_items = self._get_zset(name) @@ -1981,6 +1982,7 @@ def zscore(self, name, value): except KeyError: return None + @check_conn def zunionstore(self, dest, keys, aggregate=None): """ Union multiple sorted sets specified by ``keys`` into @@ -2037,6 +2039,7 @@ def _list_or_args(self, keys, args): keys.extend(args) return keys + @check_conn def pipeline(self, transaction=True, shard_hint=None): """Return an object that can be used to issue Redis commands in a batch. @@ -2046,6 +2049,7 @@ def pipeline(self, transaction=True, shard_hint=None): """ return FakePipeline(self, transaction) + @check_conn def transaction(self, func, *keys, **kwargs): shard_hint = kwargs.pop('shard_hint', None) value_from_callable = kwargs.pop('value_from_callable', False) @@ -2068,10 +2072,12 @@ def transaction(self, func, *keys, **kwargs): continue raise redis.WatchError('Could not run transaction after 5 tries') + @check_conn def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None, lock_class=None, thread_local=True): return _Lock(self, name, timeout) + @check_conn def pubsub(self, ignore_subscribe_messages=False): """ Returns a new FakePubSub instance @@ -2082,6 +2088,7 @@ def pubsub(self, ignore_subscribe_messages=False): return ps + @check_conn def publish(self, channel, message): """ Loops through all available pubsub objects and publishes the @@ -2098,6 +2105,7 @@ def publish(self, channel, message): return count # HYPERLOGLOG COMMANDS + @check_conn def pfadd(self, name, *values): "Adds the specified elements to the specified HyperLogLog." # Simulate the behavior of HyperLogLog by using SETs underneath to @@ -2108,6 +2116,7 @@ def pfadd(self, name, *values): # - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise. return 1 if result > 0 else 0 + @check_conn def pfcount(self, *sources): """ Return the approximated cardinality of @@ -2115,6 +2124,7 @@ def pfcount(self, *sources): """ return len(self.sunion(*sources)) + @check_conn @_lua_reply(_lua_bool_ok) def pfmerge(self, dest, *sources): "Merge N different HyperLogLogs into a single one." @@ -2146,12 +2156,15 @@ def _scan(self, keys, cursor, match, count): result_cursor = 0 return result_cursor, result_data + @check_conn def scan(self, cursor=0, match=None, count=None): return self._scan(self.keys(), int(cursor), match, count or 10) + @check_conn def sscan(self, name, cursor=0, match=None, count=None): return self._scan(self.smembers(name), int(cursor), match, count or 10) + @check_conn def hscan(self, name, cursor=0, match=None, count=None): cursor, keys = self._scan(self.hkeys(name), int(cursor), match, count or 10) results = {} @@ -2159,6 +2172,7 @@ def hscan(self, name, cursor=0, match=None, count=None): results[k] = self.hget(name, k) return cursor, results + @check_conn def scan_iter(self, match=None, count=None): # This is from redis-py cursor = '0' @@ -2167,6 +2181,7 @@ def scan_iter(self, match=None, count=None): for item in data: yield item + @check_conn def sscan_iter(self, name, match=None, count=None): # This is from redis-py cursor = '0' @@ -2176,6 +2191,7 @@ def sscan_iter(self, name, match=None, count=None): for item in data: yield item + @check_conn def hscan_iter(self, name, match=None, count=None): # This is from redis-py cursor = '0' diff --git a/test_fakeredis.py b/test_fakeredis.py index 15f0bdf..f4a2486 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -4343,5 +4343,71 @@ def test_zrevrangebylex(self): with self.assertRaises(redis.ConnectionError): self.redis.zrevrangebylex('name', 5, 1) + def test_zrevran(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zrevrank('name', 2) + + def test_zscore(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zscore('name', 2) + + def test_zunionstor(self): + with self.assertRaises(redis.ConnectionError): + self.redis.zunionstore('dest', ['1', '2']) + + def test_pipeline(self): + with self.assertRaises(redis.ConnectionError): + self.redis.pipeline() + + def test_transaction(self): + with self.assertRaises(redis.ConnectionError): + def func(a): + return a * a + + self.redis.transaction(func, 3) + + def test_lock(self): + with self.assertRaises(redis.ConnectionError): + self.redis.lock('name') + + def test_pubsub(self): + with self.assertRaises(redis.ConnectionError): + self.redis.pubsub() + + def test_pfadd(self): + with self.assertRaises(redis.ConnectionError): + self.redis.pfadd('name', [1]) + + def test_pfmerge(self): + with self.assertRaises(redis.ConnectionError): + self.redis.pfmerge('dest', ['a', 'b']) + + def test_scan(self): + with self.assertRaises(redis.ConnectionError): + self.redis.scan() + + def test_sscan(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sscan('name') + + def test_hscan(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hscan('name') + + def test_scan_iter(self): + with self.assertRaises(redis.ConnectionError): + self.redis.scan_iter() + + def test_sscan_iter(self): + with self.assertRaises(redis.ConnectionError): + self.redis.sscan_iter('name') + + def test_hscan_iter(self): + with self.assertRaises(redis.ConnectionError): + self.redis.hscan_iter('name') + + + + if __name__ == '__main__': unittest.main() From ef97612e95f40b1900022afd51e03fccac826909 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Sat, 7 Jul 2018 10:48:04 -0500 Subject: [PATCH 09/24] flake8 --- test_fakeredis.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test_fakeredis.py b/test_fakeredis.py index f4a2486..53c5d15 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -4253,7 +4253,7 @@ def test_smembers(self): def test_smove(self): with self.assertRaises(redis.ConnectionError): - self.redis.smove('src','dest', 20) + self.redis.smove('src', 'dest', 20) def test_spop(self): with self.assertRaises(redis.ConnectionError): @@ -4307,10 +4307,6 @@ def test_rangebylex(self): with self.assertRaises(redis.ConnectionError): self.redis.zrangebylex('name', 1, 4) - def test_zrange(self): - with self.assertRaises(redis.ConnectionError): - self.redis.zrank('name', 1) - def test_zrem(self): with self.assertRaises(redis.ConnectionError): self.redis.zrem('name', [1]) @@ -4407,7 +4403,5 @@ def test_hscan_iter(self): self.redis.hscan_iter('name') - - if __name__ == '__main__': unittest.main() From 269fc1cc99c77fba458c0f61599bd2d8a1ae61d6 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Sat, 7 Jul 2018 11:09:48 -0500 Subject: [PATCH 10/24] Fixed issue with decorator not playing well with _lua_reply --- fakeredis.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index 4a4ed5c..1925577 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -333,6 +333,8 @@ def release(self): def check_conn(func): """Used to mock connection errors""" + + @functools.wraps(func) def func_wrapper(*args, **kwargs): if not args[0]._connected: raise redis.ConnectionError @@ -369,14 +371,14 @@ def __init__(self, db=0, charset='utf-8', errors='strict', if decode_responses: _patch_responses(self) - @check_conn @_lua_reply(_lua_bool_ok) + @check_conn def flushdb(self): self._db.clear() return True - @check_conn @_lua_reply(_lua_bool_ok) + @check_conn def flushall(self): for db in self._dbs.values(): db.clear() @@ -572,8 +574,8 @@ def mget(self, keys, *args): found.append(value) return found - @check_conn @_lua_reply(_lua_bool_ok) + @check_conn def mset(self, *args, **kwargs): if args: if len(args) != 1 or not isinstance(args[0], dict): @@ -603,8 +605,8 @@ def persist(self, name): def ping(self): return True - @check_conn @_lua_reply(_lua_bool_ok) + @check_conn def rename(self, src, dst): try: value = self._db[src] @@ -760,13 +762,13 @@ def type(self, name): assert key is None return b'none' - @check_conn @_lua_reply(_lua_bool_ok) + @check_conn def watch(self, *names): pass - @check_conn @_lua_reply(_lua_bool_ok) + @check_conn def unwatch(self): pass @@ -1123,8 +1125,8 @@ def lpop(self, name): except IndexError: return None - @check_conn @_lua_reply(_lua_bool_ok) + @check_conn def lset(self, name, index, value): try: lst = self._get_list_or_none(name) @@ -1139,8 +1141,8 @@ def lset(self, name, index, value): def rpushx(self, name, value): self._get_list(name).append(to_bytes(value)) - @check_conn @_lua_reply(_lua_bool_ok) + @check_conn def ltrim(self, name, start, end): val = self._get_list_or_none(name) if val is not None: From 49db0a3b44058472a6d90ceaf1cfeb6af4322da4 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Tue, 10 Jul 2018 14:46:40 -0500 Subject: [PATCH 11/24] Modified _patch_responses to take the decorator as a parameter --- fakeredis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index 1925577..370ffac 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -221,12 +221,12 @@ def decode_response(*args, **kwargs): return decode_response -def _patch_responses(obj): +def _patch_responses(obj, decorator): for attr_name in dir(obj): attr = getattr(obj, attr_name) if not callable(attr) or attr_name.startswith('_'): continue - func = _make_decode_func(attr) + func = decorator(attr) setattr(obj, attr_name, func) From 5dccd2a359ff621ad3a1a467a0f7b6d2e27b6ce6 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Tue, 10 Jul 2018 14:48:01 -0500 Subject: [PATCH 12/24] Made connection check decorator private function. Adjusted for use through _patch_responses --- fakeredis.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index 370ffac..35ed5d7 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -331,12 +331,11 @@ def release(self): self.redis.delete(self.name) -def check_conn(func): +def _check_conn(func): """Used to mock connection errors""" - @functools.wraps(func) def func_wrapper(*args, **kwargs): - if not args[0]._connected: + if not func.__self__.connected: raise redis.ConnectionError return func(*args, **kwargs) return func_wrapper @@ -367,9 +366,12 @@ def __init__(self, db=0, charset='utf-8', errors='strict', self._encoding_errors = errors self._pubsubs = [] self._decode_responses = decode_responses - self._connected = connected + self.connected = connected if decode_responses: - _patch_responses(self) + _patch_responses(self, _make_decode_func) + + if not connected: + _patch_responses(self, _check_conn) @_lua_reply(_lua_bool_ok) @check_conn From c1a80b3a076a759f97283b811928ed9f9864de0d Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Tue, 10 Jul 2018 14:51:14 -0500 Subject: [PATCH 13/24] Removed decorator from methods --- fakeredis.py | 114 --------------------------------------------------- 1 file changed, 114 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index 35ed5d7..9845fd0 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -374,13 +374,11 @@ def __init__(self, db=0, charset='utf-8', errors='strict', _patch_responses(self, _check_conn) @_lua_reply(_lua_bool_ok) - @check_conn def flushdb(self): self._db.clear() return True @_lua_reply(_lua_bool_ok) - @check_conn def flushall(self): for db in self._dbs.values(): db.clear() @@ -411,13 +409,11 @@ def _setdefault_string(self, name): return value # Basic key commands - @check_conn def append(self, key, value): self._setdefault_string(key) self._db[key] += to_bytes(value) return len(self._db[key]) - @check_conn def bitcount(self, name, start=0, end=-1): if end == -1: end = None @@ -429,7 +425,6 @@ def bitcount(self, name, start=0, end=-1): except KeyError: return 0 - @check_conn def decr(self, name, amount=1): try: value = int(self._get_string(name, b'0')) - amount @@ -439,16 +434,13 @@ def decr(self, name, amount=1): "range.") return value - @check_conn def exists(self, name): return name in self._db __contains__ = exists - @check_conn def expire(self, name, time): return self._expire(name, time) - @check_conn def pexpire(self, name, millis): return self._expire(name, millis, 1000) @@ -465,11 +457,9 @@ def _expire(self, name, time, multiplier=1): else: return False - @check_conn def expireat(self, name, when): return self._expireat(name, when) - @check_conn def pexpireat(self, name, when): return self._expireat(name, when, 1000) @@ -482,13 +472,11 @@ def _expireat(self, name, when, multiplier=1): else: return False - @check_conn def echo(self, value): if isinstance(value, text_type): return value.encode('utf-8') return value - @check_conn def get(self, name): value = self._get_string(name, None) if value is not None: @@ -500,7 +488,6 @@ def __getitem__(self, name): return value raise KeyError(name) - @check_conn def getbit(self, name, offset): """Returns a boolean indicating the value of ``offset`` in ``name``""" val = self._get_string(name) @@ -513,7 +500,6 @@ def getbit(self, name, offset): return 0 return 1 if (1 << actual_bitoffset) & actual_val else 0 - @check_conn def getset(self, name, value): """ Set the value at key ``name`` to ``value`` if key doesn't exist @@ -523,7 +509,6 @@ def getset(self, name, value): self._db[name] = to_bytes(value) return val - @check_conn def incr(self, name, amount=1): """ Increments the value of ``key`` by ``amount``. If no key exists, @@ -546,7 +531,6 @@ def incrby(self, name, amount=1): """ return self.incr(name, amount) - @check_conn def incrbyfloat(self, name, amount=1.0): try: value = float(self._get_string(name, b'0')) + amount @@ -555,13 +539,11 @@ def incrbyfloat(self, name, amount=1.0): raise redis.ResponseError("value is not a valid float.") return value - @check_conn def keys(self, pattern=None): if pattern is not None: regex = _compile_pattern(pattern) return [key for key in self._db if pattern is None or regex.match(key)] - @check_conn def mget(self, keys, *args): all_keys = self._list_or_args(keys, args) found = [] @@ -577,7 +559,6 @@ def mget(self, keys, *args): return found @_lua_reply(_lua_bool_ok) - @check_conn def mset(self, *args, **kwargs): if args: if len(args) != 1 or not isinstance(args[0], dict): @@ -588,7 +569,6 @@ def mset(self, *args, **kwargs): self.set(key, val) return True - @check_conn def msetnx(self, mapping): """ Sets each key in the ``mapping`` dict to its corresponding value if @@ -600,7 +580,6 @@ def msetnx(self, mapping): return True return False - @check_conn def persist(self, name): self._db.persist(name) @@ -608,7 +587,6 @@ def ping(self): return True @_lua_reply(_lua_bool_ok) - @check_conn def rename(self, src, dst): try: value = self._db[src] @@ -618,14 +596,12 @@ def rename(self, src, dst): del self._db[src] return True - @check_conn def renamenx(self, src, dst): if dst in self._db: return False else: return self.rename(src, dst) - @check_conn def set(self, name, value, ex=None, px=None, nx=False, xx=False): if (not nx and not xx) or (nx and self._db.get(name, None) is None) \ or (xx and not self._db.get(name, None) is None): @@ -654,7 +630,6 @@ def set(self, name, value, ex=None, px=None, nx=False, xx=False): __setitem__ = set - @check_conn def setbit(self, name, offset, value): val = self._get_string(name, b'\x00') byte = offset // 8 @@ -698,7 +673,6 @@ def setnx(self, name, value): return False return result - @check_conn def setrange(self, name, offset, value): val = self._get_string(name, b"") if len(val) < offset: @@ -707,11 +681,9 @@ def setrange(self, name, offset, value): self._db.setx(name, val) return len(val) - @check_conn def strlen(self, name): return len(self._get_string(name)) - @check_conn def substr(self, name, start, end=-1): if end == -1: end = None @@ -725,11 +697,9 @@ def substr(self, name, start, end=-1): # according to the docs. getrange = substr - @check_conn def ttl(self, name): return self._ttl(name) - @check_conn def pttl(self, name): return self._ttl(name, 1000) @@ -749,7 +719,6 @@ def _ttl(self, name, multiplier=1): (exp_time - now).seconds + (exp_time - now).microseconds / 1E6) * multiplier)) - @check_conn def type(self, name): key = self._db.get(name) if hasattr(key.__class__, 'redis_type'): @@ -765,16 +734,13 @@ def type(self, name): return b'none' @_lua_reply(_lua_bool_ok) - @check_conn def watch(self, *names): pass @_lua_reply(_lua_bool_ok) - @check_conn def unwatch(self): pass - @check_conn def delete(self, *names): deleted = 0 for name in names: @@ -785,7 +751,6 @@ def delete(self, *names): continue return deleted - @check_conn def sort(self, name, start=None, num=None, by=None, get=None, desc=False, alpha=False, store=None): """Sort and return the list, set or sorted set at ``name``. @@ -836,7 +801,6 @@ def sort(self, name, start=None, num=None, by=None, get=None, desc=False, except KeyError: return [] - @check_conn def eval(self, script, numkeys, *keys_and_args): from lupa import LuaRuntime, LuaError @@ -1075,13 +1039,11 @@ def _setdefault_list(self, name): raise redis.ResponseError(_WRONGTYPE_MSG) return value - @check_conn def lpush(self, name, *values): self._setdefault_list(name)[0:0] = list(reversed( [to_bytes(x) for x in values])) return len(self._db[name]) - @check_conn def lrange(self, name, start, end): if end == -1: end = None @@ -1089,11 +1051,9 @@ def lrange(self, name, start, end): end += 1 return self._get_list(name)[start:end] - @check_conn def llen(self, name): return len(self._get_list(name)) - @check_conn @_remove_empty def lrem(self, name, count, value): value = to_bytes(value) @@ -1114,12 +1074,10 @@ def lrem(self, name, count, value): del a_list[index] return len(indices_to_remove) - @check_conn def rpush(self, name, *values): self._setdefault_list(name).extend([to_bytes(x) for x in values]) return len(self._db[name]) - @check_conn @_remove_empty def lpop(self, name): try: @@ -1128,7 +1086,6 @@ def lpop(self, name): return None @_lua_reply(_lua_bool_ok) - @check_conn def lset(self, name, index, value): try: lst = self._get_list_or_none(name) @@ -1139,12 +1096,10 @@ def lset(self, name, index, value): raise redis.ResponseError("index out of range") return True - @check_conn def rpushx(self, name, value): self._get_list(name).append(to_bytes(value)) @_lua_reply(_lua_bool_ok) - @check_conn def ltrim(self, name, start, end): val = self._get_list_or_none(name) if val is not None: @@ -1155,18 +1110,15 @@ def ltrim(self, name, start, end): self._db.setx(name, val[start:end]) return True - @check_conn def lindex(self, name, index): try: return self._get_list(name)[index] except IndexError: return None - @check_conn def lpushx(self, name, value): self._get_list(name).insert(0, to_bytes(value)) - @check_conn @_remove_empty def rpop(self, name): try: @@ -1174,7 +1126,6 @@ def rpop(self, name): except IndexError: return None - @check_conn def linsert(self, name, where, refvalue, value): if where.lower() not in ('before', 'after'): raise redis.ResponseError('syntax error') @@ -1192,7 +1143,6 @@ def linsert(self, name, where, refvalue, value): lst.insert(index, to_bytes(value)) return len(lst) - @check_conn def rpoplpush(self, src, dst): # _get_list instead of _setdefault_list at this point because we # don't want to create the list if nothing gets popped. @@ -1204,7 +1154,6 @@ def rpoplpush(self, src, dst): self._db.setx(dst, dst_list) return el - @check_conn def blpop(self, keys, timeout=0): # This has to be a best effort approximation which follows # these rules: @@ -1223,7 +1172,6 @@ def blpop(self, keys, timeout=0): self._remove_if_empty(key) return ret - @check_conn def brpop(self, keys, timeout=0): if isinstance(keys, string_types): keys = [to_bytes(keys)] @@ -1236,7 +1184,6 @@ def brpop(self, keys, timeout=0): self._remove_if_empty(key) return ret - @check_conn def brpoplpush(self, src, dst, timeout=0): return self.rpoplpush(src, dst) @@ -1252,7 +1199,6 @@ def _setdefault_hash(self, name): raise redis.ResponseError(_WRONGTYPE_MSG) return value - @check_conn @_remove_empty def hdel(self, name, *keys): h = self._get_hash(name) @@ -1263,7 +1209,6 @@ def hdel(self, name, *keys): rem += 1 return rem - @check_conn def hexists(self, name, key): "Returns a boolean indicating if ``key`` exists within hash ``name``" if self._get_hash(name).get(key) is None: @@ -1271,26 +1216,22 @@ def hexists(self, name, key): else: return 1 - @check_conn def hget(self, name, key): "Return the value of ``key`` within the hash ``name``" return self._get_hash(name).get(key) - @check_conn def hgetall(self, name): "Return a Python dict of the hash's name/value pairs" all_items = dict() all_items.update(self._get_hash(name)) return all_items - @check_conn def hincrby(self, name, key, amount=1): "Increment the value of ``key`` in hash ``name`` by ``amount``" new = int(self._setdefault_hash(name).get(key, b'0')) + amount self._db[name][key] = to_bytes(new) return new - @check_conn def hincrbyfloat(self, name, key, amount=1.0): """Increment the value of key in hash name by floating amount""" try: @@ -1305,17 +1246,14 @@ def hincrbyfloat(self, name, key, amount=1.0): self._db[name][key] = to_bytes(new) return new - @check_conn def hkeys(self, name): "Return the list of keys within hash ``name``" return list(self._get_hash(name)) - @check_conn def hlen(self, name): "Return the number of elements in hash ``name``" return len(self._get_hash(name)) - @check_conn def hset(self, name, key, value): """ Set ``key`` to ``value`` within hash ``name`` @@ -1325,7 +1263,6 @@ def hset(self, name, key, value): self._setdefault_hash(name)[key] = to_bytes(value) return 1 if key_is_new else 0 - @check_conn def hsetnx(self, name, key, value): """ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not @@ -1336,7 +1273,6 @@ def hsetnx(self, name, key, value): self._setdefault_hash(name)[key] = to_bytes(value) return True - @check_conn def hmset(self, name, mapping): """ Sets each key in the ``mapping`` dict to its corresponding value @@ -1350,14 +1286,12 @@ def hmset(self, name, mapping): self._setdefault_hash(name).update(new_mapping) return True - @check_conn def hmget(self, name, keys, *args): "Returns a list of values ordered identically to ``keys``" h = self._get_hash(name) all_keys = self._list_or_args(keys, args) return [h.get(k) for k in all_keys] - @check_conn def hvals(self, name): "Return the list of values within hash ``name``" return list(self._get_hash(name).values()) @@ -1374,7 +1308,6 @@ def _setdefault_set(self, name): raise redis.ResponseError(_WRONGTYPE_MSG) return value - @check_conn def sadd(self, name, *values): "Add ``value`` to set ``name``" a_set = self._setdefault_set(name) @@ -1382,12 +1315,10 @@ def sadd(self, name, *values): a_set |= set(to_bytes(x) for x in values) return len(a_set) - card - @check_conn def scard(self, name): "Return the number of elements in set ``name``" return len(self._get_set(name)) - @check_conn def sdiff(self, keys, *args): "Return the difference of sets specified by ``keys``" all_keys = (to_bytes(x) for x in self._list_or_args(keys, args)) @@ -1396,7 +1327,6 @@ def sdiff(self, keys, *args): diff -= self._get_set(key) return diff - @check_conn @_remove_empty def sdiffstore(self, dest, keys, *args): """ @@ -1407,7 +1337,6 @@ def sdiffstore(self, dest, keys, *args): self._db[dest] = set(to_bytes(x) for x in diff) return len(diff) - @check_conn def sinter(self, keys, *args): "Return the intersection of sets specified by ``keys``" all_keys = (to_bytes(x) for x in self._list_or_args(keys, args)) @@ -1416,7 +1345,6 @@ def sinter(self, keys, *args): intersect.intersection_update(self._get_set(key)) return intersect - @check_conn @_remove_empty def sinterstore(self, dest, keys, *args): """ @@ -1427,17 +1355,14 @@ def sinterstore(self, dest, keys, *args): self._db[dest] = set(to_bytes(x) for x in intersect) return len(intersect) - @check_conn def sismember(self, name, value): "Return a boolean indicating if ``value`` is a member of set ``name``" return to_bytes(value) in self._get_set(name) - @check_conn def smembers(self, name): "Return all members of the set ``name``" return self._get_set(name) - @check_conn @_remove_empty def smove(self, src, dst, value): value = to_bytes(value) @@ -1450,7 +1375,6 @@ def smove(self, src, dst, value): except KeyError: return False - @check_conn @_remove_empty def spop(self, name): "Remove and return a random member of set ``name``" @@ -1459,7 +1383,6 @@ def spop(self, name): except KeyError: return None - @check_conn def srandmember(self, name, number=None): """ If ``number`` is None, returns a random member of set ``name``. @@ -1489,7 +1412,6 @@ def srandmember(self, name, number=None): in sorted(random.sample(range(len(members)), number)) ] - @check_conn @_remove_empty def srem(self, name, *values): "Remove ``value`` from set ``name``" @@ -1498,7 +1420,6 @@ def srem(self, name, *values): a_set -= set(to_bytes(x) for x in values) return card - len(a_set) - @check_conn def sunion(self, keys, *args): "Return the union of sets specifiued by ``keys``" all_keys = (to_bytes(x) for x in self._list_or_args(keys, args)) @@ -1507,7 +1428,6 @@ def sunion(self, keys, *args): union.update(self._get_set(key)) return union - @check_conn def sunionstore(self, dest, keys, *args): """ Store the union of sets specified by ``keys`` into a new @@ -1618,7 +1538,6 @@ def _get_lexcomp_and_str(self, value): return comparator, actual_value - @check_conn def zadd(self, name, *args, **kwargs): """ Set any number of score, element-name pairs to the key ``name``. Pairs @@ -1649,12 +1568,10 @@ def zadd(self, name, *args, **kwargs): raise redis.ResponseError("value is not a valid float") return len(zset) - old_len - @check_conn def zcard(self, name): "Return the number of elements in the sorted set ``name``" return len(self._get_zset(name)) - @check_conn def zcount(self, name, min, max): found = 0 filter_func = self._get_zelement_range_filter_func(min, max) @@ -1663,7 +1580,6 @@ def zcount(self, name, min, max): found += 1 return found - @check_conn def zincrby(self, name, value, amount=1): "Increment the score of ``value`` in sorted set ``name`` by ``amount``" d = self._setdefault_zset(name) @@ -1671,7 +1587,6 @@ def zincrby(self, name, value, amount=1): d[value] = score return score - @check_conn @_remove_empty def zinterstore(self, dest, keys, aggregate=None): """ @@ -1703,7 +1618,6 @@ def _apply_score_cast_func(self, items, all_items, withscores, score_cast_func): else: return [(k, score_cast_func(to_bytes(all_items[k]))) for k in items] - @check_conn def zrange(self, name, start, end, desc=False, withscores=False, score_cast_func=float): """ Return a range of values from sorted set ``name`` between @@ -1737,7 +1651,6 @@ def _get_zelements_in_order(self, all_items, reverse=False): in_order = sorted(by_keyname, key=lambda x: x[1], reverse=reverse) return [el[0] for el in in_order] - @check_conn def zrangebyscore(self, name, min, max, start=None, num=None, withscores=False, score_cast_func=float): """ @@ -1771,7 +1684,6 @@ def _zrangebyscore(self, name, min, max, start, num, withscores, score_cast_func matches = matches[start:start + num] return self._apply_score_cast_func(matches, all_items, withscores, score_cast_func) - @check_conn def zrangebylex(self, name, min, max, start=None, num=None): """ @@ -1809,7 +1721,6 @@ def _zrangebylex(self, name, min, max, start, num, reverse): matches = matches[start:start + num] return matches - @check_conn def zrank(self, name, value): """ Returns a 0-based value indicating the rank of ``value`` in sorted set @@ -1822,7 +1733,6 @@ def zrank(self, name, value): except ValueError: return None - @check_conn @_remove_empty def zrem(self, name, *values): "Remove member ``value`` from sorted set ``name``" @@ -1834,7 +1744,6 @@ def zrem(self, name, *values): rem += 1 return rem - @check_conn @_remove_empty def zremrangebyrank(self, name, min, max): """ @@ -1855,7 +1764,6 @@ def zremrangebyrank(self, name, min, max): num_deleted += 1 return num_deleted - @check_conn @_remove_empty def zremrangebyscore(self, name, min, max): """ @@ -1871,7 +1779,6 @@ def zremrangebyscore(self, name, min, max): removed += 1 return removed - @check_conn @_remove_empty def zremrangebylex(self, name, min, max): """ @@ -1893,7 +1800,6 @@ def zremrangebylex(self, name, min, max): removed += 1 return removed - @check_conn def zlexcount(self, name, min, max): """ Returns a count of elements in the sorted set ``name`` @@ -1913,7 +1819,6 @@ def zlexcount(self, name, min, max): found += 1 return found - @check_conn def zrevrange(self, name, start, num, withscores=False, score_cast_func=float): """ Return a range of values from sorted set ``name`` between @@ -1928,7 +1833,6 @@ def zrevrange(self, name, start, num, withscores=False, score_cast_func=float): """ return self.zrange(name, start, num, True, withscores, score_cast_func) - @check_conn def zrevrangebyscore(self, name, max, min, start=None, num=None, withscores=False, score_cast_func=float): """ @@ -1946,7 +1850,6 @@ def zrevrangebyscore(self, name, max, min, start=None, num=None, return self._zrangebyscore(name, min, max, start, num, withscores, score_cast_func, reverse=True) - @check_conn def zrevrangebylex(self, name, max, min, start=None, num=None): """ @@ -1966,7 +1869,6 @@ def zrevrangebylex(self, name, max, min, return self._zrangebylex(name, min, max, start, num, reverse=True) - @check_conn def zrevrank(self, name, value): """ Returns a 0-based value indicating the descending rank of @@ -1977,7 +1879,6 @@ def zrevrank(self, name, value): if zrank is not None: return num_items - self.zrank(name, value) - 1 - @check_conn def zscore(self, name, value): "Return the score of element ``value`` in sorted set ``name``" all_items = self._get_zset(name) @@ -1986,7 +1887,6 @@ def zscore(self, name, value): except KeyError: return None - @check_conn def zunionstore(self, dest, keys, aggregate=None): """ Union multiple sorted sets specified by ``keys`` into @@ -2043,7 +1943,6 @@ def _list_or_args(self, keys, args): keys.extend(args) return keys - @check_conn def pipeline(self, transaction=True, shard_hint=None): """Return an object that can be used to issue Redis commands in a batch. @@ -2053,7 +1952,6 @@ def pipeline(self, transaction=True, shard_hint=None): """ return FakePipeline(self, transaction) - @check_conn def transaction(self, func, *keys, **kwargs): shard_hint = kwargs.pop('shard_hint', None) value_from_callable = kwargs.pop('value_from_callable', False) @@ -2076,12 +1974,10 @@ def transaction(self, func, *keys, **kwargs): continue raise redis.WatchError('Could not run transaction after 5 tries') - @check_conn def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None, lock_class=None, thread_local=True): return _Lock(self, name, timeout) - @check_conn def pubsub(self, ignore_subscribe_messages=False): """ Returns a new FakePubSub instance @@ -2092,7 +1988,6 @@ def pubsub(self, ignore_subscribe_messages=False): return ps - @check_conn def publish(self, channel, message): """ Loops through all available pubsub objects and publishes the @@ -2109,7 +2004,6 @@ def publish(self, channel, message): return count # HYPERLOGLOG COMMANDS - @check_conn def pfadd(self, name, *values): "Adds the specified elements to the specified HyperLogLog." # Simulate the behavior of HyperLogLog by using SETs underneath to @@ -2120,7 +2014,6 @@ def pfadd(self, name, *values): # - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise. return 1 if result > 0 else 0 - @check_conn def pfcount(self, *sources): """ Return the approximated cardinality of @@ -2128,7 +2021,6 @@ def pfcount(self, *sources): """ return len(self.sunion(*sources)) - @check_conn @_lua_reply(_lua_bool_ok) def pfmerge(self, dest, *sources): "Merge N different HyperLogLogs into a single one." @@ -2160,15 +2052,12 @@ def _scan(self, keys, cursor, match, count): result_cursor = 0 return result_cursor, result_data - @check_conn def scan(self, cursor=0, match=None, count=None): return self._scan(self.keys(), int(cursor), match, count or 10) - @check_conn def sscan(self, name, cursor=0, match=None, count=None): return self._scan(self.smembers(name), int(cursor), match, count or 10) - @check_conn def hscan(self, name, cursor=0, match=None, count=None): cursor, keys = self._scan(self.hkeys(name), int(cursor), match, count or 10) results = {} @@ -2176,7 +2065,6 @@ def hscan(self, name, cursor=0, match=None, count=None): results[k] = self.hget(name, k) return cursor, results - @check_conn def scan_iter(self, match=None, count=None): # This is from redis-py cursor = '0' @@ -2185,7 +2073,6 @@ def scan_iter(self, match=None, count=None): for item in data: yield item - @check_conn def sscan_iter(self, name, match=None, count=None): # This is from redis-py cursor = '0' @@ -2195,7 +2082,6 @@ def sscan_iter(self, name, match=None, count=None): for item in data: yield item - @check_conn def hscan_iter(self, name, match=None, count=None): # This is from redis-py cursor = '0' From fa63151d2ac78d4f6327da581b51b11a3d2153f5 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Tue, 10 Jul 2018 14:51:46 -0500 Subject: [PATCH 14/24] Adjusted patch_responses call for FakePubSub class --- fakeredis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fakeredis.py b/fakeredis.py index 9845fd0..fe1801b 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -2238,7 +2238,7 @@ def __init__(self, decode_responses=False, *args, **kwargs): self._q = Queue() self.subscribed = False if decode_responses: - _patch_responses(self) + _patch_responses(self, _make_decode_func) self._decode_responses = decode_responses self.ignore_subscribe_messages = kwargs.get( 'ignore_subscribe_messages', False) From 0fa4e3cadd9a76c7bf20a6b5160a192fed3de1a0 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Tue, 10 Jul 2018 14:52:09 -0500 Subject: [PATCH 15/24] Adjusted test cases for access to connected attribute --- test_fakeredis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test_fakeredis.py b/test_fakeredis.py index 53c5d15..1a8171f 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -4079,12 +4079,12 @@ def test_persist(self): self.redis.persist('key') def test_rename(self): - self.redis._connected = True + self.redis.connected = True self.redis.set('key1', 'value') - self.redis._connected = False + self.redis.connected = False with self.assertRaises(redis.ConnectionError): self.redis.rename('key1', 'key2') - self.redis._connected = True + self.redis.connected = True self.assertTrue(self.redis.exists('key1')) def test_watch(self): From ed56f4a7cafc4ce9895ba83c559799ff4758c92c Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 11 Jul 2018 11:12:42 -0500 Subject: [PATCH 16/24] Connection error decorator should be patched to response regardless of initial state of connection --- fakeredis.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index fe1801b..8c547df 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -370,8 +370,7 @@ def __init__(self, db=0, charset='utf-8', errors='strict', if decode_responses: _patch_responses(self, _make_decode_func) - if not connected: - _patch_responses(self, _check_conn) + _patch_responses(self, _check_conn) @_lua_reply(_lua_bool_ok) def flushdb(self): From dc3e59e71b39b615c8fc732444e3e4228f6393d6 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 11 Jul 2018 11:12:59 -0500 Subject: [PATCH 17/24] ConnectionError message --- fakeredis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fakeredis.py b/fakeredis.py index 8c547df..c5d4a37 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -336,7 +336,7 @@ def _check_conn(func): @functools.wraps(func) def func_wrapper(*args, **kwargs): if not func.__self__.connected: - raise redis.ConnectionError + raise redis.ConnectionError("FakeRedis is emulating a connection error.") return func(*args, **kwargs) return func_wrapper From bfe919bd76d22a4eae31b9f0b47172b51bf1c087 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 11 Jul 2018 11:20:17 -0500 Subject: [PATCH 18/24] Updated README with documentation on mocking connection errors --- README.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.rst b/README.rst index 6d5f97a..0d61825 100644 --- a/README.rst +++ b/README.rst @@ -74,6 +74,23 @@ test you run, be sure to call `r.flushall()` in your Alternatively, you can create an instance that does not share data with other instances, by passing `singleton=False` to the constructor. +It is also possible to mock connection errors so you can effectively test +your error handling. Simply pass `connected=False` to the constructor or +set the connected attribute to `False` after initialization. + +.. code-block:: python + >>> import fakeredis + >>> r = fakeredis.FakeStrictRedis(connected=False) + >>> r.set('foo', 'bar') + Traceback (most recent call last): + File "", line 1, in + File "~/fakeredis/fakeredis.py", line 339, in func_wrapper + raise redis.ConnectionError("FakeRedis is emulating a connection error.") + redis.exceptions.ConnectionError: FakeRedis is emulating a connection error. + >>> r.connected = True + >>> r.set('foo', 'bar') + True + Fakeredis implements the same interface as `redis-py`_, the popular redis client for python, and models the responses of redis 2.6. From 03615decff7f425a71abdf86acb1949421a64085 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 11 Jul 2018 11:23:38 -0500 Subject: [PATCH 19/24] Proper indentation in README --- README.rst | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/README.rst b/README.rst index 0d61825..b75255e 100644 --- a/README.rst +++ b/README.rst @@ -79,17 +79,17 @@ your error handling. Simply pass `connected=False` to the constructor or set the connected attribute to `False` after initialization. .. code-block:: python - >>> import fakeredis - >>> r = fakeredis.FakeStrictRedis(connected=False) - >>> r.set('foo', 'bar') - Traceback (most recent call last): - File "", line 1, in - File "~/fakeredis/fakeredis.py", line 339, in func_wrapper - raise redis.ConnectionError("FakeRedis is emulating a connection error.") - redis.exceptions.ConnectionError: FakeRedis is emulating a connection error. - >>> r.connected = True - >>> r.set('foo', 'bar') - True + >>> import fakeredis + >>> r = fakeredis.FakeStrictRedis(connected=False) + >>> r.set('foo', 'bar') + Traceback (most recent call last): + File "", line 1, in + File "~/fakeredis/fakeredis.py", line 339, in func_wrapper + raise redis.ConnectionError("FakeRedis is emulating a connection error.") + redis.exceptions.ConnectionError: FakeRedis is emulating a connection error. + >>> r.connected = True + >>> r.set('foo', 'bar') + True Fakeredis implements the same interface as `redis-py`_, the popular redis client for python, and models the responses From 71e89ba5407973113302a888f2c2f613ff294ca2 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 11 Jul 2018 16:34:51 -0500 Subject: [PATCH 20/24] _check_conn should be patched on before decode_responses --- fakeredis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fakeredis.py b/fakeredis.py index c5d4a37..d90971f 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -367,11 +367,11 @@ def __init__(self, db=0, charset='utf-8', errors='strict', self._pubsubs = [] self._decode_responses = decode_responses self.connected = connected + _patch_responses(self, _check_conn) + if decode_responses: _patch_responses(self, _make_decode_func) - _patch_responses(self, _check_conn) - @_lua_reply(_lua_bool_ok) def flushdb(self): self._db.clear() From 8412e2589bb2e3565dfd98cbeac15f4801b58b02 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Wed, 11 Jul 2018 16:41:36 -0500 Subject: [PATCH 21/24] Spacing fix on documentation --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index b75255e..d311c44 100644 --- a/README.rst +++ b/README.rst @@ -79,6 +79,7 @@ your error handling. Simply pass `connected=False` to the constructor or set the connected attribute to `False` after initialization. .. code-block:: python + >>> import fakeredis >>> r = fakeredis.FakeStrictRedis(connected=False) >>> r.set('foo', 'bar') From 6f301b39451dbe083e027a955cbb444041880218 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Thu, 26 Jul 2018 15:27:35 -0500 Subject: [PATCH 22/24] Patched connection error decorator to PubSub class. Added test cases --- fakeredis.py | 4 +++- test_fakeredis.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/fakeredis.py b/fakeredis.py index d90971f..8a83f91 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -2231,11 +2231,13 @@ class FakePubSub(object): PATTERN_MESSAGE_TYPES = ['psubscribe', 'punsubscribe'] LISTEN_DELAY = 0.1 # delay between listen loops (seconds) - def __init__(self, decode_responses=False, *args, **kwargs): + def __init__(self, connected=True, decode_responses=False, *args, **kwargs): self.channels = {} self.patterns = {} self._q = Queue() self.subscribed = False + self.connected = connected + _patch_responses(self, _check_conn) if decode_responses: _patch_responses(self, _make_decode_func) self._decode_responses = decode_responses diff --git a/test_fakeredis.py b/test_fakeredis.py index 1a8171f..9c45a19 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -4403,5 +4403,42 @@ def test_hscan_iter(self): self.redis.hscan_iter('name') +class TestPubSubConnected(unittest.TestCase): + + def setUp(self): + self.redis = fakeredis.FakePubSub(connected=False) + + def tearDown(self): + del self.redis + + def test_basic_subscript(self): + with self.assertRaises(redis.ConnectionError): + self.redis.subscribe('logs') + + def test_subscript_conn_lost(self): + self.redis.connected = True + self.redis.subscribe('logs') + self.redis.connected = False + with self.assertRaises(redis.ConnectionError): + self.redis.get_message() + + def test_put_listen(self): + self.redis.connected = True + count = self.redis.put('logs', 'mymessage', 'subscribe') + self.assertEqual(count, 1, 'Message could should be 1') + self.redis.connected = False + with self.assertRaises(redis.ConnectionError): + self.redis.get_message() + self.redis.connected = True + msg = self.redis.get_message() + check = { + 'type': 'subscribe', + 'pattern': None, + 'channel': b'logs', + 'data': 'mymessage' + } + self.assertEqual(msg, check, 'Message was not published to channel') + + if __name__ == '__main__': unittest.main() From 75dced3ee2d2b311e845ae2cf2c7ecbe0752ba70 Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Fri, 27 Jul 2018 03:14:05 -0500 Subject: [PATCH 23/24] Connected arg after decode arg in FakePubSub --- fakeredis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fakeredis.py b/fakeredis.py index 8a83f91..b27eff5 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -2231,7 +2231,7 @@ class FakePubSub(object): PATTERN_MESSAGE_TYPES = ['psubscribe', 'punsubscribe'] LISTEN_DELAY = 0.1 # delay between listen loops (seconds) - def __init__(self, connected=True, decode_responses=False, *args, **kwargs): + def __init__(self, decode_responses=False, connected=True, *args, **kwargs): self.channels = {} self.patterns = {} self._q = Queue() From 28cae30883d8e2863db15eaf6c62925e972c43ba Mon Sep 17 00:00:00 2001 From: Adam Mertz Date: Fri, 27 Jul 2018 03:14:33 -0500 Subject: [PATCH 24/24] Renamed test attribute redis to pubsub --- test_fakeredis.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/test_fakeredis.py b/test_fakeredis.py index 9c45a19..e66bd3b 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -3984,7 +3984,6 @@ def test_searches_for_c_stdlib_and_raises_if_missing(self): class TestFakeStrictRedisConnectionErrors(unittest.TestCase): - def create_redis(self): return fakeredis.FakeStrictRedis(db=0, connected=False) @@ -4404,33 +4403,35 @@ def test_hscan_iter(self): class TestPubSubConnected(unittest.TestCase): + def create_redis(self): + return fakeredis.FakePubSub(connected=False) def setUp(self): - self.redis = fakeredis.FakePubSub(connected=False) + self.pubsub = self.create_redis() def tearDown(self): - del self.redis + del self.pubsub def test_basic_subscript(self): with self.assertRaises(redis.ConnectionError): - self.redis.subscribe('logs') + self.pubsub.subscribe('logs') def test_subscript_conn_lost(self): - self.redis.connected = True - self.redis.subscribe('logs') - self.redis.connected = False + self.pubsub.connected = True + self.pubsub.subscribe('logs') + self.pubsub.connected = False with self.assertRaises(redis.ConnectionError): - self.redis.get_message() + self.pubsub.get_message() def test_put_listen(self): - self.redis.connected = True - count = self.redis.put('logs', 'mymessage', 'subscribe') + self.pubsub.connected = True + count = self.pubsub.put('logs', 'mymessage', 'subscribe') self.assertEqual(count, 1, 'Message could should be 1') - self.redis.connected = False + self.pubsub.connected = False with self.assertRaises(redis.ConnectionError): - self.redis.get_message() - self.redis.connected = True - msg = self.redis.get_message() + self.pubsub.get_message() + self.pubsub.connected = True + msg = self.pubsub.get_message() check = { 'type': 'subscribe', 'pattern': None,