Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve filtering #121

Merged
merged 5 commits into from Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 59 additions & 48 deletions gapipy/query.py
@@ -1,3 +1,4 @@
from copy import deepcopy
from functools import wraps
from itertools import islice

Expand Down Expand Up @@ -25,35 +26,56 @@


def _check_listable(func):

"""
decorator to ensure the Query we're attempting to call func on is listable
"""
@wraps(func)
def inner(query, *args, **kwargs):
def wrapper(query, *args, **kwargs):
if not (query.resource._is_listable or query.parent):
raise ValueError(
'The {0} resource is not listable and/or is only available as a subresource'.format(
query.resource.__name__))

"The {} resource is not listable and/or is only available as a subresource".format(
query.resource.__name__,
)
)
return func(query, *args, **kwargs)

return inner
return wrapper


class Query(object):

def __init__(self, client, resource, filters=None, parent=None, raw_data=None):
self._client = client
self._raw_data = raw_data or {}
self.parent = parent
self.resource = resource
self._client = client
self._filters = filters or {}
self.parent = parent
self._raw_data = raw_data or {}

def __iter__(self):
"""Provided as a convenience so that Query objects can be iterated
without calling `all`.

i.e. `for dossier in dossiers.filter(name="Peru")`
instead of `for dossier in dossiers.filter(name="Peru").all()`
"""
return self.all()

def _to_dict(self):
# Used by Resource when converting nested Query objects into
# serializable types.
return self._raw_data

def options(self):
return self.resource.options(client=self._client)
def _clone(self):
"""
create a clone of this Query, with deep copies of _filter & _raw_data
"""
return Query(
self._client,
self.resource,
filters=deepcopy(self._filters),
parent=self.parent,
raw_data=deepcopy(self._raw_data),
)

def get(self, resource_id, variation_id=None, cached=True, headers=None,
httperrors_mapped_to_none=HTTPERRORS_MAPPED_TO_NONE):
Expand All @@ -74,8 +96,6 @@ def get(self, resource_id, variation_id=None, cached=True, headers=None,
something Falsey as `httperrors_mapped_to_none` like a `None` or an
empty list.
"""
key = self.query_key(resource_id, variation_id)

try:
data = self.get_resource_data(
resource_id,
Expand All @@ -91,10 +111,10 @@ def get(self, resource_id, variation_id=None, cached=True, headers=None,
return resource_object

def get_resource_data(self, resource_id, variation_id=None, cached=True, headers=None):
'''
"""
Returns a dictionary of resource data, which is used to initialize
a Resource object in the `get` method.
'''
"""
key = self.query_key(resource_id, variation_id)
resource_data = None
if cached:
Expand Down Expand Up @@ -132,17 +152,23 @@ def query_key(self, resource_id=None, variation_id=None):
parts.append(self._client.api_language)

if self._client.application_key:
part = self._client.application_key.split('_')[0]
if part == self._client.application_key or part.strip(' ') != 'test':
return ':'.join(parts)
part = self._client.application_key.split("_")[0]
if part == self._client.application_key or part.strip(" ") != "test":
return ":".join(parts)
parts.append(part)
return ':'.join(parts)
return ":".join(parts)

@_check_listable
def all(self, limit=None):
"""Generator of instances of the query resource. If limit is set to a
positive integer `n`, only return the first `n` results.
"""
# check limit is valid integer value
if limit is not None:
if not isinstance(limit, int):
raise TypeError("limit must be an integer")
elif limit <= 0:
raise ValueError("limit must be a positive integer")

requestor = APIRequestor(
self._client,
Expand All @@ -152,47 +178,35 @@ def all(self, limit=None):
)
# use href when available; this change should be transparent
# introduced: 2.20.0
href = None
href = None
if isinstance(self._raw_data, dict):
href = self._raw_data.get('href')
# generator to fetch list resources
generator = requestor.list(href)
# reset filters in case they were set on this query
self._filters = {}

if limit:
if isinstance(limit, int) and limit > 0:
generator = islice(generator, limit)
else:
raise ValueError('`limit` must be a positive integer')
href = self._raw_data.get("href")

for result in generator:
# generator to fetch list resources
for result in islice(requestor.list(href), limit):
yield self.resource(result, client=self._client, stub=True)

def filter(self, **kwargs):
"""Add filter arguments to the query.

For example, if `query` is a Query for the TourDossier ressource, then
`query.filter(name='Amazing Adventure')` will return a query containing
only dossiers whose names contain 'Amazing Adventure'.
`query.filter(name="Amazing Adventure")` will return a query containing
only dossiers whose names contain "Amazing Adventure".
"""
self._filters.update(kwargs)
return self
clone = self._clone()
clone._filters.update(kwargs)
return clone

@_check_listable
def count(self):
"""Returns the number of element in the query."""

requestor = APIRequestor(
self._client,
self.resource,
params=self._filters,
parent=self.parent
)
response = requestor.list_raw()
out = response.get('count')
self._filters = {}
return out
return requestor.list_raw().get("count")

def create(self, data_dict, headers=None):
"""Create an instance of the query resource using the given data"""
Expand All @@ -204,11 +218,8 @@ def first(self):
"""
return next(self.all(), None)

def __iter__(self):
"""Provided as a convenience so that Query objects can be iterated
without calling `all`.

i.e. `for dossier in dossiers.filter(name='Peru')`
instead of `for dossier in dossiers.filter(name='Peru').all()`
def options(self):
"""
return self.all()
return the OPTIONS response for the resource bound to this Query
"""
return self.resource.options(client=self._client)
63 changes: 42 additions & 21 deletions tests/test_query.py
Expand Up @@ -194,11 +194,25 @@ def test_get_instance_by_id_with_non_404_error(self, mock_request):
self.assertIsNone(
query.get(1234, httperrors_mapped_to_none=[response.status_code]))

@patch('gapipy.request.APIRequestor._request')
def test_filtered_query_returns_new_object(self, mock_request):
"""
Arguments passed to .filter() are stored on new (copied) Query instance
that will be different from the one it derived from.
Instances should be different and filters should not stack up
Every new filter returns a new object.
"""
query = Query(self.client, Tour).filter(tour_dossier_code='PPP')
query1 = Query(self.client, Tour).filter(tour_dossier_code='DJNN')

self.assertFalse(query is query1)
self.assertNotEqual(query._filters, query1._filters)

@patch('gapipy.request.APIRequestor._request')
def test_filtered_query(self, mock_request):
"""
Arguments passed to .filter() are stored on the Query instance but are
cleared when that query is evaluated.
Arguments passed to .filter() are stored on new (copied) Query instance
and are not supposed to be cleared when that query is evaluated.
"""
# Create a basic filter query for PPP...
query = Query(self.client, Tour).filter(tour_dossier_code='PPP')
Expand All @@ -219,27 +233,20 @@ def test_filtered_query(self, mock_request):
})
mock_request.reset_mock()

# ... our stored filter args got reset.
self.assertEqual(len(query._filters), 0)

# Check .count() also clears stored filter args appropriately:
query.filter(
tour_dossier_code='PPP',
order_by__desc='departures_start_date').count()
mock_request.assert_called_once_with(
'/tours', 'GET', params={
'tour_dossier_code': 'PPP',
'order_by__desc': 'departures_start_date',
})
mock_request.reset_mock()
self.assertEqual(len(query._filters), 0)
# ... our stored filter args remain for the current instance.
self.assertEqual(len(query._filters), 2)

# .count() should remain the query filters and
# respond from a new instance with the count value:
query.count()
self.assertEqual(len(query._filters), 2)

@patch('gapipy.request.APIRequestor._request')
def test_query_reset_filter(self, mock_request):
def test_query_persist_filter_on_count(self, mock_request):
query = Query(self.client, Tour)
query.filter(tour_dossier_code='PPP').count()
self.assertEqual(query._filters, {})
my_query = query.filter(tour_dossier_code='PPP')
my_query.count()
self.assertEqual(my_query._filters, {'tour_dossier_code': 'PPP'})

def test_listing_non_listable_resource_fails(self):
message = 'The Activity resource is not listable and/or is only available as a subresource'
Expand Down Expand Up @@ -302,8 +309,8 @@ def test_fetch_all_with_limit(self, mock_request):
mock_request.assert_called_once_with(
'/tour_dossiers', 'GET', params={})

def test_fetch_all_with_wrong_argument_for_limit(self):
message = '`limit` must be a positive integer'
def test_fetch_all_with_negative_arg_for_limit(self):
message = 'limit must be a positive integer'
if sys.version_info.major < 3:
with self.assertRaisesRegexp(ValueError, message):
query = Query(self.client, Tour).all(limit=-1)
Expand All @@ -313,6 +320,20 @@ def test_fetch_all_with_wrong_argument_for_limit(self):
query = Query(self.client, Tour).all(limit=-1)
list(query) # force the query to evaluate

def test_fetch_all_with_wrong_arg_types_for_limit(self):
wrong_arg_types = ['', [], {}, object()]
message = 'limit must be an integer'

for wrong_arg in wrong_arg_types:
if sys.version_info.major < 3:
with self.assertRaisesRegexp(TypeError, message):
query = Query(self.client, Tour).all(limit=wrong_arg)
list(query) # force the query to evaluate
else:
with self.assertRaisesRegex(TypeError, message):
query = Query(self.client, Tour).all(limit=wrong_arg)
list(query) # force the query to evaluate


class QueryCacheTestCase(unittest.TestCase):
def setUp(self):
Expand Down