diff --git a/fakeredis/commands_mixins/pubsub_mixin.py b/fakeredis/commands_mixins/pubsub_mixin.py index b08422d5..582e04b1 100644 --- a/fakeredis/commands_mixins/pubsub_mixin.py +++ b/fakeredis/commands_mixins/pubsub_mixin.py @@ -73,6 +73,10 @@ def publish(self, channel, message): receivers += 1 return receivers + @command(name='PUBSUB NUMPAT', fixed=(), repeat=()) + def pubsub_numpat(self, *_): + return len(self._server.psubscribers) + @command(name='PUBSUB CHANNELS', fixed=(), repeat=(bytes,)) def pubsub_channels(self, *args): channels = list(self._server.subscribers.keys()) diff --git a/test/test_mixins/test_pubsub_commands.py b/test/test_mixins/test_pubsub_commands.py index 4b5551b3..6fb87131 100644 --- a/test/test_mixins/test_pubsub_commands.py +++ b/test/test_mixins/test_pubsub_commands.py @@ -1,16 +1,33 @@ import threading +import time import uuid from queue import Queue from time import sleep +from typing import Optional, Dict, Any import pytest import redis +from redis.client import PubSub import fakeredis from .. import testtools from ..testtools import raw_command +def wait_for_message( + pubsub: PubSub, timeout=0.5, ignore_subscribe_messages=False +) -> Optional[Dict[str, Any]]: + now = time.time() + timeout = now + timeout + while now < timeout: + message = pubsub.get_message(ignore_subscribe_messages=ignore_subscribe_messages) + if message is not None: + return message + time.sleep(0.01) + now = time.time() + return None + + def test_ping_pubsub(r: redis.Redis): p = r.pubsub() p.subscribe('channel') @@ -40,6 +57,15 @@ def test_pubsub_subscribe(r: redis.Redis): assert message == expected_message +@pytest.mark.slow +def test_pubsub_numpat(r: redis.Redis): + p = r.pubsub() + p.psubscribe("*oo", "*ar", "b*z") + for i in range(3): + assert wait_for_message(p)["type"] == "psubscribe" + assert r.pubsub_numpat() == 3 + + @pytest.mark.slow def test_pubsub_psubscribe(r: redis.Redis): pubsub = r.pubsub()