Skip to content

Commit

Permalink
Remove need for packaging
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Jun 6, 2023
1 parent b2dbd49 commit 38c4742
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 76 deletions.
6 changes: 3 additions & 3 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Any, Tuple

import redis
from packaging.version import Version


if redis.VERSION >= (5, 0):
from redis.parsers import BaseParser
Expand Down Expand Up @@ -331,11 +331,11 @@ def _ttl(self, key, scale):
return int(round((key.expireat - self._db.time) * scale))

def _encodefloat(self, value, humanfriendly):
if self.version >= Version('7'):
if self.version >= (7,):
value = 0 + value
return Float.encode(value, humanfriendly)

def _encodeint(self, value):
if self.version >= Version('7'):
if self.version >= (7,):
value = 0 + value
return Int.encode(value)
6 changes: 5 additions & 1 deletion fakeredis/_command_args_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def extract_args(
expected: Tuple[str, ...],
error_on_unexpected: bool = True,
left_from_first_unexpected: bool = True,
exception=None
) -> Tuple[List, List]:
"""Parse argument values
Expand Down Expand Up @@ -135,7 +136,10 @@ def _parse_params(

if not found:
if error_on_unexpected:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
raise (
SimpleError(msgs.SYNTAX_ERROR_MSG)
if exception is None
else SimpleError(exception.format(actual_args[i])))
if left_from_first_unexpected:
return results, actual_args[i:]
left_args.append(actual_args[i])
Expand Down
7 changes: 3 additions & 4 deletions fakeredis/_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import functools
import math
import re

from packaging.version import Version
from typing import Tuple

from . import _msgs as msgs
from ._helpers import null_terminate, SimpleError, SimpleString
Expand Down Expand Up @@ -317,14 +316,14 @@ def __init__(self, name, func_name, fixed, repeat=(), args=(), flags=""):
self.flags = set(flags)
self.command_args = args

def check_arity(self, args, version):
def check_arity(self, args, version: Tuple[int]):
if len(args) != len(self.fixed):
delta = len(args) - len(self.fixed)
if delta < 0 or not self.repeat:
msg = msgs.WRONG_ARGS_MSG6.format(self.name)
raise SimpleError(msg)
if delta % len(self.repeat) != 0:
msg = msgs.WRONG_ARGS_MSG7 if version >= Version('7') else msgs.WRONG_ARGS_MSG6.format(self.name)
msg = msgs.WRONG_ARGS_MSG7 if version >= (7,) else msgs.WRONG_ARGS_MSG6.format(self.name)
raise SimpleError(msg)

def apply(self, args, db, version):
Expand Down
22 changes: 12 additions & 10 deletions fakeredis/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import warnings
import weakref
from collections import defaultdict
from typing import Dict
from typing import Dict, Tuple

import redis
from packaging.version import Version

from fakeredis._fakesocket import FakeSocket
from fakeredis._helpers import (Database, FakeSelector)
Expand All @@ -19,18 +18,21 @@
LOGGER = logging.getLogger('fakeredis')


def _create_version(v) -> Version:
if isinstance(v, Version):
def _create_version(v) -> Tuple[int]:
if isinstance(v, tuple):
return v
if isinstance(v, int):
return Version(str(v))
return Version(v)
return (v,)
if isinstance(v, str):
v = v.split('.')
return tuple(int(x) for x in v)
return v


class FakeServer:
_servers_map: Dict[str, 'FakeServer'] = dict()

def __init__(self, version: Version = Version("7")):
def __init__(self, version: Tuple[int] = (7,)):
self.lock = threading.Lock()
self.dbs = defaultdict(lambda: Database(self.lock))
# Maps channel/pattern to weak set of sockets
Expand All @@ -43,7 +45,7 @@ def __init__(self, version: Version = Version("7")):
self.version = _create_version(version)

@staticmethod
def get_server(key, version: Version):
def get_server(key, version: Tuple[int]):
return FakeServer._servers_map.setdefault(key, FakeServer(version=version))


Expand All @@ -54,7 +56,7 @@ def __init__(self, *args, **kwargs):
self._selector = None
self._server = kwargs.pop('server', None)
path = kwargs.pop('path', None)
version = kwargs.pop('version', Version('7.0'))
version = kwargs.pop('version', (7, 0))
connected = kwargs.pop('connected', True)
if self._server is None:
if path:
Expand Down Expand Up @@ -131,7 +133,7 @@ def __str__(self):


class FakeRedisMixin:
def __init__(self, *args, server=None, connected=True, version=Version('7'), **kwargs):
def __init__(self, *args, server=None, connected=True, version=(7,), **kwargs):
# Interpret the positional and keyword arguments according to the
# version of redis in use.
parameters = inspect.signature(redis.Redis.__init__).parameters
Expand Down
10 changes: 4 additions & 6 deletions fakeredis/commands_mixins/bitmap_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from packaging.version import Version

from fakeredis import _msgs as msgs
from fakeredis._commands import (command, Key, Int, BitOffset, BitValue, fix_range_string, fix_range)
from fakeredis._helpers import SimpleError, casematch
Expand All @@ -17,10 +15,10 @@ def bitpos(self, key, bit, *args):
raise SimpleError(msgs.BIT_ARG_MUST_BE_ZERO_OR_ONE)
if len(args) > 3:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if len(args) == 3 and self.version < Version('7'):
if len(args) == 3 and self.version < (7,):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
bit_mode = False
if len(args) == 3 and self.version >= Version('7'):
if len(args) == 3 and self.version >= (7,):
bit_mode = casematch(args[2], b'bit')
if not bit_mode and not casematch(args[2], b'byte'):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
Expand Down Expand Up @@ -56,9 +54,9 @@ def bitcount(self, key, *args):
start = Int.decode(args[0])
end = Int.decode(args[1])
bit_mode = False
if len(args) == 3 and self.version < Version('7'):
if len(args) == 3 and self.version < (7,):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if len(args) == 3 and self.version >= Version('7'):
if len(args) == 3 and self.version >= (7,):
bit_mode = casematch(args[2], b'bit')
if not bit_mode and not casematch(args[2], b'byte'):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
Expand Down
21 changes: 3 additions & 18 deletions fakeredis/commands_mixins/generic_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import pickle
import random

from packaging.version import Version

from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import (
Expand Down Expand Up @@ -43,22 +41,9 @@ def _lookup_key(self, key, pattern):
return item.value

def _expireat(self, key, timestamp, *args):
nx = False
xx = False
gt = False
lt = False
for arg in args:
if casematch(b'nx', arg):
nx = True
elif casematch(b'xx', arg):
xx = True
elif casematch(b'gt', arg):
gt = True
elif casematch(b'lt', arg):
lt = True
else:
raise SimpleError(msgs.EXPIRE_UNSUPPORTED_OPTION.format(arg))
if self.version < Version('7') and any((nx, xx, gt, lt)):
(nx, xx, gt, lt,), _ = extract_args(
args, ('nx', 'xx', 'gt', 'lt',), exception=msgs.EXPIRE_UNSUPPORTED_OPTION,)
if self.version < (7,) and any((nx, xx, gt, lt)):
raise SimpleError(msgs.WRONG_ARGS_MSG6.format('expire'))
counter = (nx, gt, lt).count(True)
if (counter > 1) or (nx and xx):
Expand Down
6 changes: 2 additions & 4 deletions fakeredis/commands_mixins/pubsub_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from packaging.version import Version

from fakeredis import _msgs as msgs
from fakeredis._commands import (command)
from fakeredis._helpers import (NoResponse, compile_pattern, SimpleError)
Expand Down Expand Up @@ -93,7 +91,7 @@ def pubsub(self, *args):

@command(name='PUBSUB HELP', fixed=())
def pubsub_help(self, *args):
if self.version >= Version('7'):
if self.version >= (7,):
help_strings = [
'PUBSUB <subcommand> [<arg> [value] [opt] ...]. Subcommands are:',
'CHANNELS [<pattern>]',
Expand All @@ -111,7 +109,7 @@ def pubsub_help(self, *args):
' Return the number of subscribers for the specified shard level channel(s'
')',
'HELP',
(' Prints this help.' if self.version < Version('7.1') else ' Print this help.'),
(' Prints this help.' if self.version < (7, 1) else ' Print this help.'),
]
else:
help_strings = [
Expand Down
8 changes: 3 additions & 5 deletions fakeredis/commands_mixins/scripting_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import itertools
import logging

from packaging.version import Version

from fakeredis import _msgs as msgs
from fakeredis._commands import (command, Int)
from fakeredis._helpers import (SimpleError, SimpleString, null_terminate, OK, encode_command)
Expand Down Expand Up @@ -179,7 +177,7 @@ def eval(self, script, numkeys, *keys_and_args):
try:
result = lua_runtime.execute(script)
except SimpleError as ex:
if self.version <= Version('6'):
if self.version < (7,):
raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex))
raise SimpleError(ex.value)
except LuaError as ex:
Expand Down Expand Up @@ -208,7 +206,7 @@ def script_load(self, *args):

@command(name='script exists', fixed=(), repeat=(bytes,), flags=msgs.FLAG_NO_SCRIPT, )
def script_exists(self, *args):
if self.version >= Version('7') and len(args) == 0:
if self.version >= (7,) and len(args) == 0:
raise SimpleError(msgs.WRONG_ARGS_MSG7)
return [int(sha1 in self.script_cache) for sha1 in args]

Expand Down Expand Up @@ -244,7 +242,7 @@ def script_help(self, *args):
'LOAD <script>',
' Load a script into the scripts cache without executing it.',
'HELP',
(' Prints this help.' if self.version < Version('7.1') else ' Print this help.'),
(' Prints this help.' if self.version < (7, 1) else ' Print this help.'),
]

return [s.encode() for s in help_strings]
4 changes: 1 addition & 3 deletions fakeredis/commands_mixins/set_mixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import random

from packaging.version import Version

from fakeredis import _msgs as msgs
from fakeredis._commands import (command, Key, Int, CommandItem)
from fakeredis._helpers import (OK, SimpleError, casematch)
Expand Down Expand Up @@ -69,7 +67,7 @@ def sinter(self, *keys):

@command((Int, bytes), (bytes,))
def sintercard(self, numkeys, *args):
if self.version < Version('7'):
if self.version < (7,):
raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format('sintercard'))
if numkeys < 1:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
Expand Down
4 changes: 1 addition & 3 deletions fakeredis/commands_mixins/sortedset_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import math
from typing import Union, Optional

from packaging.version import Version

from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import (command, Key, Int, Float, CommandItem, Timeout, ScoreTest, StringTest, fix_range)
Expand Down Expand Up @@ -97,7 +95,7 @@ def zadd(self, key, *args):
raise SimpleError(msgs.ZADD_INCR_LEN_ERROR_MSG)
# Parse all scores first, before updating
items = [
((0.0 + Float.decode(elements[j]) if self.version >= Version('7')
((0.0 + Float.decode(elements[j]) if self.version >= (7,)
else Float.decode(elements[j]), elements[j + 1]))
for j in range(0, len(elements), 2)
]
Expand Down
4 changes: 1 addition & 3 deletions fakeredis/commands_mixins/streams_mixin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import functools
from typing import List

from packaging.version import Version

import fakeredis._msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import Key, command, CommandItem
Expand All @@ -22,7 +20,7 @@ def xadd(self, key, *args):
if not elements or len(elements) % 2 != 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format('XADD'))
stream = key.value or XStream()
if self.version < Version('7') and entry_key != b'*' and not StreamRangeTest.valid_key(entry_key):
if self.version < (7,) and entry_key != b'*' and not StreamRangeTest.valid_key(entry_key):
raise SimpleError(msgs.XADD_INVALID_ID)
entry_key = stream.add(elements, entry_key=entry_key)
if entry_key is None:
Expand Down
4 changes: 2 additions & 2 deletions fakeredis/commands_mixins/string_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math

from packaging.version import Version


from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
Expand Down Expand Up @@ -160,7 +160,7 @@ def set_(self, key, value, *args):

if (xx and nx) or ((px is not None) + (ex is not None) + keepttl > 1):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if nx and get and self.version < Version('7'):
if nx and get and self.version < (7,):
# The command docs say this is allowed from Redis 7.0.
raise SimpleError(msgs.SYNTAX_ERROR_MSG)

Expand Down
18 changes: 9 additions & 9 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pytest
import pytest_asyncio
import redis
from packaging.version import Version

import fakeredis
from fakeredis._server import _create_version


@pytest_asyncio.fixture(scope="session")
Expand All @@ -26,7 +26,7 @@ def real_redis_version() -> Union[None, str]:
@pytest_asyncio.fixture(name='fake_server')
def _fake_server(request):
min_server_marker = request.node.get_closest_marker('min_server')
server_version = Version(min_server_marker.args[0]) if min_server_marker else Version('6.2')
server_version = min_server_marker.args[0] if min_server_marker else '6.2'
server = fakeredis.FakeServer(version=server_version)
server.connected = request.node.get_closest_marker('disconnected') is None
return server
Expand All @@ -48,8 +48,8 @@ def r(request, create_redis) -> redis.Redis:
def _marker_version_value(request, marker_name: str):
marker_value = request.node.get_closest_marker(marker_name)
if marker_value is None:
return Version(str(0 if marker_name == 'min_server' else 100))
return Version(marker_value.args[0])
return (0,) if marker_name == 'min_server' else (100,)
return _create_version(marker_value.args[0])


@pytest_asyncio.fixture(
Expand All @@ -64,13 +64,13 @@ def _create_redis(request) -> Callable[[int], redis.Redis]:
server_version = request.getfixturevalue('real_redis_version')
if not cls_name.startswith('Fake') and not server_version:
pytest.skip('Redis is not running')
server_version = server_version or '6'
server_version = _create_version(server_version) or (6,)
min_server = _marker_version_value(request, 'min_server')
max_server = _marker_version_value(request, 'max_server')
if Version(server_version) < min_server:
pytest.skip(f'Redis server {min_server.base_version} or more required but {server_version} found')
if Version(server_version) > max_server:
pytest.skip(f'Redis server {max_server.base_version} or less required but {server_version} found')
if server_version < min_server:
pytest.skip(f'Redis server {min_server} or more required but {server_version} found')
if server_version > max_server:
pytest.skip(f'Redis server {max_server} or less required but {server_version} found')
decode_responses = request.node.get_closest_marker('decode_responses') is not None

def factory(db=0):
Expand Down
Loading

0 comments on commit 38c4742

Please sign in to comment.