diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 8a63c8c4..1658dacc 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -63,7 +63,7 @@ def __init__(self, *args, **kwargs): self.server_key = path else: host, port = kwargs.get('host'), kwargs.get('port') - self.server_key = uuid.uuid4().hex if host is None or port is None else f'{host}:{port}' + self.server_key = f'{host}:{port}' self.server_key += f'v{version}' self._server = FakeServer.get_server(self.server_key, version=version) self._server.connected = connected @@ -136,15 +136,13 @@ class FakeRedisMixin: def __init__(self, *args, server=None, connected=True, version=(7,), **kwargs): # Interpret the positional and keyword arguments according to the # version of redis in use. - parameters = inspect.signature(redis.Redis.__init__).parameters - parameter_names = list(parameters.keys()) - default_args = parameters.values() - ignore_default_param_values = {'host', 'port', 'db'} - kwds = {p.name: p.default - for p in default_args - if (p.default != inspect.Parameter.empty - and p.name not in ignore_default_param_values)} - kwds.update(kwargs) + parameters = list(inspect.signature(redis.Redis.__init__).parameters.values())[1:] + # Convert args => kwargs + kwargs.update({parameters[i].name: args[i] for i in range(len(args))}) + kwargs.setdefault('host', uuid.uuid4().hex) + kwds = {p.name: kwargs.get(p.name, p.default) + for ind, p in enumerate(parameters) + if p.default != inspect.Parameter.empty} if not kwds.get('connection_pool', None): charset = kwds.get('charset', None) errors = kwds.get('errors', None) @@ -183,10 +181,7 @@ def __init__(self, *args, server=None, connected=True, version=(7,), **kwargs): kwds.pop('server', None) kwds.pop('connected', None) kwds.pop('version', None) - parameter_names_to_cut = parameter_names[1:len(args) + 1] - for param in parameter_names_to_cut: - kwds.pop(param, None) - super().__init__(*args, **kwds) + super().__init__(**kwds) @classmethod def from_url(cls, *args, **kwargs): diff --git a/test/test_init_args.py b/test/test_init_args.py index 6be04264..ebfca890 100644 --- a/test/test_init_args.py +++ b/test/test_init_args.py @@ -129,3 +129,38 @@ def test_same_connection_params(self): r1.set('foo', 'bar') assert r2.get('foo') == b'bar' assert not r3.exists('foo') + + def test_new_server_with_positional_args(self): + from fakeredis import FakeRedis + + # same host, default port and db index + fake_redis_1 = FakeRedis('localhost') + fake_redis_2 = FakeRedis('localhost') + + fake_redis_1.set("foo", "bar") + + assert fake_redis_2.get("foo") == b'bar' + + # same host and port + fake_redis_1 = FakeRedis('localhost', 6000) + fake_redis_2 = FakeRedis('localhost', 6000) + + fake_redis_1.set("foo", "bar") + + assert fake_redis_2.get("foo") == b'bar' + + # same connection parameters, but different db index + fake_redis_1 = FakeRedis('localhost', 6000, 0) + fake_redis_2 = FakeRedis('localhost', 6000, 1) + + fake_redis_1.set("foo", "bar") + + assert fake_redis_2.get("foo") is None + + # mix of positional arguments and keyword args + fake_redis_1 = FakeRedis('localhost', port=6000, db=0) + fake_redis_2 = FakeRedis('localhost', port=6000, db=1) + + fake_redis_1.set("foo", "bar") + + assert fake_redis_2.get("foo") is None diff --git a/test/test_mixins/test_transactions_commands.py b/test/test_mixins/test_transactions_commands.py index b6fd3368..d1abe269 100644 --- a/test/test_mixins/test_transactions_commands.py +++ b/test/test_mixins/test_transactions_commands.py @@ -307,3 +307,35 @@ def test_socket_cleanup_watch(fake_server): sock = pipeline.connection._sock # noqa: F841 pipeline.connection.disconnect() r2.set('test', 'foo') + + +def test_get_within_pipeline(r: redis.Redis): + r.set("test", "foo") + r.set("test2", "foo2") + expected_keys = set(r.keys()) + with r.pipeline() as p: + assert set(r.keys()) == expected_keys + p.watch("test") + assert set(r.keys()) == expected_keys + +@pytest.mark.fake +def test_get_within_pipeline_w_host(): + r = fakeredis.FakeRedis('localhost') + r.set("test", "foo") + r.set("test2", "foo2") + expected_keys = set(r.keys()) + with r.pipeline() as p: + assert set(r.keys()) == expected_keys + p.watch("test") + assert set(r.keys()) == expected_keys + +@pytest.mark.fake +def test_get_within_pipeline_no_args(): + r = fakeredis.FakeRedis() + r.set("test", "foo") + r.set("test2", "foo2") + expected_keys = set(r.keys()) + with r.pipeline() as p: + assert set(r.keys()) == expected_keys + p.watch("test") + assert set(r.keys()) == expected_keys