Skip to content

Commit

Permalink
Merge 4093498 into 2b6087e
Browse files Browse the repository at this point in the history
  • Loading branch information
Eugene Eeo committed Jul 23, 2014
2 parents 2b6087e + 4093498 commit 8e047dd
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 45 deletions.
25 changes: 25 additions & 0 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,28 @@ def test(value):
assert not query({'val': 40})
assert not query({'val': '44'})
assert not query({'': None})


def test_nested_query():
query = where('user.name') == 'john'

assert query({'user':{'name':'john'}})
assert not query({'user':{'name':'don'}})
assert not query({})


def test_each_query():
query = where('user').any('followers') == 'don'

assert query({'user':{'followers':['john', 'don']}})
assert not query({'user':{'followers':1}})
assert not query({})

query = ~query
assert query({})

query = where('user').any('followers').matches('\\d+')
assert query({'user':{'followers':['12']}})

query = where('user').any('followers').test(lambda x:x)
assert query({'user':{'followers':[0,1,0]}})
155 changes: 110 additions & 45 deletions tinydb/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,47 @@
__all__ = ('Query',)


def methodproxy(attribute, method):
"""
Utility function for delegating a specific method to another
given attribute.
:param attribute: The instance attribute to delegate to
:param method: A method that the attribute object must have
"""
def fn(self, *args, **kwargs):
getattr(getattr(self, attribute), method)(*args, **kwargs)
return self
return fn


def haskey(key, datum):
"""
Checks whether a nested key is in a datum.
:param key: A sequence of keys splitted by '.'
:param datum: The datum to test
"""
keys = key.split('.')
for key in keys[:-1]:
if not isinstance(datum, dict) or key not in datum:
return False
datum = datum[key]
return keys[-1] in datum


def getkey(key, datum):
"""
Provides nested fetching of values.
:param key: A sequence of keys splitted by '.'
:param datum: The datum to select from
"""
for item in key.split('.'):
datum = datum[item]
return datum


class AndOrMixin(object):
"""
A mixin providing methods calls ``&`` and ``|``.
Expand Down Expand Up @@ -54,6 +95,18 @@ def __and__(self, other):
"""
return QueryAnd(self, other)

def any(self, key):
"""
Create a compound query that will check if any of the
data under a given key (which must be an iterable) is
said to be correct with a given query function.
Example:
>>> (where('key').any('a') == 2)({'key': {'a': [1,2,3]}})
True
"""
return QueryAny(self._key, Query(key))


class Query(AndOrMixin):
"""
Expand All @@ -68,6 +121,7 @@ class Query(AndOrMixin):

def __init__(self, key):
self._key = key
self._func = lambda x: True
self._repr = 'has \'{0}\''.format(key)

def matches(self, regex):
Expand Down Expand Up @@ -108,7 +162,7 @@ def __eq__(self, other):
if isinstance(other, Query):
return self._repr == other._repr
else:
self._value_eq = other
self._func = lambda x: x == other
self._update_repr('==', other)
return self

Expand All @@ -119,7 +173,7 @@ def __ne__(self, other):
>>> where('f1') != 42
'f1' != 42
"""
self._value_ne = other
self._func = lambda x: x != other
self._update_repr('!=', other)
return self

Expand All @@ -130,7 +184,7 @@ def __lt__(self, other):
>>> where('f1') < 42
'f1' < 42
"""
self._value_lt = other
self._func = lambda x: x < other
self._update_repr('<', other)
return self

Expand All @@ -141,7 +195,7 @@ def __le__(self, other):
>>> where('f1') <= 42
'f1' <= 42
"""
self._value_le = other
self._func = lambda x: x <= other
self._update_repr('<=', other)
return self

Expand All @@ -152,7 +206,7 @@ def __gt__(self, other):
>>> where('f1') > 42
'f1' > 42
"""
self._value_gt = other
self._func = lambda x: x > other
self._update_repr('>', other)
return self

Expand All @@ -163,7 +217,7 @@ def __ge__(self, other):
>>> where('f1') >= 42
'f1' >= 42
"""
self._value_ge = other
self._func = lambda x: x >= other
self._update_repr('>=', other)
return self

Expand All @@ -185,40 +239,8 @@ def __call__(self, element):
:param element: The dict that we will run our tests against.
:type element: dict
"""
if self._key not in element:
return False

try:
return element[self._key] == self._value_eq
except AttributeError:
pass

try:
return element[self._key] != self._value_ne
except AttributeError:
pass

try:
return element[self._key] < self._value_lt
except AttributeError:
pass

try:
return element[self._key] <= self._value_le
except AttributeError:
pass

try:
return element[self._key] > self._value_gt
except AttributeError:
pass

try:
return element[self._key] >= self._value_ge
except AttributeError:
pass

return True # _key exists in element (see above)
return (haskey(self._key, element)
and self._func(getkey(self._key, element)))

def _update_repr(self, operator, value):
""" Update the current test's ``repr``. """
Expand Down Expand Up @@ -304,14 +326,14 @@ class QueryRegex(AndOrMixin):
"""
def __init__(self, key, regex):
self.regex = regex
self._func = lambda x: re.match(self.regex, x)
self._key = key

def __call__(self, element):
"""
See :meth:`Query.__call__`.
"""
return bool(self._key in element
and re.match(self.regex, element[self._key]))
return haskey(self._key, element) and self._func(getkey(self._key, element))

def __repr__(self):
return '\'{0}\' ~= {1} '.format(self._key, self.regex)
Expand All @@ -325,14 +347,57 @@ class QueryCustom(AndOrMixin):
"""

def __init__(self, key, test):
self.test = test
self._func = test
self._key = key

def __call__(self, element):
"""
See :meth:`Query.__call__`.
"""
return self._key in element and self.test(element[self._key])
return haskey(self._key, element) and self._func(getkey(self._key, element))

def __repr__(self):
return '\'{0}\'.test({1})'.format(self._key, self._func)


class QueryAny(Query):
"""
Run a Query object against the data in the dict value.
"""
def __init__(self, key, query):
self._key = key
self._query = query

__eq__ = methodproxy('_query', '__eq__')
__ne__ = methodproxy('_query', '__ne__')
__gt__ = methodproxy('_query', '__gt__')
__ge__ = methodproxy('_query', '__ge__')
__lt__ = methodproxy('_query', '__lt__')
__le__ = methodproxy('_query', '__le__')

def matches(self, *args, **kwargs):
"""
Overrides the internal Query object with a QueryRegex object.
"""
self._query = self._query.matches(*args, **kwargs)
return self

def test(self, *args, **kwargs):
"""
Overrides the internal Query object with a QueryCustom object.
"""
self._query = self._query.test(*args, **kwargs)
return self

def __repr__(self):
return '\'{0}\'.test({1})'.format(self._key, self.test)
return "'{0}'.each({1})".format(self._key, self._query)

def __call__(self, element):
if haskey(self._key, element):
datum = getkey(self._key, element)
if haskey(self._query._key, datum):
iterable = getkey(self._query._key, datum)
if not hasattr(iterable, '__iter__'):
return False
return any(self._query._func(e) for e in iterable)
return False

0 comments on commit 8e047dd

Please sign in to comment.