diff --git a/CHANGES.md b/CHANGES.md index 64b5f9c0..5efcf9b6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,7 @@ - [NEW] Add custom JSON encoder/decoder option to `Document` constructor. - [NEW] Add new view parameters, `stable` and `update`, as keyword arguments to `get_view_result`. +- [NEW] Allow arbitrary query parameters to be passed to custom changes filters. - [FIXED] Case where an exception was raised after successful retry when using `doc.update_field`. - [FIXED] Removed unnecessary request when retrieving a Result collection that is less than the 'page_size' value diff --git a/src/cloudant/_common_util.py b/src/cloudant/_common_util.py index 2af49b34..1904a0a4 100644 --- a/src/cloudant/_common_util.py +++ b/src/cloudant/_common_util.py @@ -47,6 +47,9 @@ # Argument Types +ANY_ARG = object() +ANY_TYPE = object() + RESULT_ARG_TYPES = { 'descending': (bool,), 'endkey': (int, LONGTYPE, STRTYPE, Sequence,), @@ -101,6 +104,7 @@ 'filter': (STRTYPE,), 'include_docs': (bool,), 'style': (STRTYPE,), + ANY_ARG: ANY_TYPE # pass arbitrary query parameters to a custom filter } _CHANGES_ARG_TYPES.update(_DB_UPDATES_ARG_TYPES) diff --git a/src/cloudant/feed.py b/src/cloudant/feed.py index 4a93fc16..ef2c90fc 100644 --- a/src/cloudant/feed.py +++ b/src/cloudant/feed.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2015, 2016 IBM. All rights reserved. +# Copyright (c) 2015, 2018 IBM. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ from ._2to3 import iteritems_, next_, unicode_, STRTYPE, NONETYPE from .error import CloudantArgumentError, CloudantFeedException -from ._common_util import feed_arg_types, TYPE_CONVERTERS +from ._common_util import ANY_ARG, ANY_TYPE, feed_arg_types, TYPE_CONVERTERS class Feed(object): """ @@ -100,7 +100,7 @@ def _translate(self, options): if isinstance(val, STRTYPE): translation[key] = val elif not isinstance(val, NONETYPE): - arg_converter = TYPE_CONVERTERS.get(type(val)) + arg_converter = TYPE_CONVERTERS.get(type(val), json.dumps) translation[key] = arg_converter(val) except Exception as ex: raise CloudantArgumentError(115, key, ex) @@ -111,11 +111,18 @@ def _validate(self, key, val, arg_types): Ensures that the key and the value are valid arguments to be used with the feed. """ - if key not in arg_types: - raise CloudantArgumentError(116, key) - if (not isinstance(val, arg_types[key]) or - (isinstance(val, bool) and int in arg_types[key])): - raise CloudantArgumentError(117, key, arg_types[key]) + if key in arg_types: + arg_type = arg_types[key] + else: + if ANY_ARG not in arg_types: + raise CloudantArgumentError(116, key) + arg_type = arg_types[ANY_ARG] + + if arg_type == ANY_TYPE: + return + if (not isinstance(val, arg_type) or + (isinstance(val, bool) and int in arg_type)): + raise CloudantArgumentError(117, key, arg_type) if isinstance(val, int) and val < 0 and not isinstance(val, bool): raise CloudantArgumentError(118, key, val) if key == 'feed': diff --git a/tests/unit/changes_tests.py b/tests/unit/changes_tests.py index ee3dce13..007e07a6 100644 --- a/tests/unit/changes_tests.py +++ b/tests/unit/changes_tests.py @@ -465,14 +465,20 @@ def test_get_feed_using_doc_ids(self): self.assertSetEqual(set([x['id'] for x in changes]), expected) self.assertTrue(str(feed.last_seq).startswith('100')) - def test_invalid_argument(self): + def test_get_feed_with_custom_filter_query_params(self): """ - Test that an invalid argument is caught and an exception is raised + Test using feed with custom filter query parameters. """ - feed = Feed(self.db, foo='bar') - with self.assertRaises(CloudantArgumentError) as cm: - invalid_feed = [x for x in feed] - self.assertEqual(str(cm.exception), 'Invalid argument foo') + feed = Feed( + self.db, + filter='mailbox/new_mail', + foo='bar', # query parameters to a custom filter + include_docs=False + ) + params = feed._translate(feed._options) + self.assertEquals(params['filter'], 'mailbox/new_mail') + self.assertEquals(params['foo'], 'bar') + self.assertEquals(params['include_docs'], 'false') def test_invalid_argument_type(self): """