diff --git a/docs/about/changelog.md b/docs/about/changelog.md index 535ae7b1..15b184ab 100644 --- a/docs/about/changelog.md +++ b/docs/about/changelog.md @@ -8,8 +8,9 @@ description: Change log of all fakeredis releases ## v2.18.0 ### 🚀 Features - -- Implement `PUBSUB NUMPAT` #195 + +- Implement `PUBSUB NUMPAT` #195, `SSUBSCRIBE` #199, `SPUBLISH` #198, + `SUNSUBSCRIBE` #200, ### 🧰 Bug Fixes diff --git a/docs/redis-commands/Redis.md b/docs/redis-commands/Redis.md index 55076797..dd46249e 100644 --- a/docs/redis-commands/Redis.md +++ b/docs/redis-commands/Redis.md @@ -1287,7 +1287,7 @@ Internal commands for debugging HyperLogLog values. An internal command for testing HyperLogLog values. -## `pubsub` commands (9/15 implemented) +## `pubsub` commands (13/15 implemented) ### [PSUBSCRIBE](https://redis.io/commands/psubscribe/) @@ -1309,6 +1309,10 @@ Returns the active channels. Returns helpful text about the different subcommands. +### [PUBSUB NUMPAT](https://redis.io/commands/pubsub-numpat/) + +Returns a count of unique pattern subscriptions. + ### [PUBSUB NUMSUB](https://redis.io/commands/pubsub-numsub/) Returns a count of subscribers to channels. @@ -1317,10 +1321,22 @@ Returns a count of subscribers to channels. Stops listening to messages published to channels that match one or more patterns. +### [SPUBLISH](https://redis.io/commands/spublish/) + +Post a message to a shard channel + +### [SSUBSCRIBE](https://redis.io/commands/ssubscribe/) + +Listens for messages published to shard channels. + ### [SUBSCRIBE](https://redis.io/commands/subscribe/) Listens for messages published to channels. +### [SUNSUBSCRIBE](https://redis.io/commands/sunsubscribe/) + +Stops listening to messages posted to shard channels. + ### [UNSUBSCRIBE](https://redis.io/commands/unsubscribe/) Stops listening to messages posted to channels. @@ -1329,10 +1345,6 @@ Stops listening to messages posted to channels. ### Unsupported pubsub commands > To implement support for a command, see [here](../../guides/implement-command/) -#### [PUBSUB NUMPAT](https://redis.io/commands/pubsub-numpat/) (not implemented) - -Returns a count of unique pattern subscriptions. - #### [PUBSUB SHARDCHANNELS](https://redis.io/commands/pubsub-shardchannels/) (not implemented) Returns the active shard channels. @@ -1341,18 +1353,6 @@ Returns the active shard channels. Returns the count of subscribers of shard channels. -#### [SPUBLISH](https://redis.io/commands/spublish/) (not implemented) - -Post a message to a shard channel - -#### [SSUBSCRIBE](https://redis.io/commands/ssubscribe/) (not implemented) - -Listens for messages published to shard channels. - -#### [SUNSUBSCRIBE](https://redis.io/commands/sunsubscribe/) (not implemented) - -Stops listening to messages posted to shard channels. - ## `set` commands (17/17 implemented) diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py index 77105d40..ef448d01 100644 --- a/fakeredis/_basefakesocket.py +++ b/fakeredis/_basefakesocket.py @@ -35,7 +35,10 @@ def bin_reverse(x, bits_count): class BaseFakeSocket: - ACCEPTED_COMMANDS_WHILE_PUBSUB = {'ping', 'subscribe', 'unsubscribe', 'psubscribe', 'punsubscribe', 'quit', } + ACCEPTED_COMMANDS_WHILE_PUBSUB = { + 'ping', 'subscribe', 'unsubscribe', 'psubscribe', 'punsubscribe', 'quit', + 'ssubscribe', 'sunsubscribe', + } _connection_error_class = redis.ConnectionError def __init__(self, server, db, *args, **kwargs): diff --git a/fakeredis/_server.py b/fakeredis/_server.py index d2b10d75..740c57c0 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -38,6 +38,7 @@ def __init__(self, version: Tuple[int] = (7,)): # Maps channel/pattern to weak set of sockets self.subscribers = defaultdict(weakref.WeakSet) self.psubscribers = defaultdict(weakref.WeakSet) + self.ssubscribers = defaultdict(weakref.WeakSet) self.lastsave = int(time.time()) self.connected = True # List of weakrefs to sockets that are being closed lazily diff --git a/fakeredis/commands_mixins/pubsub_mixin.py b/fakeredis/commands_mixins/pubsub_mixin.py index 582e04b1..f43ad84c 100644 --- a/fakeredis/commands_mixins/pubsub_mixin.py +++ b/fakeredis/commands_mixins/pubsub_mixin.py @@ -48,6 +48,10 @@ def psubscribe(self, *patterns): def subscribe(self, *channels): return self._subscribe(channels, self._server.subscribers, b'subscribe') + @command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT) + def ssubscribe(self, *channels): + return self._subscribe(channels, self._server.ssubscribers, b'ssubscribe') + @command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT) def punsubscribe(self, *patterns): return self._unsubscribe(patterns, self._server.psubscribers, b'punsubscribe') @@ -56,6 +60,10 @@ def punsubscribe(self, *patterns): def unsubscribe(self, *channels): return self._unsubscribe(channels, self._server.subscribers, b'unsubscribe') + @command(fixed=(), repeat=(bytes,), flags=msgs.FLAG_NO_SCRIPT) + def sunsubscribe(self, *channels): + return self._unsubscribe(channels, self._server.ssubscribers, b'sunsubscribe') + @command((bytes, bytes)) def publish(self, channel, message): receivers = 0 @@ -73,6 +81,23 @@ def publish(self, channel, message): receivers += 1 return receivers + @command((bytes, bytes)) + def spublish(self, channel, message): + receivers = 0 + msg = [b'smessage', channel, message] + subs = self._server.ssubscribers.get(channel, set()) + for sock in subs: + sock.put_response(msg) + receivers += 1 + for (pattern, socks) in self._server.psubscribers.items(): + regex = compile_pattern(pattern) + if regex.match(channel): + msg = [b'pmessage', pattern, channel, message] + for sock in socks: + sock.put_response(msg) + receivers += 1 + return receivers + @command(name='PUBSUB NUMPAT', fixed=(), repeat=()) def pubsub_numpat(self, *_): return len(self._server.psubscribers) diff --git a/test/test_mixins/test_pubsub_commands.py b/test/test_mixins/test_pubsub_commands.py index 6fb87131..a9c85459 100644 --- a/test/test_mixins/test_pubsub_commands.py +++ b/test/test_mixins/test_pubsub_commands.py @@ -28,6 +28,15 @@ def wait_for_message( return None +def make_message(_type, channel, data, pattern=None): + return { + "type": _type, + "pattern": pattern and pattern.encode("utf-8") or None, + "channel": channel and channel.encode("utf-8") or None, + "data": data.encode("utf-8") if isinstance(data, str) else data, + } + + def test_ping_pubsub(r: redis.Redis): p = r.pubsub() p.subscribe('channel') @@ -455,3 +464,60 @@ def test_pubsub_numsub(r: redis.Redis): assert r.pubsub_numsub(a, b, c) == [(a.encode(), 2), (b.encode(), 2), (c.encode(), 1), ] assert r.pubsub_numsub() == [] assert r.pubsub_numsub(a, "non-existing") == [(a.encode(), 2), (b"non-existing", 0)] + + +@testtools.run_test_if_redispy_ver('above', '5.0.0rc2') +def test_published_message_to_shard_channel(r: redis.Redis): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p) == make_message("ssubscribe", "foo", 1) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + + +@testtools.run_test_if_redispy_ver('above', '5.0.0rc2') +def test_subscribe_property_with_shard_channels_cluster(r: redis.Redis): + p = r.pubsub() + keys = ["foo", "bar", "uni" + chr(4456) + "code"] + assert p.subscribed is False + p.ssubscribe(keys[0]) + # we're now subscribed even though we haven't processed the reply from the server just yet + assert p.subscribed is True + assert wait_for_message(p) == make_message("ssubscribe", keys[0], 1) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all shard_channels + p.sunsubscribe() + # we're still technically subscribed until we process the response messages from the server + assert p.subscribed is True + assert wait_for_message(p) == make_message("sunsubscribe", keys[0], 0) + # now we're no longer subscribed as no more messages can be delivered to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + p.ssubscribe(keys[0]) + assert p.subscribed is True + assert wait_for_message(p) == make_message("ssubscribe", keys[0], 1) + + # unsubscribe again + p.sunsubscribe() + assert p.subscribed is True + # subscribe to another shard_channel before reading the unsubscribe response + p.ssubscribe(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert wait_for_message(p) == make_message("sunsubscribe", keys[0], 0) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert wait_for_message(p) == make_message("ssubscribe", keys[1], 1) + p.sunsubscribe() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert wait_for_message(p) == make_message("sunsubscribe", keys[1], 0) + # now we're finally unsubscribed + assert p.subscribed is False