Skip to content

Commit

Permalink
Implement GEO commands
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Feb 24, 2023
1 parent 6d729f6 commit ed00059
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 82 deletions.
1 change: 1 addition & 0 deletions fakeredis/_msgs.py
Expand Up @@ -63,3 +63,4 @@
FLAG_NO_SCRIPT = 's' # Command not allowed in scripts
FLAG_LEAVE_EMPTY_VAL = 'v'
FLAG_TRANSACTION = 't'
GEO_UNSUPPORTED_UNIT = 'unsupported unit provided. please use M, KM, FT, MI'
152 changes: 109 additions & 43 deletions fakeredis/commands_mixins/geo_mixin.py
@@ -1,37 +1,94 @@
import sys
from collections import namedtuple
from typing import List, Optional, Any
from typing import List, Any

from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import command, Key, Float
from fakeredis._commands import command, Key, Float, CommandItem
from fakeredis._helpers import SimpleError
from fakeredis._zset import ZSet
from fakeredis.geo import geohash
from fakeredis.geo.haversine import distance

UNIT_TO_M = {'km': 0.001, 'mi': 0.000621371, 'ft': 3.28084, 'm': 1}


def translate_meters_to_unit(unit_arg: bytes) -> float:
unit_str = unit_arg.decode().lower()
if unit_str == 'km':
unit = 0.001
elif unit_str == 'mi':
unit = 0.000621371
elif unit_str == 'ft':
unit = 3.28084
else: # meter
unit = 1
"""number of meters in a unit.
:param unit_arg: unit name (km, mi, ft, m)
:returns: number of meters in unit
"""
unit = UNIT_TO_M.get(unit_arg.decode().lower())
if unit is None:
raise SimpleError(msgs.GEO_UNSUPPORTED_UNIT)
return unit


GeoResult = namedtuple('GeoResult', 'name long lat hash distance')


def _parse_results(
items: List[GeoResult],
withcoord: bool, withdist: bool) -> List[Any]:
"""Parse list of GeoResults to redis response
:param withcoord: include coordinates in response
:param withdist: include distance in response
:returns: Parsed list
"""
res = list()
for item in items:
new_item = [item.name, ]
if withdist:
new_item.append(Float.encode(item.distance, False))
if withcoord:
new_item.append([Float.encode(item.long, False),
Float.encode(item.lat, False)])
if len(new_item) == 1:
new_item = new_item[0]
res.append(new_item)
return res


def _find_near(
zset: ZSet,
lat: float, long: float, radius: float,
conv: float, count: int, count_any: bool, desc: bool) -> List[GeoResult]:
"""Find items within area (lat,long)+radius
:param zset: list of items to check
:param lat: latitude
:param long: longitude
:param radius: radius in whatever units
:param conv: conversion of radius to meters
:param count: number of results to give
:param count_any: should we return any results that match? (vs. sorted)
:param desc: should results be sorted descending order?
:returns: List of GeoResults
"""
results = list()
for name, _hash in zset.items():
p_lat, p_long, _, _ = geohash.decode(_hash)
dist = distance((p_lat, p_long), (lat, long)) * conv
if dist < radius:
results.append(GeoResult(name, p_long, p_lat, _hash, dist))
if count_any and len(results) >= count:
break
results = sorted(results, key=lambda x: x.distance, reverse=desc)
if count:
results = results[:count]
return results


class GeoCommandsMixin:
# TODO
# GEORADIUS, GEORADIUS_RO,
# GEORADIUSBYMEMBER, GEORADIUSBYMEMBER_RO,
# GEOSEARCH, GEOSEARCHSTORE
def _store_geo_results(self, item_name: bytes, geo_results: List[GeoResult], scoredist: bool) -> int:
db_item = CommandItem(item_name, self._db, item=self._db.get(item_name), default=ZSet())
db_item.value = ZSet()
for item in geo_results:
val = item.distance if scoredist else item.hash
db_item.value.add(item.name, val)
db_item.writeback()
return len(geo_results)

@command(name='GEOADD', fixed=(Key(ZSet),), repeat=(bytes,))
def geoadd(self, key, *args):
Expand Down Expand Up @@ -83,42 +140,51 @@ def geodist(self, key, m1, m2, *args):
unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1
return res * unit

def _parse_results(
self, items: List[GeoResult],
withcoord: bool, withdist: bool, withhash: bool,
count: Optional[int], desc: bool) -> List[Any]:
items = sorted(items, key=lambda x: x.distance, reverse=desc)
if count:
items = items[:count]
res = list()
for item in items:
new_item = [item.name, ]
if withdist:
new_item.append(self._encodefloat(item.distance, False))
if withcoord:
new_item.append([self._encodefloat(item.long, False),
self._encodefloat(item.lat, False)])
if len(new_item) == 1:
new_item = new_item[0]
res.append(new_item)
return res
def _search(
self, key, long, lat, radius, conv,
withcoord, withdist, withhash, count, count_any, desc, store, storedist):
zset = key.value
geo_results = _find_near(zset, lat, long, radius, conv, count, count_any, desc)

if store:
self._store_geo_results(store, geo_results, scoredist=False)
return len(geo_results)
if storedist:
self._store_geo_results(storedist, geo_results, scoredist=True)
return len(geo_results)
ret = _parse_results(geo_results, withcoord, withdist)
return ret

@command(name='GEORADIUS_RO', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,))
def georadius_ro(self, key, long, lat, radius, *args):
(withcoord, withdist, withhash, count, count_any, desc), left_args = extract_args(
args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc',),
error_on_unexpected=False, left_from_first_unexpected=False)
count = count or sys.maxsize
conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
return self._search(
key, long, lat, radius, conv,
withcoord, withdist, withhash, count, count_any, desc, False, False)

@command(name='GEORADIUS', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,))
def georadius(self, key, long, lat, radius, *args):
zset = key.value
results = list()
(withcoord, withdist, withhash, count, count_any, desc, store, storedist), left_args = extract_args(
args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc', '*store', '*storedist'),
error_on_unexpected=False, left_from_first_unexpected=False)
unit = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
count = count or sys.maxsize
conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
return self._search(
key, long, lat, radius, conv,
withcoord, withdist, withhash, count, count_any, desc, store, storedist)

for name, _hash in zset.items():
p_lat, p_long, _, _ = geohash.decode(_hash)
dist = distance((p_lat, p_long), (lat, long)) * unit
if dist < radius:
results.append(GeoResult(name, p_long, p_lat, _hash, dist))
if count_any and len(results) >= count:
break
@command(name='GEORADIUSBYMEMBER', fixed=(Key(ZSet), bytes, Float), repeat=(bytes,))
def georadiusbymember(self, key, member_name, radius, *args):
member_score = key.value.get(member_name)
lat, long, _, _ = geohash.decode(member_score)
return self.georadius(key, long, lat, radius, *args)

return self._parse_results(results, withcoord, withdist, withhash, count, desc)
@command(name='GEORADIUSBYMEMBER_RO', fixed=(Key(ZSet), bytes, Float), repeat=(bytes,))
def georadiusbymember_ro(self, key, member_name, radius, *args):
member_score = key.value.get(member_name)
lat, long, _, _ = geohash.decode(member_score)
return self.georadius_ro(key, long, lat, radius, *args)
104 changes: 65 additions & 39 deletions test/test_mixins/test_geo_commands.py
@@ -1,6 +1,10 @@
from typing import Dict, Any

import pytest
import redis

from test import testtools


def test_geoadd(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
Expand All @@ -10,13 +14,20 @@ def test_geoadd(r: redis.Redis):

values = (2.1909389952632, 41.433791470673, "place1")
assert r.geoadd("a", values) == 1

values = ((2.1909389952632, 31.433791470673, "place1") +
(2.1873744593677, 41.406342043777, "place2",))
assert r.geoadd("a", values, ch=True) == 2
assert r.zrange("a", 0, -1) == [b"place1", b"place2"]

with pytest.raises(redis.RedisError):
with pytest.raises(redis.DataError):
r.geoadd("barcelona", (1, 2))
with pytest.raises(redis.DataError):
r.geoadd("t", values, ch=True, nx=True, xx=True)
with pytest.raises(redis.ResponseError):
testtools.raw_command(r, "geoadd", "barcelona", "1", "2")
with pytest.raises(redis.ResponseError):
testtools.raw_command(r, "geoadd", "barcelona", "nx", "xx", *values, )


def test_geoadd_xx(r: redis.Redis):
Expand Down Expand Up @@ -91,56 +102,38 @@ def test_geodist_missing_one_member(r: redis.Redis):
assert r.geodist("barcelona", "place1", "missing_member", "km") is None


def test_georadius(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, b"\x80place2"))

r.geoadd("barcelona", values)
assert r.georadius("barcelona", 2.191, 41.433, 1000) == [b"place1"]
assert r.georadius("barcelona", 2.187, 41.406, 1000) == [b"\x80place2"]


def test_georadius_no_values(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, "place2",))

r.geoadd("barcelona", values)
assert r.georadius("barcelona", 1, 2, 1000) == []


def test_georadius_units(r: redis.Redis):
@pytest.mark.parametrize(
"long,lat,radius,extra,expected", [
(2.191, 41.433, 1000, {}, [b"place1"]),
(2.187, 41.406, 1000, {}, [b"place2"]),
(1, 2, 1000, {}, []),
(2.191, 41.433, 1, {"unit": "km"}, [b"place1"]),
(2.191, 41.433, 3000, {"count": 1}, [b"place1"]),
])
def test_georadius(
r: redis.Redis, long: float, lat: float, radius: float,
extra: Dict[str, Any],
expected):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, "place2",))

(2.1873744593677, 41.406342043777, b"place2"))
r.geoadd("barcelona", values)
assert r.georadius("barcelona", 2.191, 41.433, 1, unit="km") == [b"place1"]
assert r.georadius("barcelona", long, lat, radius, **extra) == expected


def test_georadius_with(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, "place2",))

r.geoadd("barcelona", values)

# test a bunch of combinations to test the parse response
# function.
# test a bunch of combinations to test the parse response function.
res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True, )
assert res == [pytest.approx([
b"place1",
0.0881,
pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)
], 0.001)]
assert res == [pytest.approx([b"place1", 0.0881, pytest.approx((2.1909, 41.4337), 0.0001)], 0.001)]

res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True)
assert res == [pytest.approx([
b"place1",
0.0881,
pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)
], 0.001)]
assert res == [pytest.approx([b"place1", 0.0881, pytest.approx((2.1909, 41.4337), 0.0001)], 0.001)]

assert r.georadius(
"barcelona", 2.191, 41.433, 1, unit="km", withcoord=True
) == [[b"place1", pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)]]
res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withcoord=True)
assert res == [[b"place1", pytest.approx((2.1909, 41.4337), 0.0001)]]

# test no values.
assert (r.georadius("barcelona", 2, 1, 1, unit="km", withdist=True, withcoord=True, ) == [])
Expand All @@ -151,6 +144,39 @@ def test_georadius_count(r: redis.Redis):
(2.1873744593677, 41.406342043777, "place2",))

r.geoadd("barcelona", values)
assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1) == [b"place1"]

assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1, store='barcelona') == 1
assert r.georadius("barcelona", 2.191, 41.433, 3000, store_dist='extract') == 1
assert r.zcard("extract") == 1
res = r.georadius("barcelona", 2.191, 41.433, 3000, count=1, any=True)
assert (res == [b"place2"]) or res == [b'place1']

values = ((13.361389, 38.115556, "Palermo") +
(15.087269, 37.502669, "Catania",))

r.geoadd("Sicily", values)
assert testtools.raw_command(
r, "GEORADIUS", "Sicily", "15", "37", "200", "km",
"STOREDIST", "neardist", "STORE", "near") == 2
assert r.zcard("near") == 2
assert r.zcard("neardist") == 0


def test_georadius_errors(r: redis.Redis):
values = ((13.361389, 38.115556, "Palermo") +
(15.087269, 37.502669, "Catania",))

r.geoadd("Sicily", values)

with pytest.raises(redis.DataError): # Unsupported unit
r.georadius("barcelona", 2.191, 41.433, 3000, unit='dsf')
with pytest.raises(redis.ResponseError): # Unsupported unit
testtools.raw_command(
r, "GEORADIUS", "Sicily", "15", "37", "200", "ddds",
"STOREDIST", "neardist", "STORE", "near")

bad_values = (13.361389, 38.115556, "Palermo", 15.087269, "Catania",)
with pytest.raises(redis.DataError):
r.geoadd('newgroup', bad_values)
with pytest.raises(redis.ResponseError):
testtools.raw_command(r, 'geoadd', 'newgroup', *bad_values)

0 comments on commit ed00059

Please sign in to comment.