Skip to content

Commit

Permalink
feat:implement ZMPOP and BZMPOP
Browse files Browse the repository at this point in the history
Fix #191, #186
  • Loading branch information
cunla committed Jul 12, 2023
1 parent 4342889 commit 5aa5e91
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions fakeredis/commands_mixins/sortedset_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,34 @@

class SortedSetCommandsMixin:
# Sorted set commands
def _zpop(self, key, count, reverse):
def _zpop(self, key: CommandItem, count: int, reverse: bool, flatten_list: bool) -> List[bytes]:
zset = key.value
members = list(zset)
if reverse:
members.reverse()
members = members[:count]
res = [(bytes(member), self._encodefloat(zset.get(member), True)) for member in members]
res = list(itertools.chain.from_iterable(res))
res = [[bytes(member), self._encodefloat(zset.get(member), True)] for member in members]
if flatten_list:
res = list(itertools.chain.from_iterable(res))
for item in members:
zset.discard(item)
return res

def _bzpop(self, keys, reverse, first_pass):
def _bzpop(self, keys: List[CommandItem], reverse: bool, first_pass: bool) -> List[bytes]:
for key in keys:
item = CommandItem(key, self._db, item=self._db.get(key), default=[])
temp_res = self._zpop(item, 1, reverse)
temp_res = self._zpop(item, 1, reverse, True)
if temp_res:
return [key, temp_res[0], temp_res[1]]
return None

@command((Key(ZSet),), (Int,))
def zpopmin(self, key, count=1):
return self._zpop(key, count, False)
return self._zpop(key, count, reverse=False, flatten_list=True)

@command((Key(ZSet),), (Int,))
def zpopmax(self, key, count=1):
return self._zpop(key, count, True)
return self._zpop(key, count, reverse=True, flatten_list=True)

@command((bytes, bytes), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def bzpopmin(self, *args):
Expand Down Expand Up @@ -513,18 +514,14 @@ def zrandmember(self, key: CommandItem, *args) -> list[Optional[float]]:
def _encodefloat(self, value, humanfriendly):
raise NotImplementedError # Implemented in BaseFakeSocket

def _zmpop(self, keys, count, direction_left, first_pass):
if direction_left:
op = lambda count: slice(None, count) # noqa:E731
else:
op = lambda count: slice(None, -count - 1, -1) # noqa:E731

def _zmpop(self, keys, count, reverse, first_pass):
for key in keys:
item = CommandItem(key, self._db, item=self._db.get(key), default=[])
res = _list_pop_count(op, item, count)
res = self._zpop(item, count, reverse, flatten_list=False)
if res:
return [key, res]
return None

@command(fixed=(Int,), repeat=(bytes,))
def zmpop(self, numkeys, *args):
if numkeys == 0:
Expand All @@ -534,10 +531,10 @@ def zmpop(self, numkeys, *args):
args = args[:-2]
else:
count = 1
if len(args) != numkeys + 1 or (not casematch(args[-1], b'left') and not casematch(args[-1], b'right')):
if len(args) != numkeys + 1 or (not casematch(args[-1], b'min') and not casematch(args[-1], b'max')):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)

return self._zmpop(args[:-1], count, casematch(args[-1], b'left'), False)
return self._zmpop(args[:-1], count, casematch(args[-1], b'max'), False)

@command(fixed=(Timeout, Int,), repeat=(bytes,))
def bzmpop(self, timeout, numkeys, *args):
Expand All @@ -548,7 +545,7 @@ def bzmpop(self, timeout, numkeys, *args):
args = args[:-2]
else:
count = 1
if len(args) != numkeys + 1 or (not casematch(args[-1], b'left') and not casematch(args[-1], b'right')):
if len(args) != numkeys + 1 or (not casematch(args[-1], b'min') and not casematch(args[-1], b'max')):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)

return self._blocking(timeout, functools.partial(self._zmpop, args[:-1], count, casematch(args[-1], b'left')))
return self._blocking(timeout, functools.partial(self._zmpop, args[:-1], count, casematch(args[-1], b'max')))

0 comments on commit 5aa5e91

Please sign in to comment.