diff --git a/complaint_search/es_interface.py b/complaint_search/es_interface.py index f1588756..71ed18d7 100644 --- a/complaint_search/es_interface.py +++ b/complaint_search/es_interface.py @@ -183,7 +183,7 @@ def _create_and_append_bool_should_clauses(es_field_name, value_list, filter_list.append(filter_clauses) # List of possible arguments: -# - fmt: format to be returned: "json", "csv", "xls", or "xlsx" +# - format: format to be returned: "json", "csv", "xls", or "xlsx" # - field: field you want to search in: "complaint_what_happened", "company_public_response", "_all" # - size: number of complaints to return # - frm: from which index to start returning @@ -208,7 +208,7 @@ def search(**kwargs): # base default parameters params = { - "fmt": "json", + "format": "json", "field": "complaint_what_happened", "size": 10, "frm": 0, @@ -216,7 +216,6 @@ def search(**kwargs): } params.update(**kwargs) - res = None body = { "from": params.get("frm"), @@ -290,12 +289,12 @@ def search(**kwargs): body["post_filter"]["and"]["filters"]) # format - if params.get("fmt") == "json": + if params.get("format") == "json": ## Create base aggregation body["aggs"] = _create_aggregation(**kwargs) res = get_es().search(index=_COMPLAINT_ES_INDEX, body=body) - elif params.get("fmt") in ["csv", "xls", "xlsx"]: - p = {"format": params.get("fmt"), + elif params.get("format") in ("csv", "xls", "xlsx"): + p = {"format": params.get("format"), "source": json.dumps(body)} p = urllib.urlencode(p) url = "{}/{}/{}/_data?{}".format(_ES_URL, _COMPLAINT_ES_INDEX, diff --git a/complaint_search/renderers.py b/complaint_search/renderers.py new file mode 100644 index 00000000..04512ebc --- /dev/null +++ b/complaint_search/renderers.py @@ -0,0 +1,26 @@ +from rest_framework.renderers import BaseRenderer + +class CSVRenderer(BaseRenderer): + media_type = 'text/csv' + format = 'csv' + + def render(self, data, media_type=None, renderer_context=None): + return data + +class XLSRenderer(BaseRenderer): + media_type = 'application/vnd.ms-excel' + format = 'xls' + charset = None + render_style = 'binary' + + def render(self, data, media_type=None, renderer_context=None): + return data + +class XLSXRenderer(BaseRenderer): + media_type = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + format = 'xlsx' + charset = None + render_style = 'binary' + + def render(self, data, media_type=None, renderer_context=None): + return data \ No newline at end of file diff --git a/complaint_search/serializer.py b/complaint_search/serializer.py index 93701a63..ae7bb05a 100644 --- a/complaint_search/serializer.py +++ b/complaint_search/serializer.py @@ -42,7 +42,7 @@ class SearchInputSerializer(serializers.Serializer): (SORT_CREATED_DATE_DESC, 'Descending Created Date'), (SORT_CREATED_DATE_ASC, 'Ascending Created Date'), ) - fmt = serializers.ChoiceField(FORMAT_CHOICES, required=False) + format = serializers.ChoiceField(FORMAT_CHOICES, required=False) field = serializers.ChoiceField(FIELD_CHOICES, required=False) size = serializers.IntegerField(min_value=1, max_value=100000, required=False) frm = serializers.IntegerField(min_value=0, max_value=100000, required=False) diff --git a/complaint_search/tests/expected_results/search_with_fmt_nonjson__valid.json b/complaint_search/tests/expected_results/search_with_format_nonjson__valid.json similarity index 100% rename from complaint_search/tests/expected_results/search_with_fmt_nonjson__valid.json rename to complaint_search/tests/expected_results/search_with_format_nonjson__valid.json diff --git a/complaint_search/tests/test_es_interface.py b/complaint_search/tests/test_es_interface.py index c65d6d62..78d79b42 100644 --- a/complaint_search/tests/test_es_interface.py +++ b/complaint_search/tests/test_es_interface.py @@ -53,23 +53,23 @@ def test_search_no_param__valid(self, mock_rget, mock_search): @mock.patch('requests.get', ok=True, content="RGET_OK") @mock.patch('json.dumps') @mock.patch('urllib.urlencode') - def test_search_with_fmt_nonjson__valid(self, mock_urlencode, mock_jdump, mock_rget, mock_search): + def test_search_with_format_nonjson__valid(self, mock_urlencode, mock_jdump, mock_rget, mock_search): mock_search.return_value = 'OK' mock_jdump.return_value = 'JDUMPS_OK' - body = self.load("search_with_fmt_nonjson__valid") - for fmt in ["csv", "xls", "xlsx"]: - res = search(fmt=fmt) + body = self.load("search_with_format_nonjson__valid") + for format in ["csv", "xls", "xlsx"]: + res = search(format=format) self.assertEqual(len(mock_jdump.call_args), 2) self.assertEqual(1, len(mock_jdump.call_args[0])) act_body = mock_jdump.call_args[0][0] diff = deep.diff(body, act_body) if diff: - print "fmt={}".format(fmt) + print "format={}".format(format) diff.print_full() self.assertIsNone(deep.diff(body, act_body)) self.assertEqual(len(mock_urlencode.call_args), 2) self.assertEqual(1, len(mock_urlencode.call_args[0])) - param = {"format": fmt, "source": "JDUMPS_OK"} + param = {"format": format, "source": "JDUMPS_OK"} act_param = mock_urlencode.call_args[0][0] self.assertEqual(param, act_param) @@ -95,9 +95,9 @@ def test_search_with_field__valid(self, mock_search): @mock.patch.object(Elasticsearch, 'search') @mock.patch('requests.get', ok=True, content="RGET_OK") - def test_search_with_fmt__invalid(self, mock_rget, mock_search): + def test_search_with_format__invalid(self, mock_rget, mock_search): mock_search.return_value = 'OK' - res = search(fmt="pdf") + res = search(format="pdf") self.assertIsNone(res) mock_search.assert_not_called() mock_rget.assert_not_called() diff --git a/complaint_search/tests/test_views_search.py b/complaint_search/tests/test_views_search.py index ee303b16..e05d610b 100644 --- a/complaint_search/tests/test_views_search.py +++ b/complaint_search/tests/test_views_search.py @@ -40,28 +40,27 @@ def test_search_cors_headers(self, mock_essearch): @mock.patch('complaint_search.views.datetime') @mock.patch('complaint_search.es_interface.search') - def test_search_with_fmt(self, mock_essearch, mock_dt): + def test_search_with_format(self, mock_essearch, mock_dt): """ - Searching with fmt + Searching with format """ - FMT_CONTENT_TYPE_MAP = { + FORMAT_CONTENT_TYPE_MAP = { "csv": "text/csv", "xls": "application/vnd.ms-excel", "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" } - for k, v in FMT_CONTENT_TYPE_MAP.iteritems(): + for k, v in FORMAT_CONTENT_TYPE_MAP.iteritems(): url = reverse('complaint_search:search') - params = {"fmt": k} + params = {"format": k} mock_essearch.return_value = 'OK' mock_dt.now.return_value = datetime(2017,1,1,12,0) response = self.client.get(url, params) self.assertEqual(response.status_code, status.HTTP_200_OK) - # mock_essearch.assert_called_once_with(fmt=k) - self.assertEqual(response.get('Content-Type'), v) + self.assertIn(v, response.get('Content-Type')) self.assertEqual(response.get('Content-Disposition'), 'attachment; filename="complaints-2017-01-01_12_00.{}"'.format(k)) self.assertEqual('OK', response.content) - mock_essearch.has_calls([ mock.call(fmt=k) for k in FMT_CONTENT_TYPE_MAP ], any_order=True) + mock_essearch.has_calls([ mock.call(format=k) for k in FORMAT_CONTENT_TYPE_MAP ], any_order=True) self.assertEqual(3, mock_essearch.call_count) @mock.patch('complaint_search.es_interface.search') diff --git a/complaint_search/views.py b/complaint_search/views.py index 86dbc014..9aa80c28 100644 --- a/complaint_search/views.py +++ b/complaint_search/views.py @@ -1,18 +1,22 @@ from rest_framework import status -from rest_framework.decorators import api_view +from rest_framework.decorators import api_view, renderer_classes +from rest_framework.renderers import JSONRenderer, BrowsableAPIRenderer from rest_framework.response import Response from django.http import HttpResponse from django.conf import settings from datetime import datetime from elasticsearch import TransportError import es_interface +from complaint_search.renderers import CSVRenderer, XLSRenderer, XLSXRenderer from complaint_search.decorators import catch_es_error from complaint_search.serializer import SearchInputSerializer, SuggestInputSerializer @api_view(['GET']) +@renderer_classes((JSONRenderer, CSVRenderer, XLSRenderer, XLSXRenderer, BrowsableAPIRenderer)) @catch_es_error def search(request): + fixed_qparam = request.query_params QPARAMS_VARS = ('fmt', 'field', 'size', 'frm', 'sort', 'search_term', 'min_date', 'max_date') @@ -28,6 +32,12 @@ def search(request): # for param in request.query_params if param in QPARAMS_VARS + QPARAMS_LISTS} data = {} + + # Add format to data (only checking if it is csv, xls, xlsx, then specific them) + format = request.accepted_renderer.format + if format and format in ('csv', 'xls', 'xlsx'): + data['format'] = format + for param in request.query_params: if param in QPARAMS_VARS: data[param] = request.query_params.get(param) @@ -40,13 +50,6 @@ def search(request): if serializer.is_valid(): results = es_interface.search(**serializer.validated_data) - FMT_CONTENT_TYPE_MAP = { - "json": "application/json", - "csv": "text/csv", - "xls": "application/vnd.ms-excel", - "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" - } - # Local development requires CORS support headers = {} if settings.DEBUG: @@ -56,20 +59,15 @@ def search(request): 'Access-Control-Allow-Methods': 'GET' } - # Putting response together based on format - fmt = serializer.validated_data.get("fmt", 'json') - if fmt == 'json': - return Response(results, headers=headers) - elif fmt in ('csv', 'xls', 'xlsx'): - media_type = FMT_CONTENT_TYPE_MAP.get(fmt) - response = HttpResponse(results, content_type=media_type) - for header in headers: - response[header] = headers[header] + # If format is in csv, xls, xlsx, update its attachment response + # with a filename + if format in ('csv', 'xls', 'xlsx'): + filename = 'complaints-{}.{}'.format( + datetime.now().strftime('%Y-%m-%d_%H_%M'), format) + headers['Content-Disposition'] = 'attachment; filename="{}"'.format(filename) - filename = 'complaints-{}.{}'.format(datetime.now().strftime('%Y-%m-%d_%H_%M'), fmt) - response['Content-Disposition'] = 'attachment; filename="{}"'.format(filename) + return Response(results, headers=headers) - return response else: return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)