Skip to content

Commit

Permalink
fix: use same server_key within pipeline when issued watch (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleg committed Jul 18, 2023
1 parent 4027ada commit 640faeb
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 14 deletions.
23 changes: 9 additions & 14 deletions fakeredis/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, *args, **kwargs):
self.server_key = path
else:
host, port = kwargs.get('host'), kwargs.get('port')
self.server_key = uuid.uuid4().hex if host is None or port is None else f'{host}:{port}'
self.server_key = f'{host}:{port}'
self.server_key += f'v{version}'
self._server = FakeServer.get_server(self.server_key, version=version)
self._server.connected = connected
Expand Down Expand Up @@ -136,15 +136,13 @@ class FakeRedisMixin:
def __init__(self, *args, server=None, connected=True, version=(7,), **kwargs):
# Interpret the positional and keyword arguments according to the
# version of redis in use.
parameters = inspect.signature(redis.Redis.__init__).parameters
parameter_names = list(parameters.keys())
default_args = parameters.values()
ignore_default_param_values = {'host', 'port', 'db'}
kwds = {p.name: p.default
for p in default_args
if (p.default != inspect.Parameter.empty
and p.name not in ignore_default_param_values)}
kwds.update(kwargs)
parameters = list(inspect.signature(redis.Redis.__init__).parameters.values())[1:]
# Convert args => kwargs
kwargs.update({parameters[i].name: args[i] for i in range(len(args))})
kwargs.setdefault('host', uuid.uuid4().hex)
kwds = {p.name: kwargs.get(p.name, p.default)
for ind, p in enumerate(parameters)
if p.default != inspect.Parameter.empty}
if not kwds.get('connection_pool', None):
charset = kwds.get('charset', None)
errors = kwds.get('errors', None)
Expand Down Expand Up @@ -183,10 +181,7 @@ def __init__(self, *args, server=None, connected=True, version=(7,), **kwargs):
kwds.pop('server', None)
kwds.pop('connected', None)
kwds.pop('version', None)
parameter_names_to_cut = parameter_names[1:len(args) + 1]
for param in parameter_names_to_cut:
kwds.pop(param, None)
super().__init__(*args, **kwds)
super().__init__(**kwds)

@classmethod
def from_url(cls, *args, **kwargs):
Expand Down
35 changes: 35 additions & 0 deletions test/test_init_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,38 @@ def test_same_connection_params(self):
r1.set('foo', 'bar')
assert r2.get('foo') == b'bar'
assert not r3.exists('foo')

def test_new_server_with_positional_args(self):
from fakeredis import FakeRedis

# same host, default port and db index
fake_redis_1 = FakeRedis('localhost')
fake_redis_2 = FakeRedis('localhost')

fake_redis_1.set("foo", "bar")

assert fake_redis_2.get("foo") == b'bar'

# same host and port
fake_redis_1 = FakeRedis('localhost', 6000)
fake_redis_2 = FakeRedis('localhost', 6000)

fake_redis_1.set("foo", "bar")

assert fake_redis_2.get("foo") == b'bar'

# same connection parameters, but different db index
fake_redis_1 = FakeRedis('localhost', 6000, 0)
fake_redis_2 = FakeRedis('localhost', 6000, 1)

fake_redis_1.set("foo", "bar")

assert fake_redis_2.get("foo") is None

# mix of positional arguments and keyword args
fake_redis_1 = FakeRedis('localhost', port=6000, db=0)
fake_redis_2 = FakeRedis('localhost', port=6000, db=1)

fake_redis_1.set("foo", "bar")

assert fake_redis_2.get("foo") is None
32 changes: 32 additions & 0 deletions test/test_mixins/test_transactions_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,35 @@ def test_socket_cleanup_watch(fake_server):
sock = pipeline.connection._sock # noqa: F841
pipeline.connection.disconnect()
r2.set('test', 'foo')


def test_get_within_pipeline(r: redis.Redis):
r.set("test", "foo")
r.set("test2", "foo2")
expected_keys = set(r.keys())
with r.pipeline() as p:
assert set(r.keys()) == expected_keys
p.watch("test")
assert set(r.keys()) == expected_keys

@pytest.mark.fake
def test_get_within_pipeline_w_host():
r = fakeredis.FakeRedis('localhost')
r.set("test", "foo")
r.set("test2", "foo2")
expected_keys = set(r.keys())
with r.pipeline() as p:
assert set(r.keys()) == expected_keys
p.watch("test")
assert set(r.keys()) == expected_keys

@pytest.mark.fake
def test_get_within_pipeline_no_args():
r = fakeredis.FakeRedis()
r.set("test", "foo")
r.set("test2", "foo2")
expected_keys = set(r.keys())
with r.pipeline() as p:
assert set(r.keys()) == expected_keys
p.watch("test")
assert set(r.keys()) == expected_keys

0 comments on commit 640faeb

Please sign in to comment.