diff --git a/elasticutils/contrib/django/__init__.py b/elasticutils/contrib/django/__init__.py index a0cfd41..f347c66 100644 --- a/elasticutils/contrib/django/__init__.py +++ b/elasticutils/contrib/django/__init__.py @@ -9,6 +9,7 @@ try: from django.conf import settings from django.shortcuts import render + from django.utils.decorators import decorator_from_middleware_with_args except ImportError: pass @@ -16,6 +17,14 @@ log = logging.getLogger('elasticutils') +ES_EXCEPTIONS = ( + pyelasticsearch.exceptions.ConnectionError, + pyelasticsearch.exceptions.ElasticHttpError, + pyelasticsearch.exceptions.ElasticHttpNotFoundError, + pyelasticsearch.exceptions.Timeout +) + + def get_es(**overrides): """Return a pyelasticsearch ElasticSearch object using settings from ``settings.py``. @@ -58,12 +67,8 @@ def wrapper(*args, **kw): return wrapper -def es_required_or_50x(disabled_template='elasticutils/501.html', - error_template='elasticutils/503.html'): - """Wrap a Django view and handle ElasticSearch errors. - - This wraps a Django view and returns 501 or 503 status codes and - pages if things go awry. +class ESExceptionMiddleware(object): + """Middleware to handle ElasticSearch errors. HTTP 501 Returned when ``ES_DISABLED`` is True. @@ -80,57 +85,65 @@ def es_required_or_50x(disabled_template='elasticutils/501.html', * error: A string version of the exception thrown. - :arg disabled_template: The template to use when ES_DISABLED is - True. + :arg disabled_template: The template to use when ES_DISABLED is True. Defaults to ``elasticutils/501.html``. + :arg error_template: The template to use when ElasticSearch isn't working properly, is missing an index, or something along those lines. Defaults to ``elasticutils/503.html``. + """ - Examples:: + def __init__(self, disabled_template=None, error_template=None): + self.disabled_template = ( + disabled_template or 'elasticutils/501.html') + self.error_template = ( + error_template or 'elasticutils/503.html') - # This creates a home_view and decorates it to use the - # default templates. + def process_request(self, request): + if getattr(settings, 'ES_DISABLED', False): + response = render(request, self.disabled_template) + response.status_code = 501 + return response - @es_required_or_50x() - def home_view(request): - ... + def process_exception(self, request, exception): + if issubclass(exception.__class__, ES_EXCEPTIONS): + response = render(request, self.error_template, + {'error': exception}) + response.status_code = 503 + return response - # This creates a search_view and overrides the templates +""" +The following decorator wraps a Django view and handles ElasticSearch errors. - @es_required_or_50x(disabled_template='search/es_disabled.html', - error_template('search/es_down.html') - def search_view(request): - ... +This wraps a Django view and returns 501 or 503 status codes and +pages if things go awry. - """ - def wrap(fun): - @wraps(fun) - def wrapper(request, *args, **kw): - if getattr(settings, 'ES_DISABLED', False): - response = render(request, disabled_template) - response.status_code = 501 - return response - - try: - return fun(request, *args, **kw) - - except (pyelasticsearch.exceptions.ConnectionError, - pyelasticsearch.exceptions.ElasticHttpError, - pyelasticsearch.exceptions.ElasticHttpNotFoundError, - pyelasticsearch.exceptions.Timeout) as exc: - response = render(request, error_template, {'error': exc}) - response.status_code = 503 - return response - - return wrapper - - return wrap +See the above middleware for explanation of the arguments. + +Examples:: + + # This creates a home_view and decorates it to use the + # default templates. + + @es_required_or_50x() + def home_view(request): + ... + + + # This creates a search_view and overrides the templates + + @es_required_or_50x(disabled_template='search/es_disabled.html', + error_template('search/es_down.html') + def search_view(request): + ... + +""" +es_required_or_50x = decorator_from_middleware_with_args(ESExceptionMiddleware) class S(elasticutils.S): diff --git a/elasticutils/templates/elasticutils/501.html b/elasticutils/templates/elasticutils/501.html new file mode 100644 index 0000000..e69de29 diff --git a/elasticutils/templates/elasticutils/503.html b/elasticutils/templates/elasticutils/503.html new file mode 100644 index 0000000..e69de29 diff --git a/elasticutils/tests/test_django.py b/elasticutils/tests/test_django.py index f83ab78..fba2bff 100644 --- a/elasticutils/tests/test_django.py +++ b/elasticutils/tests/test_django.py @@ -20,9 +20,12 @@ try: from django.conf import settings + from django.test import RequestFactory + from django.test.utils import override_settings from elasticutils.contrib.django import ( - S, F, get_es, InvalidFieldActionError) + S, F, get_es, InvalidFieldActionError, ES_EXCEPTIONS, + ESExceptionMiddleware, es_required_or_50x) from elasticutils.contrib.django.tasks import ( index_objects, unindex_objects) from elasticutils.tests.django_utils import ( @@ -368,3 +371,49 @@ def test_tasks(self): unindex_objects(FakeDjangoMappingType, [1, 2, 3]) FakeDjangoMappingType.refresh_index() eq_(FakeDjangoMappingType.search().count(), 0) + + +class MiddlewareTest(DjangoElasticTestCase): + + def setUp(self): + super(MiddlewareTest, self).setUp() + + def view(request, exc): + raise exc + + self.func = view + self.fake_request = RequestFactory().get('/') + + def test_exceptions(self): + for exc in ES_EXCEPTIONS: + response = ESExceptionMiddleware().process_exception( + self.fake_request, exc(Exception)) + eq_(response.status_code, 503) + + @override_settings(ES_DISABLED=True) + def test_es_disabled(self): + response = ESExceptionMiddleware().process_request(self.fake_request) + eq_(response.status_code, 501) + + +class DecoratorTest(DjangoElasticTestCase): + + def setUp(self): + super(DecoratorTest, self).setUp() + + @es_required_or_50x() + def view(request, exc): + raise exc + + self.func = view + self.fake_request = RequestFactory().get('/') + + def test_exceptions(self): + for exc in ES_EXCEPTIONS: + response = self.func(self.fake_request, exc(Exception)) + eq_(response.status_code, 503) + + @override_settings(ES_DISABLED=True) + def test_es_disabled(self): + response = self.func(self.fake_request) + eq_(response.status_code, 501) diff --git a/test_settings.py b/test_settings.py index 424804d..6db843c 100644 --- a/test_settings.py +++ b/test_settings.py @@ -1,6 +1,15 @@ +import os + + +ROOT = os.path.abspath(os.path.dirname(__file__)) + + ES_URLS = ['http://localhost:9200'] ES_INDEXES = {'default': ['elasticutilstest']} ES_TIMEOUT = 10 ES_DISABLED = False CELERY_ALWAYS_EAGER = True + +SECRET_KEY = 'super_secret' +TEMPLATE_DIRS = ('%s/elasticutils/templates' % ROOT,)