From d3761f845d1fdf653f08d6a2c5db64884a90c36b Mon Sep 17 00:00:00 2001 From: Vitor Baptista Date: Thu, 3 Jul 2014 16:35:13 -0300 Subject: [PATCH] [#1830] datastore_search accepts multiple filters values as array With this patch, if you query the datastore_search with filters: ```json { "filters": { "country": ["Brazil", "Argentina"] } } ``` It'll return the rows that have `country IN ("Brazil", "Argentina")`. I had to change the `datastore_search` extension point on `IDatastore` because I needed the fields' types. I needed the fields' types because the Datastore accept fields with array type (e.g. `_text`). When filtering a field that has an array type, I want to query literally. In the sample query above, the Datastore can create two different queries, depending on the type of the `country` field. If it's not an array type, it'll do as I described above, querying `country IN ("Brazil", "Argentina")`. If it's of an array type, it'll query literally `country = ["Brazil", "Argentina"]`. --- ckanext/datastore/db.py | 13 +++++++++---- ckanext/datastore/interfaces.py | 7 ++++--- ckanext/datastore/plugin.py | 25 ++++++++++++++++++------- ckanext/datastore/tests/test_search.py | 2 +- 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/ckanext/datastore/db.py b/ckanext/datastore/db.py index ff88870ff67..453a5882b26 100644 --- a/ckanext/datastore/db.py +++ b/ckanext/datastore/db.py @@ -223,6 +223,13 @@ def _get_fields(context, data_dict): return fields +def _get_fields_types(context, data_dict): + all_fields = _get_fields(context, data_dict) + all_fields.insert(0, {'id': '_id', 'type': 'int'}) + field_types = dict([(f['id'], f['type']) for f in all_fields]) + return field_types + + def json_get_values(obj, current_list=None): if current_list is None: current_list = [] @@ -845,9 +852,7 @@ def validate(context, data_dict): def search_data(context, data_dict): validate(context, data_dict) - all_fields = _get_fields(context, data_dict) - column_names = _pluck('id', all_fields) - column_names.insert(0, '_id') + fields_types = _get_fields_types(context, data_dict) query_dict = { 'select': [], @@ -857,7 +862,7 @@ def search_data(context, data_dict): for plugin in p.PluginImplementations(interfaces.IDatastore): query_dict = plugin.datastore_search(context, data_dict, - column_names, query_dict) + fields_types, query_dict) where_clause, where_values = _where(query_dict['where']) diff --git a/ckanext/datastore/interfaces.py b/ckanext/datastore/interfaces.py index fc2194c0fe8..f0f59f13e97 100644 --- a/ckanext/datastore/interfaces.py +++ b/ckanext/datastore/interfaces.py @@ -34,7 +34,7 @@ def datastore_validate(self, context, data_dict, column_names): ''' return data_dict - def datastore_search(self, context, data_dict, column_names, query_dict): + def datastore_search(self, context, data_dict, fields_types, query_dict): '''Modify queries made on datastore_search The overall design is that every IDatastore extension will receive the @@ -79,8 +79,9 @@ def datastore_search(self, context, data_dict, column_names, query_dict): :type context: dictionary :param data_dict: the parameters received from the user :type data_dict: dictionary - :param column_names: the current resource's column names - :type column_names: list + :param fields_types: the current resource's fields as dict keys and + their types as values + :type fields_types: dictionary :param query_dict: the current query_dict, as changed by the IDatastore extensions that ran before yours :type query_dict: dictionary diff --git a/ckanext/datastore/plugin.py b/ckanext/datastore/plugin.py index 2755ba279b4..a236288466b 100644 --- a/ckanext/datastore/plugin.py +++ b/ckanext/datastore/plugin.py @@ -343,8 +343,9 @@ def datastore_delete(self, context, data_dict, column_names, query_dict): query_dict['where'] += self._where(data_dict, column_names) return query_dict - def datastore_search(self, context, data_dict, column_names, query_dict): + def datastore_search(self, context, data_dict, fields_types, query_dict): fields = data_dict.get('fields') + column_names = fields_types.keys() if fields: field_ids = datastore_helpers.get_list(fields) @@ -355,8 +356,8 @@ def datastore_search(self, context, data_dict, column_names, query_dict): limit = data_dict.get('limit', 100) offset = data_dict.get('offset', 0) - sort = self._sort(data_dict, field_ids) - where = self._where(data_dict, field_ids) + sort = self._sort(data_dict) + where = self._where(data_dict, fields_types) select_cols = [u'"{0}"'.format(field_id) for field_id in field_ids] +\ [u'count(*) over() as "_full_count" %s' % rank_column] @@ -370,13 +371,20 @@ def datastore_search(self, context, data_dict, column_names, query_dict): return query_dict - def _where(self, data_dict, column_names): + def _where(self, data_dict, fields_types): filters = data_dict.get('filters', {}) clauses = [] + for field, value in filters.iteritems(): - if field not in column_names: + if field not in fields_types: continue - clause = (u'"{0}" = %s'.format(field), value) + field_array_type = self._is_array_type(fields_types[field]) + if isinstance(value, list) and not field_array_type: + clause_str = (u'"{0}" in ({1})'.format(field, + ','.join(['%s'] * len(value)))) + clause = (clause_str,) + tuple(value) + else: + clause = (u'"{0}" = %s'.format(field), value) clauses.append(clause) # add full-text search where clause @@ -386,6 +394,9 @@ def _where(self, data_dict, column_names): return clauses + def _is_array_type(self, field_type): + return field_type.startswith('_') + def _textsearch_query(self, data_dict): q = data_dict.get('q') lang = data_dict.get(u'language', u'english') @@ -399,7 +410,7 @@ def _textsearch_query(self, data_dict): return statement.format(lang=lang, query=q), rank_column return '', '' - def _sort(self, data_dict, field_ids): + def _sort(self, data_dict): sort = data_dict.get('sort') if not sort: if data_dict.get('q'): diff --git a/ckanext/datastore/tests/test_search.py b/ckanext/datastore/tests/test_search.py index e2691ffd192..20dd988d489 100644 --- a/ckanext/datastore/tests/test_search.py +++ b/ckanext/datastore/tests/test_search.py @@ -222,7 +222,7 @@ def test_search_filter_array_field(self): assert res_dict['success'] is True result = res_dict['result'] assert result['total'] == 1 - assert result['records'] == [self.expected_records[0]] + assert_equals(result['records'], [self.expected_records[0]]) def test_search_filter_normal_field_passing_multiple_values_in_array(self): data = {'resource_id': self.data['resource_id'],