diff --git a/docs/sections/settings.rst b/docs/sections/settings.rst index 984d2df2..f7dadd5a 100644 --- a/docs/sections/settings.rst +++ b/docs/sections/settings.rst @@ -55,6 +55,20 @@ OPTIONAL. ``int``. Code object expiration after been delivered. Expressed in seconds. Default is ``60*10``. +OIDC_DISCOVERY_CACHE_ENABLE +================ + +OPTIONAL. ``bool``. Enable caching the response on the discovery endpoint, by using default cache. Cache key will be a combination of site URL and types supported by the provider, changing any of these will invalidate stored value. + +Default is ``False``. + +OIDC_DISCOVERY_CACHE_EXPIRE +================ + +OPTIONAL. ``int``. Discovery endpoint cache expiration time expressed in seconds. + +Expressed in seconds. Default is ``60*10``. + OIDC_EXTRA_SCOPE_CLAIMS ======================= diff --git a/example/app/urls.py b/example/app/urls.py index c9755745..acc75adc 100644 --- a/example/app/urls.py +++ b/example/app/urls.py @@ -9,7 +9,7 @@ urlpatterns = [ url(r'^$', TemplateView.as_view(template_name='home.html'), name='home'), - url(r'^accounts/login/$', auth_views.LoginView.as_view(template_name='login.html'), name='login'), + url(r'^accounts/login/$', auth_views.LoginView.as_view(template_name='login.html'), name='login'), # noqa url(r'^accounts/logout/$', auth_views.LogoutView.as_view(next_page='/'), name='logout'), url(r'^', include('oidc_provider.urls', namespace='oidc_provider')), url(r'^admin/', admin.site.urls), diff --git a/oidc_provider/settings.py b/oidc_provider/settings.py index 90750fda..08274042 100644 --- a/oidc_provider/settings.py +++ b/oidc_provider/settings.py @@ -14,7 +14,7 @@ def __init__(self): @property def OIDC_LOGIN_URL(self): """ - REQUIRED. Used to log the user in. By default Django's LOGIN_URL will be used. + OPTIONAL. Used to log the user in. By default Django's LOGIN_URL will be used. """ return settings.LOGIN_URL @@ -48,6 +48,20 @@ def OIDC_CODE_EXPIRE(self): """ return 60*10 + @property + def OIDC_DISCOVERY_CACHE_ENABLE(self): + """ + OPTIONAL. Enable caching the response on the discovery endpoint. + """ + return False + + @property + def OIDC_DISCOVERY_CACHE_EXPIRE(self): + """ + OPTIONAL. Discovery endpoint cache expiration time expressed in seconds. + """ + return 60*60*24 + @property def OIDC_EXTRA_SCOPE_CLAIMS(self): """ diff --git a/oidc_provider/tests/cases/test_authorize_endpoint.py b/oidc_provider/tests/cases/test_authorize_endpoint.py index 508b3e70..7ccd0a23 100644 --- a/oidc_provider/tests/cases/test_authorize_endpoint.py +++ b/oidc_provider/tests/cases/test_authorize_endpoint.py @@ -1,5 +1,3 @@ -from oidc_provider.lib.errors import RedirectUriError - try: from urllib.parse import urlencode, quote except ImportError: @@ -25,15 +23,16 @@ from jwkest.jwt import JWT from oidc_provider import settings +from oidc_provider.lib.endpoints.authorize import AuthorizeEndpoint +from oidc_provider.lib.errors import RedirectUriError +from oidc_provider.lib.utils.authorize import strip_prompt_login from oidc_provider.tests.app.utils import ( create_fake_user, create_fake_client, FAKE_CODE_CHALLENGE, is_code_valid, ) -from oidc_provider.lib.utils.authorize import strip_prompt_login from oidc_provider.views import AuthorizeView -from oidc_provider.lib.endpoints.authorize import AuthorizeEndpoint class AuthorizeEndpointMixin(object): diff --git a/oidc_provider/tests/cases/test_provider_info_endpoint.py b/oidc_provider/tests/cases/test_provider_info_endpoint.py index 2265ef66..1dfe2777 100644 --- a/oidc_provider/tests/cases/test_provider_info_endpoint.py +++ b/oidc_provider/tests/cases/test_provider_info_endpoint.py @@ -1,9 +1,13 @@ +from mock import patch + +from django.core.cache import cache try: from django.urls import reverse except ImportError: from django.core.urlresolvers import reverse from django.test import RequestFactory -from django.test import TestCase +from django.test import TestCase, override_settings + from oidc_provider.views import ProviderInfoView @@ -13,7 +17,11 @@ class ProviderInfoTestCase(TestCase): def setUp(self): self.factory = RequestFactory() - def test_response(self): + def tearDown(self): + cache.clear() + + @patch('oidc_provider.views.ProviderInfoView._build_cache_key') + def test_response(self, build_cache_key): """ See if the endpoint is returning the corresponding server information by checking status, content type, etc. @@ -24,6 +32,32 @@ def test_response(self): response = ProviderInfoView.as_view()(request) + # Caching not available by default. + build_cache_key.assert_not_called() + + self.assertEqual(response.status_code, 200) + self.assertEqual(response['Content-Type'] == 'application/json', True) + self.assertEqual(bool(response.content), True) + + @override_settings(OIDC_DISCOVERY_CACHE_ENABLE=True) + @patch('oidc_provider.views.ProviderInfoView._build_cache_key') + def test_response_with_cache_enabled(self, build_cache_key): + """ + Enable caching on the discovery endpoint and ensure data is being saved on cache. + """ + build_cache_key.return_value = 'key' + + url = reverse('oidc_provider:provider-info') + + request = self.factory.get(url) + + response = ProviderInfoView.as_view()(request) + self.assertEqual(response.status_code, 200) + build_cache_key.assert_called_once() + + assert 'authorization_endpoint' in cache.get('key') + + response = ProviderInfoView.as_view()(request) self.assertEqual(response.status_code, 200) self.assertEqual(response['Content-Type'] == 'application/json', True) self.assertEqual(bool(response.content), True) diff --git a/oidc_provider/views.py b/oidc_provider/views.py index 79f0fb88..4a2c94ce 100644 --- a/oidc_provider/views.py +++ b/oidc_provider/views.py @@ -1,3 +1,4 @@ +import hashlib import logging from django.views.decorators.csrf import csrf_exempt @@ -20,6 +21,7 @@ from django.core.urlresolvers import reverse from django.db import transaction from django.contrib.auth import logout as django_user_logout +from django.core.cache import cache from django.http import JsonResponse, HttpResponse from django.shortcuts import render from django.template.loader import render_to_string @@ -256,7 +258,16 @@ def set_headers(response): class ProviderInfoView(View): - def get(self, request, *args, **kwargs): + _types_supported = None + + @property + def types_supported(self): + if self._types_supported is None: + self._types_supported = [ + response_type.value for response_type in ResponseType.objects.all()] + return self._types_supported + + def _build_response_dict(self, request): dic = dict() site_url = get_site_url(request=request) @@ -268,8 +279,7 @@ def get(self, request, *args, **kwargs): dic['end_session_endpoint'] = site_url + reverse('oidc_provider:end-session') dic['introspection_endpoint'] = site_url + reverse('oidc_provider:token-introspection') - types_supported = [response_type.value for response_type in ResponseType.objects.all()] - dic['response_types_supported'] = types_supported + dic['response_types_supported'] = self.types_supported dic['jwks_uri'] = site_url + reverse('oidc_provider:jwks') @@ -284,7 +294,29 @@ def get(self, request, *args, **kwargs): if settings.get('OIDC_SESSION_MANAGEMENT_ENABLE'): dic['check_session_iframe'] = site_url + reverse('oidc_provider:check-session-iframe') - response = JsonResponse(dic) + return dic + + def _build_cache_key(self, request): + """ + Cache key will be a combination of site URL and types supported by the provider. + """ + key_data = get_site_url(request=request) + ''.join(self.types_supported) + key_hash = hashlib.md5(key_data.encode('utf-8')).hexdigest() + return f'oidc_discovery_{key_hash}' + + def get(self, request): + if settings.get('OIDC_DISCOVERY_CACHE_ENABLE'): + cache_key = self._build_cache_key(request) + cached_dict = cache.get(cache_key) + if cached_dict: + response_dict = cached_dict + else: + response_dict = self._build_response_dict(request) + cache.set(cache_key, response_dict, settings.get('OIDC_DISCOVERY_CACHE_EXPIRE')) + else: + response_dict = self._build_response_dict(request) + + response = JsonResponse(response_dict) response['Access-Control-Allow-Origin'] = '*' return response