diff --git a/complaint_search/es_builders.py b/complaint_search/es_builders.py new file mode 100644 index 00000000..01bdf3a6 --- /dev/null +++ b/complaint_search/es_builders.py @@ -0,0 +1,270 @@ +import abc +from collections import defaultdict, namedtuple + +class BaseBuilder(object): + __metaclass__ = abc.ABCMeta + + # Filters for those with string type + _OPTIONAL_FILTERS = ("product", "issue", "company", "state", "zip_code", "timely", + "company_response", "company_public_response", + "consumer_consent_provided", "submitted_via", "tag", "consumer_disputed") + + # Filters for those that need conversion from string to boolean + _OPTIONAL_FILTERS_STRING_TO_BOOL = ("has_narratives",) + + _OPTIONAL_FILTERS_PARAM_TO_ES_MAP = { + "product": "product.raw", + "sub_product": "sub_product.raw", + "issue": "issue.raw", + "sub_issue": "sub_issue.raw", + "company_public_response": "company_public_response.raw", + "consumer_consent_provided": "consumer_consent_provided.raw", + "consumer_disputed": "consumer_disputed.raw" + } + + _OPTIONAL_FILTERS_CHILD_MAP = { + "product": "sub_product", + "issue": "sub_issue" + } + + def __init__(self): + self.params = {} + + def add(self, **kwargs): + self.params.update(**kwargs) + + @abc.abstractmethod + def build(self): + """Method that will build the body dictionary.""" + + def _create_bool_should_clauses(self, es_field_name, value_list, + with_subitems=False, es_subitem_field_name=None): + if value_list: + if not with_subitems: + term_list = [ {"terms": {es_field_name: [value]}} + for value in value_list ] + return {"bool": {"should": term_list}} + else: + item_dict = defaultdict(list) + for v in value_list: + # -*- coding: utf-8 -*- + v_pair = v.split(u'\u2022') + # No subitem + if len(v_pair) == 1: + # This will initialize empty list for item if not in item_dict yet + item_dict[v_pair[0]] + elif len(v_pair) == 2: + # put subproduct into list + item_dict[v_pair[0]].append(v_pair[1]) + + # Go through item_dict to create filters + f_list = [] + for item, subitems in item_dict.iteritems(): + item_term = {"terms": {es_field_name: [item]}} + # Item without any subitems + if not subitems: + f_list.append(item_term) + else: + subitem_term = {"terms": {es_subitem_field_name: subitems}} + f_list.append({"and": {"filters": [item_term, subitem_term]}}) + + return {"bool": {"should": f_list}} + + def _create_and_append_bool_should_clauses(self, es_field_name, value_list, + filter_list, with_subitems=False, es_subitem_field_name=None): + + filter_clauses = self._create_bool_should_clauses(es_field_name, value_list, + with_subitems, es_subitem_field_name) + + if filter_clauses: + filter_list.append(filter_clauses) + + +class SearchBuilder(BaseBuilder): + def __init__(self): + self.params = { + "format": "json", + "field": "complaint_what_happened", + "size": 10, + "frm": 0, + "sort": "relevance_desc" + } + + def build(self): + search = { + "from": self.params.get("frm"), + "size": self.params.get("size"), + "query": {"match_all": {}}, + "highlight": { + "fields": { + self.params.get("field"): {} + }, + "number_of_fragments": 1, + "fragment_size": 500 + } + } + + # sort + sort_field, sort_order = self.params.get("sort").rsplit("_", 1) + sort_field = "_score" if sort_field == "relevance" else sort_field + search["sort"] = [{sort_field: {"order": sort_order}}] + + # query + if self.params.get("search_term"): + search["query"] = { + "match": { + self.params.get("field"): { + "query": self.params.get("search_term"), + "operator": "and" + } + } + } + else: + search["query"] = { + "query_string": { + "query": "*", + "fields": [ + self.params.get("field") + ], + "default_operator": "AND" + } + } + + return search + +class PostFilterBuilder(BaseBuilder): + + def build(self): + post_filter = {"and": {"filters": []}} + + ## date + if self.params.get("min_date") or self.params.get("max_date"): + date_clause = {"range": {"date_received": {}}} + if self.params.get("min_date"): + date_clause["range"]["date_received"]["from"] = self.params.get("min_date") + if self.params.get("max_date"): + date_clause["range"]["date_received"]["to"] = self.params.get("max_date") + + post_filter["and"]["filters"].append(date_clause) + + ## Create bool should clauses for fields in self._OPTIONAL_FILTERS + for field in self._OPTIONAL_FILTERS: + if field in self._OPTIONAL_FILTERS_CHILD_MAP: + self._create_and_append_bool_should_clauses(self._OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(field, field), + self.params.get(field), post_filter["and"]["filters"], with_subitems=True, + es_subitem_field_name=self._OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(self._OPTIONAL_FILTERS_CHILD_MAP.get(field), + self._OPTIONAL_FILTERS_CHILD_MAP.get(field))) + else: + self._create_and_append_bool_should_clauses(self._OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(field, field), + self.params.get(field), post_filter["and"]["filters"]) + + for field in self._OPTIONAL_FILTERS_STRING_TO_BOOL: + if self.params.get(field): + self._create_and_append_bool_should_clauses(field, + [ 0 if cd.lower() == "no" else 1 for cd in self.params.get(field) ], + post_filter["and"]["filters"]) + + return post_filter + +class AggregationBuilder(BaseBuilder): + + def build(self): + # All fields that need to have an aggregation entry + Field = namedtuple('Field', 'name size has_subfield') + fields = [ + Field('has_narratives', 10, False), + Field('company', 10000, False), + Field('product', 10000, True), + Field('issue', 10000, True), + Field('state', 50, False), + Field('zip_code', 1000, False), + Field('timely', 10, False), + Field('company_response', 100, False), + Field('company_public_response', 100, False), + Field('consumer_disputed', 100, False), + Field('consumer_consent_provided', 100, False), + Field('tag', 100, False), + Field('submitted_via', 100, False) + ] + aggs = {} + + # Creating aggregation object for each field above + for field in fields: + field_aggs = { + "filter": { + "and": { + "filters": [ + + ] + } + } + } + + es_field_name = self._OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(field.name, field.name) + es_subfield_name = None + if field.has_subfield: + es_subfield_name = self._OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(self._OPTIONAL_FILTERS_CHILD_MAP.get(field.name)) + field_aggs["aggs"] = { + field.name: { + "terms": { + "field": es_field_name, + "size": field.size + }, + "aggs": { + es_subfield_name: { + "terms": { + "field": es_subfield_name, + "size": field.size + } + } + } + } + } + else: + field_aggs["aggs"] = { + field.name: { + "terms": { + "field": es_field_name, + "size": field.size + } + } + } + + date_filter = { + "range": { + "date_received": { + + } + } + } + if "min_date" in self.params: + date_filter["range"]["date_received"]["from"] = self.params["min_date"] + if "max_date" in self.params: + date_filter["range"]["date_received"]["to"] = self.params["max_date"] + + field_aggs["filter"]["and"]["filters"].append(date_filter) + + # Add filter clauses to aggregation entries (only those that are not the same as field name) + for item in self.params: + if item in self._OPTIONAL_FILTERS and item != field.name: + clauses = self._create_and_append_bool_should_clauses(self._OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(item, item), + self.params[item], field_aggs["filter"]["and"]["filters"], + with_subitems=item in self._OPTIONAL_FILTERS_CHILD_MAP, + es_subitem_field_name=self._OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(self._OPTIONAL_FILTERS_CHILD_MAP.get(item))) + elif item in self._OPTIONAL_FILTERS_STRING_TO_BOOL and item != field.name: + clauses = self._create_and_append_bool_should_clauses(self._OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(item, item), + [ 0 if cd.lower() == "no" else 1 for cd in self.params[item] ], + field_aggs["filter"]["and"]["filters"]) + + aggs[field.name] = field_aggs + + return aggs + +if __name__ == "__main__": + searchbuilder = SearchBuilder() + print searchbuilder.build() + pfbuilder = PostFilterBuilder() + print pfbuilder.build() + aggbuilder = AggregationBuilder() + print aggbuilder.build() + diff --git a/complaint_search/es_interface.py b/complaint_search/es_interface.py index f1588756..5c964103 100644 --- a/complaint_search/es_interface.py +++ b/complaint_search/es_interface.py @@ -4,6 +4,7 @@ from collections import defaultdict, namedtuple import requests from elasticsearch import Elasticsearch +from complaint_search.es_builders import SearchBuilder, PostFilterBuilder, AggregationBuilder _ES_URL = "{}://{}:{}".format("http", os.environ.get('ES_HOST', 'localhost'), os.environ.get('ES_PORT', '9200')) @@ -15,29 +16,6 @@ _COMPLAINT_ES_INDEX = os.environ.get('COMPLAINT_ES_INDEX', 'complaint-index') _COMPLAINT_DOC_TYPE = os.environ.get('COMPLAINT_DOC_TYPE', 'complaint-doctype') -# Filters for those with string type -_OPTIONAL_FILTERS = ("product", "issue", "company", "state", "zip_code", "timely", - "company_response", "company_public_response", - "consumer_consent_provided", "submitted_via", "tag", "consumer_disputed") - -# Filters for those that need conversion from string to boolean -_OPTIONAL_FILTERS_STRING_TO_BOOL = ("has_narratives",) - -_OPTIONAL_FILTERS_PARAM_TO_ES_MAP = { - "product": "product.raw", - "sub_product": "sub_product.raw", - "issue": "issue.raw", - "sub_issue": "sub_issue.raw", - "company_public_response": "company_public_response.raw", - "consumer_consent_provided": "consumer_consent_provided.raw", - "consumer_disputed": "consumer_disputed.raw" -} - -_OPTIONAL_FILTERS_CHILD_MAP = { - "product": "sub_product", - "issue": "sub_issue" -} - def get_es(): global _ES_INSTANCE if _ES_INSTANCE is None: @@ -45,145 +23,8 @@ def get_es(): timeout=100) return _ES_INSTANCE - - -def _create_aggregation(**kwargs): - - # All fields that need to have an aggregation entry - Field = namedtuple('Field', 'name size has_subfield') - fields = [ - Field('has_narratives', 10, False), - Field('company', 10000, False), - Field('product', 10000, True), - Field('issue', 10000, True), - Field('state', 50, False), - Field('zip_code', 1000, False), - Field('timely', 10, False), - Field('company_response', 100, False), - Field('company_public_response', 100, False), - Field('consumer_disputed', 100, False), - Field('consumer_consent_provided', 100, False), - Field('tag', 100, False), - Field('submitted_via', 100, False) - ] - aggs = {} - - # Creating aggregation object for each field above - for field in fields: - field_aggs = { - "filter": { - "and": { - "filters": [ - - ] - } - } - } - - es_field_name = _OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(field.name, field.name) - es_subfield_name = None - if field.has_subfield: - es_subfield_name = _OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(_OPTIONAL_FILTERS_CHILD_MAP.get(field.name)) - field_aggs["aggs"] = { - field.name: { - "terms": { - "field": es_field_name, - "size": field.size - }, - "aggs": { - es_subfield_name: { - "terms": { - "field": es_subfield_name, - "size": field.size - } - } - } - } - } - else: - field_aggs["aggs"] = { - field.name: { - "terms": { - "field": es_field_name, - "size": field.size - } - } - } - - date_filter = { - "range": { - "date_received": { - - } - } - } - if "min_date" in kwargs: - date_filter["range"]["date_received"]["from"] = kwargs["min_date"] - if "max_date" in kwargs: - date_filter["range"]["date_received"]["to"] = kwargs["max_date"] - - field_aggs["filter"]["and"]["filters"].append(date_filter) - - # Add filter clauses to aggregation entries (only those that are not the same as field name) - for item in kwargs: - if item in _OPTIONAL_FILTERS and item != field.name: - clauses = _create_and_append_bool_should_clauses(_OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(item, item), - kwargs[item], field_aggs["filter"]["and"]["filters"], - with_subitems=item in _OPTIONAL_FILTERS_CHILD_MAP, - es_subitem_field_name=_OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(_OPTIONAL_FILTERS_CHILD_MAP.get(item))) - elif item in _OPTIONAL_FILTERS_STRING_TO_BOOL and item != field.name: - clauses = _create_and_append_bool_should_clauses(_OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(item, item), - [ 0 if cd.lower() == "no" else 1 for cd in kwargs[item] ], - field_aggs["filter"]["and"]["filters"]) - - aggs[field.name] = field_aggs - - return aggs - -def _create_bool_should_clauses(es_field_name, value_list, - with_subitems=False, es_subitem_field_name=None): - if value_list: - if not with_subitems: - term_list = [ {"terms": {es_field_name: [value]}} - for value in value_list ] - return {"bool": {"should": term_list}} - else: - item_dict = defaultdict(list) - for v in value_list: - # -*- coding: utf-8 -*- - v_pair = v.split(u'\u2022') - # No subitem - if len(v_pair) == 1: - # This will initialize empty list for item if not in item_dict yet - item_dict[v_pair[0]] - elif len(v_pair) == 2: - # put subproduct into list - item_dict[v_pair[0]].append(v_pair[1]) - - # Go through item_dict to create filters - f_list = [] - for item, subitems in item_dict.iteritems(): - item_term = {"terms": {es_field_name: [item]}} - # Item without any subitems - if not subitems: - f_list.append(item_term) - else: - subitem_term = {"terms": {es_subitem_field_name: subitems}} - f_list.append({"and": {"filters": [item_term, subitem_term]}}) - - return {"bool": {"should": f_list}} - -def _create_and_append_bool_should_clauses(es_field_name, value_list, - filter_list, with_subitems=False, es_subitem_field_name=None): - - filter_clauses = _create_bool_should_clauses(es_field_name, value_list, - with_subitems, es_subitem_field_name) - - if filter_clauses: - filter_list.append(filter_clauses) - # List of possible arguments: -# - fmt: format to be returned: "json", "csv", "xls", or "xlsx" +# - format: format to be returned: "json", "csv", "xls", or "xlsx" # - field: field you want to search in: "complaint_what_happened", "company_public_response", "_all" # - size: number of complaints to return # - frm: from which index to start returning @@ -206,97 +47,26 @@ def _create_and_append_bool_should_clauses(es_field_name, value_list, # - tag - filters a list of tags def search(**kwargs): - # base default parameters - params = { - "fmt": "json", - "field": "complaint_what_happened", - "size": 10, - "frm": 0, - "sort": "relevance_desc" - } + search_builder = SearchBuilder() + search_builder.add(**kwargs) + body = search_builder.build() - params.update(**kwargs) + post_filter_builder = PostFilterBuilder() + post_filter_builder.add(**kwargs) + body["post_filter"] = post_filter_builder.build() + # format res = None - body = { - "from": params.get("frm"), - "size": params.get("size"), - "query": {"match_all": {}}, - "highlight": { - "fields": { - params.get("field"): {} - }, - "number_of_fragments": 1, - "fragment_size": 500 - } - } - - # sort - sort_field, sort_order = params.get("sort").rsplit("_", 1) - sort_field = "_score" if sort_field == "relevance" else sort_field - body["sort"] = [{sort_field: {"order": sort_order}}] - - # query - if params.get("search_term"): - body["query"] = { - "match": { - params.get("field"): { - "query": params.get("search_term"), - "operator": "and" - } - } - } - else: - body["query"] = { - "query_string": { - "query": "*", - "fields": [ - params.get("field") - ], - "default_operator": "AND" - } - } - - # post-filter - body["post_filter"] = {"and": {"filters": []}} + format = kwargs.get("format", "json") + if format == "json": + aggregation_builder = AggregationBuilder() + aggregation_builder.add(**kwargs) + body["aggs"] = aggregation_builder.build() - - ## date - if params.get("min_date") or params.get("max_date"): - date_clause = {"range": {"date_received": {}}} - if params.get("min_date"): - date_clause["range"]["date_received"]["from"] = params.get("min_date") - if params.get("max_date"): - date_clause["range"]["date_received"]["to"] = params.get("max_date") - - body["post_filter"]["and"]["filters"].append(date_clause) - - - ## Create bool should clauses for fields in _OPTIONAL_FILTERS - for field in _OPTIONAL_FILTERS: - if field in _OPTIONAL_FILTERS_CHILD_MAP: - _create_and_append_bool_should_clauses(_OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(field, field), - params.get(field), body["post_filter"]["and"]["filters"], with_subitems=True, - es_subitem_field_name=_OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(_OPTIONAL_FILTERS_CHILD_MAP.get(field), - _OPTIONAL_FILTERS_CHILD_MAP.get(field))) - else: - _create_and_append_bool_should_clauses(_OPTIONAL_FILTERS_PARAM_TO_ES_MAP.get(field, field), - params.get(field), body["post_filter"]["and"]["filters"]) - - for field in _OPTIONAL_FILTERS_STRING_TO_BOOL: - if params.get(field): - _create_and_append_bool_should_clauses(field, - [ 0 if cd.lower() == "no" else 1 for cd in params.get(field) ], - body["post_filter"]["and"]["filters"]) - - # format - if params.get("fmt") == "json": - ## Create base aggregation - body["aggs"] = _create_aggregation(**kwargs) res = get_es().search(index=_COMPLAINT_ES_INDEX, body=body) - elif params.get("fmt") in ["csv", "xls", "xlsx"]: - p = {"format": params.get("fmt"), - "source": json.dumps(body)} + + elif format in ("csv", "xls", "xlsx"): + p = {"format": format, "source": json.dumps(body)} p = urllib.urlencode(p) url = "{}/{}/{}/_data?{}".format(_ES_URL, _COMPLAINT_ES_INDEX, _COMPLAINT_DOC_TYPE, p) diff --git a/complaint_search/renderers.py b/complaint_search/renderers.py new file mode 100644 index 00000000..04512ebc --- /dev/null +++ b/complaint_search/renderers.py @@ -0,0 +1,26 @@ +from rest_framework.renderers import BaseRenderer + +class CSVRenderer(BaseRenderer): + media_type = 'text/csv' + format = 'csv' + + def render(self, data, media_type=None, renderer_context=None): + return data + +class XLSRenderer(BaseRenderer): + media_type = 'application/vnd.ms-excel' + format = 'xls' + charset = None + render_style = 'binary' + + def render(self, data, media_type=None, renderer_context=None): + return data + +class XLSXRenderer(BaseRenderer): + media_type = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + format = 'xlsx' + charset = None + render_style = 'binary' + + def render(self, data, media_type=None, renderer_context=None): + return data \ No newline at end of file diff --git a/complaint_search/serializer.py b/complaint_search/serializer.py index 93701a63..ae7bb05a 100644 --- a/complaint_search/serializer.py +++ b/complaint_search/serializer.py @@ -42,7 +42,7 @@ class SearchInputSerializer(serializers.Serializer): (SORT_CREATED_DATE_DESC, 'Descending Created Date'), (SORT_CREATED_DATE_ASC, 'Ascending Created Date'), ) - fmt = serializers.ChoiceField(FORMAT_CHOICES, required=False) + format = serializers.ChoiceField(FORMAT_CHOICES, required=False) field = serializers.ChoiceField(FIELD_CHOICES, required=False) size = serializers.IntegerField(min_value=1, max_value=100000, required=False) frm = serializers.IntegerField(min_value=0, max_value=100000, required=False) diff --git a/complaint_search/tests/expected_results/search_with_fmt_nonjson__valid.json b/complaint_search/tests/expected_results/search_with_format_nonjson__valid.json similarity index 100% rename from complaint_search/tests/expected_results/search_with_fmt_nonjson__valid.json rename to complaint_search/tests/expected_results/search_with_format_nonjson__valid.json diff --git a/complaint_search/tests/test_es_interface.py b/complaint_search/tests/test_es_interface.py index c65d6d62..78d79b42 100644 --- a/complaint_search/tests/test_es_interface.py +++ b/complaint_search/tests/test_es_interface.py @@ -53,23 +53,23 @@ def test_search_no_param__valid(self, mock_rget, mock_search): @mock.patch('requests.get', ok=True, content="RGET_OK") @mock.patch('json.dumps') @mock.patch('urllib.urlencode') - def test_search_with_fmt_nonjson__valid(self, mock_urlencode, mock_jdump, mock_rget, mock_search): + def test_search_with_format_nonjson__valid(self, mock_urlencode, mock_jdump, mock_rget, mock_search): mock_search.return_value = 'OK' mock_jdump.return_value = 'JDUMPS_OK' - body = self.load("search_with_fmt_nonjson__valid") - for fmt in ["csv", "xls", "xlsx"]: - res = search(fmt=fmt) + body = self.load("search_with_format_nonjson__valid") + for format in ["csv", "xls", "xlsx"]: + res = search(format=format) self.assertEqual(len(mock_jdump.call_args), 2) self.assertEqual(1, len(mock_jdump.call_args[0])) act_body = mock_jdump.call_args[0][0] diff = deep.diff(body, act_body) if diff: - print "fmt={}".format(fmt) + print "format={}".format(format) diff.print_full() self.assertIsNone(deep.diff(body, act_body)) self.assertEqual(len(mock_urlencode.call_args), 2) self.assertEqual(1, len(mock_urlencode.call_args[0])) - param = {"format": fmt, "source": "JDUMPS_OK"} + param = {"format": format, "source": "JDUMPS_OK"} act_param = mock_urlencode.call_args[0][0] self.assertEqual(param, act_param) @@ -95,9 +95,9 @@ def test_search_with_field__valid(self, mock_search): @mock.patch.object(Elasticsearch, 'search') @mock.patch('requests.get', ok=True, content="RGET_OK") - def test_search_with_fmt__invalid(self, mock_rget, mock_search): + def test_search_with_format__invalid(self, mock_rget, mock_search): mock_search.return_value = 'OK' - res = search(fmt="pdf") + res = search(format="pdf") self.assertIsNone(res) mock_search.assert_not_called() mock_rget.assert_not_called() diff --git a/complaint_search/tests/test_views_search.py b/complaint_search/tests/test_views_search.py index ee303b16..e05d610b 100644 --- a/complaint_search/tests/test_views_search.py +++ b/complaint_search/tests/test_views_search.py @@ -40,28 +40,27 @@ def test_search_cors_headers(self, mock_essearch): @mock.patch('complaint_search.views.datetime') @mock.patch('complaint_search.es_interface.search') - def test_search_with_fmt(self, mock_essearch, mock_dt): + def test_search_with_format(self, mock_essearch, mock_dt): """ - Searching with fmt + Searching with format """ - FMT_CONTENT_TYPE_MAP = { + FORMAT_CONTENT_TYPE_MAP = { "csv": "text/csv", "xls": "application/vnd.ms-excel", "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" } - for k, v in FMT_CONTENT_TYPE_MAP.iteritems(): + for k, v in FORMAT_CONTENT_TYPE_MAP.iteritems(): url = reverse('complaint_search:search') - params = {"fmt": k} + params = {"format": k} mock_essearch.return_value = 'OK' mock_dt.now.return_value = datetime(2017,1,1,12,0) response = self.client.get(url, params) self.assertEqual(response.status_code, status.HTTP_200_OK) - # mock_essearch.assert_called_once_with(fmt=k) - self.assertEqual(response.get('Content-Type'), v) + self.assertIn(v, response.get('Content-Type')) self.assertEqual(response.get('Content-Disposition'), 'attachment; filename="complaints-2017-01-01_12_00.{}"'.format(k)) self.assertEqual('OK', response.content) - mock_essearch.has_calls([ mock.call(fmt=k) for k in FMT_CONTENT_TYPE_MAP ], any_order=True) + mock_essearch.has_calls([ mock.call(format=k) for k in FORMAT_CONTENT_TYPE_MAP ], any_order=True) self.assertEqual(3, mock_essearch.call_count) @mock.patch('complaint_search.es_interface.search') diff --git a/complaint_search/views.py b/complaint_search/views.py index 86dbc014..9aa80c28 100644 --- a/complaint_search/views.py +++ b/complaint_search/views.py @@ -1,18 +1,22 @@ from rest_framework import status -from rest_framework.decorators import api_view +from rest_framework.decorators import api_view, renderer_classes +from rest_framework.renderers import JSONRenderer, BrowsableAPIRenderer from rest_framework.response import Response from django.http import HttpResponse from django.conf import settings from datetime import datetime from elasticsearch import TransportError import es_interface +from complaint_search.renderers import CSVRenderer, XLSRenderer, XLSXRenderer from complaint_search.decorators import catch_es_error from complaint_search.serializer import SearchInputSerializer, SuggestInputSerializer @api_view(['GET']) +@renderer_classes((JSONRenderer, CSVRenderer, XLSRenderer, XLSXRenderer, BrowsableAPIRenderer)) @catch_es_error def search(request): + fixed_qparam = request.query_params QPARAMS_VARS = ('fmt', 'field', 'size', 'frm', 'sort', 'search_term', 'min_date', 'max_date') @@ -28,6 +32,12 @@ def search(request): # for param in request.query_params if param in QPARAMS_VARS + QPARAMS_LISTS} data = {} + + # Add format to data (only checking if it is csv, xls, xlsx, then specific them) + format = request.accepted_renderer.format + if format and format in ('csv', 'xls', 'xlsx'): + data['format'] = format + for param in request.query_params: if param in QPARAMS_VARS: data[param] = request.query_params.get(param) @@ -40,13 +50,6 @@ def search(request): if serializer.is_valid(): results = es_interface.search(**serializer.validated_data) - FMT_CONTENT_TYPE_MAP = { - "json": "application/json", - "csv": "text/csv", - "xls": "application/vnd.ms-excel", - "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" - } - # Local development requires CORS support headers = {} if settings.DEBUG: @@ -56,20 +59,15 @@ def search(request): 'Access-Control-Allow-Methods': 'GET' } - # Putting response together based on format - fmt = serializer.validated_data.get("fmt", 'json') - if fmt == 'json': - return Response(results, headers=headers) - elif fmt in ('csv', 'xls', 'xlsx'): - media_type = FMT_CONTENT_TYPE_MAP.get(fmt) - response = HttpResponse(results, content_type=media_type) - for header in headers: - response[header] = headers[header] + # If format is in csv, xls, xlsx, update its attachment response + # with a filename + if format in ('csv', 'xls', 'xlsx'): + filename = 'complaints-{}.{}'.format( + datetime.now().strftime('%Y-%m-%d_%H_%M'), format) + headers['Content-Disposition'] = 'attachment; filename="{}"'.format(filename) - filename = 'complaints-{}.{}'.format(datetime.now().strftime('%Y-%m-%d_%H_%M'), fmt) - response['Content-Disposition'] = 'attachment; filename="{}"'.format(filename) + return Response(results, headers=headers) - return response else: return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)