Skip to content

Commit

Permalink
Implement XGROUP DELCONSUMER, XGROUP CREATECONSUMER, and `XINFO C…
Browse files Browse the repository at this point in the history
…ONSUMERS`

Fix #162
Fix #163
Fix #167
  • Loading branch information
cunla committed Jun 11, 2023
1 parent 126cb2b commit fb03707
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 23 deletions.
66 changes: 51 additions & 15 deletions fakeredis/_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import bisect
import time
from dataclasses import dataclass
from typing import List, Union, Tuple, Optional, NamedTuple, Dict

from fakeredis._commands import BeforeAny, AfterAny
Expand Down Expand Up @@ -30,15 +31,59 @@ def format_record(self):
return [self.key.encode(), results]


current_time = lambda: int(time.time() * 1000)


@dataclass
class StreamConsumerInfo(object):
name: bytes
pending: int = 0
last_attempt: int = current_time()
last_success: int = current_time()

def info(self) -> List[bytes]:
curr_time = current_time()
return [
b'name', self.name,
b'pending', self.pending,
b'idle', curr_time - self.last_attempt,
b'inactive', curr_time - self.last_success,
]


class StreamGroup(object):
def __init__(self, name: bytes, start_index: int, entries_read: int = None):
def __init__(self, stream: 'XStream', name: bytes, start_index: int, entries_read: int = None):
self.stream = stream
self.name = name
self.start_index = start_index
self.entries_read = entries_read
self.consumers = list()
# consumer_name -> #pending_messages
self.consumers: Dict[bytes, StreamConsumerInfo] = dict()
self.last_delivered_index = start_index
self.last_ack_index = start_index

def set_id(self, last_delivered_str: bytes, entries_read: Union[int, None]) -> None:
"""Set last_delivered_id for group
"""
self.start_index, _ = self.stream.find_index_key_as_str(last_delivered_str)
self.entries_read = entries_read

def add_consumer(self, consumer_name: bytes) -> int:
if consumer_name in self.consumers:
return 0
self.consumers[consumer_name] = StreamConsumerInfo()
return 1

def del_consumer(self, consumer_name: bytes) -> int:
if consumer_name not in self.consumers:
return 0
res = self.consumers[consumer_name].pending
del self.consumers[consumer_name]
return res

def consumers_info(self):
return [self.consumers[k].info() for k in self.consumers]


class StreamRangeTest:
"""Argument converter for sorted set LEX endpoints."""
Expand Down Expand Up @@ -84,6 +129,9 @@ def __init__(self):
self._values: List[StreamEntry] = list()
self._groups: Dict[bytes, StreamGroup] = dict()

def group_get(self, group_name: bytes) -> StreamGroup:
return self._groups.get(group_name, None)

def group_add(self, name: bytes, start_key_str: bytes, entries_read: Union[int, None]) -> None:
"""Add a group listening to stream
Expand All @@ -93,26 +141,14 @@ def group_add(self, name: bytes, start_key_str: bytes, entries_read: Union[int,
"""
start_index, found = self.find_index_key_as_str(start_key_str)
start_index -= (0 if found else -1)
self._groups[name] = StreamGroup(name, start_index, entries_read)
self._groups[name] = StreamGroup(self, name, start_index, entries_read)

def group_delete(self, group_name: bytes) -> int:
if group_name in self._groups:
del self._groups[group_name]
return 1
return 0

def group_set_id(self, group_name: bytes, last_delivered_str: bytes, entries_read: Union[int, None]) -> bool:
"""Set last_delivered_id for group
:returns: True if successful, False if the group is not found.
"""
group = self._groups.get(group_name, None)
if group is None:
return False
group.start_index, _ = self.find_index_key_as_str(last_delivered_str)
group.entries_read = entries_read
return True

def groups_info(self):
res = []
for group in self._groups.values():
Expand Down
38 changes: 33 additions & 5 deletions fakeredis/commands_mixins/streams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import Key, command, CommandItem
from fakeredis._helpers import SimpleError, casematch, OK
from fakeredis._stream import XStream, StreamRangeTest
from fakeredis._stream import XStream, StreamRangeTest, StreamGroup


class StreamsCommandsMixin:
Expand Down Expand Up @@ -128,20 +128,48 @@ def xgroup_setid(self, key, group_name, start_key, *args):
(entries_read,), _ = extract_args(args, ('+entriesread',))
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
res = key.value.group_set_id(group_name, start_key, entries_read)
if not res:
group = key.value.group_get(group_name)
if not group:
raise SimpleError(msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(key, group_name))
group.set_id(start_key, entries_read)
return OK

@command(name="XGROUP DESTROY", fixed=(Key(XStream), bytes,), repeat=(), )
def xgroup_destroy(self, key, group_name,):
def xgroup_destroy(self, key, group_name, ):
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
res = key.value.group_delete(group_name)
return res

@command(name="XGROUP CREATECONSUMER", fixed=(Key(XStream), bytes, bytes), repeat=(), )
def xgroup_createconsumer(self, key, group_name, consumer_name):
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = key.value.group_get(group_name)
if not group:
raise SimpleError(msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(key, group_name))
return group.add_consumer(consumer_name)

@command(name="XGROUP DELCONSUMER", fixed=(Key(XStream), bytes, bytes), repeat=(), )
def xgroup_delconsumer(self, key, group_name, consumer_name):
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = key.value.group_get(group_name)
if not group:
raise SimpleError(msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(key, group_name))
return group.del_consumer(consumer_name)

@command(name="XINFO GROUPS", fixed=(Key(XStream),), repeat=(), )
def xinfo_groups(self, key,):
def xinfo_groups(self, key, ):
if key.value is None:
raise SimpleError(msgs.NO_KEY_MSG)
return key.value.groups_info()

@command(name="XINFO CONSUMERS", fixed=(Key(XStream), bytes), repeat=(), )
def xinfo_consumers(self, key, group_name, ):
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = key.value.group_get(group_name)
if not group:
raise SimpleError(msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(key, group_name))
return group.consumers_info()
3 changes: 0 additions & 3 deletions test/test_mixins/test_streams_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ def test_xclaim_trimmed(r: redis.Redis):
assert item[0][0] == sid2


@pytest.mark.xfail
def test_xgroup_delconsumer(r: redis.Redis):
stream, group, consumer = "stream", "group", "consumer"
r.xadd(stream, {"foo": "bar"})
Expand All @@ -433,7 +432,6 @@ def test_xgroup_delconsumer(r: redis.Redis):
assert r.xgroup_delconsumer(stream, group, consumer) == 2


@pytest.mark.xfail
def test_xgroup_createconsumer(r: redis.Redis):
stream, group, consumer = "stream", "group", "consumer"
r.xadd(stream, {"foo": "bar"})
Expand All @@ -448,7 +446,6 @@ def test_xgroup_createconsumer(r: redis.Redis):
assert r.xgroup_delconsumer(stream, group, consumer) == 2


@pytest.mark.xfail
def test_xinfo_consumers(r: redis.Redis):
stream, group, consumer1, consumer2 = "stream", "group", "consumer1", "consumer2"
r.xadd(stream, {"foo": "bar"})
Expand Down

0 comments on commit fb03707

Please sign in to comment.