diff --git a/docs/tutorial/tutorial_01.rst b/docs/tutorial/tutorial_01.rst index fdb1c3edc..3eaf4cc99 100644 --- a/docs/tutorial/tutorial_01.rst +++ b/docs/tutorial/tutorial_01.rst @@ -8,17 +8,16 @@ You want to make your own :term:`Authorization Server` to issue access tokens to Start Your App -------------- During this tutorial you will make an XHR POST from a Heroku deployed app to your localhost instance. -Since the domain that will originate the request (the app on Heroku) is different than the destination domain (your local instance), -you will need to install the `django-cors-headers `_ app. +Since the domain that will originate the request (the app on Heroku) is different than the destination domain (your local instance), you will need to use the cors-middleware that we're providing. These "cross-domain" requests are by default forbidden by web browsers unless you use `CORS `_. -Create a virtualenv and install `django-oauth-toolkit` and `django-cors-headers`: +Create a virtualenv and install `django-oauth-toolkit`: :: - pip install django-oauth-toolkit django-cors-headers + pip install django-oauth-toolkit -Start a Django project, add `oauth2_provider` and `corsheaders` to the installed apps, and enable admin: +Start a Django project, add `oauth2_provider` to the installed apps, and enable admin: .. code-block:: python @@ -26,7 +25,6 @@ Start a Django project, add `oauth2_provider` and `corsheaders` to the installed 'django.contrib.admin', # ... 'oauth2_provider', - 'corsheaders', } Include the Django OAuth Toolkit urls in your `urls.py`, choosing the urlspace you prefer. For example: @@ -46,17 +44,11 @@ Include the CORS middleware in your `settings.py`: MIDDLEWARE_CLASSES = ( # ... - 'corsheaders.middleware.CorsMiddleware', + 'oauth2_provider.middleware.CorsMiddleware', # ... ) -Allow CORS requests from all domains (just for the scope of this tutorial): - -.. code-block:: python - - CORS_ORIGIN_ALLOW_ALL = True - -.. _loginTemplate: +This will allow CORS requests from the redirect uris of your applications. Include the required hidden input in your login template, `registration/login.html`. The ``{{ next }}`` template context variable will be populated with the correct diff --git a/oauth2_provider/middleware.py b/oauth2_provider/middleware.py index 33eab12d5..4d9cf160c 100644 --- a/oauth2_provider/middleware.py +++ b/oauth2_provider/middleware.py @@ -1,6 +1,9 @@ +from django import http from django.contrib.auth import authenticate from django.utils.cache import patch_vary_headers +from .models import Application + class OAuth2TokenMiddleware(object): """ @@ -32,3 +35,46 @@ def process_request(self, request): def process_response(self, request, response): patch_vary_headers(response, ('Authorization',)) return response + +HEADERS = ('x-requested-with', 'content-type', 'accept', 'origin', + 'authorization', 'x-csrftoken') +METHODS = ('GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS') + + +class CorsMiddleware(object): + def process_request(self, request): + '''If this is a preflight-request, we must always return 200''' + if (request.method == 'OPTIONS' and + 'HTTP_ACCESS_CONTROL_REQUEST_METHOD' in request.META): + return http.HttpResponse() + return None + + def process_response(self, request, response): + '''Add cors-headers to request if they can be derived correctly''' + try: + cors_allow_origin = _get_cors_allow_origin_header(request) + except Application.NoSuitableOriginFoundError: + pass + else: + response['Access-Control-Allow-Origin'] = cors_allow_origin + response['Access-Control-Allow-Credentials'] = 'true' + if request.method == 'OPTIONS': + response['Access-Control-Allow-Headers'] = ', '.join(HEADERS) + response['Access-Control-Allow-Methods'] = ', '.join(METHODS) + return response + + +def _get_cors_allow_origin_header(request): + '''Fetch the oauth-application that is responsible for making the + request and return a sutible cors-header, or None + ''' + origin = request.META.get('HTTP_ORIGIN') + if origin: + try: + app = Application.objects.filter(redirect_uris__contains=origin)[0] + except IndexError: + # No application for this origin found + pass + else: + return app.get_cors_header(origin) + raise Application.NoSuitableOriginFoundError diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index f87395002..c07fb25ba 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -120,12 +120,31 @@ def clean(self): error = _('Redirect_uris could not be empty with {0} grant_type') raise ValidationError(error.format(self.authorization_grant_type)) + def get_cors_header(self, origin): + '''Return a proper cors-header for this origin, in the context of this + application. + + :param origin: Origin-url from HTTP-request. + :raises: Application.NoSuitableOriginFoundError + ''' + parsed_origin = urlparse(origin) + for allowed_uri in self.redirect_uris.split(): + parsed_allowed_uri = urlparse(allowed_uri) + if (parsed_allowed_uri.scheme == parsed_origin.scheme and + parsed_allowed_uri.netloc == parsed_origin.netloc and + parsed_allowed_uri.port == parsed_origin.port): + return origin + raise Application.NoSuitableOriginFoundError + def get_absolute_url(self): return reverse('oauth2_provider:detail', args=[str(self.id)]) def __str__(self): return self.name or self.client_id + class NoSuitableOriginFoundError(Exception): + pass + class Application(AbstractApplication): class Meta(AbstractApplication.Meta): diff --git a/oauth2_provider/tests/test_cors_middleware.py b/oauth2_provider/tests/test_cors_middleware.py new file mode 100644 index 000000000..f7d03e3da --- /dev/null +++ b/oauth2_provider/tests/test_cors_middleware.py @@ -0,0 +1,85 @@ +from datetime import timedelta + +from django.test import TestCase, Client, override_settings +from django.utils import timezone +from django.conf.urls import patterns, url +from django.http import HttpResponse +from django.views.generic import View + +from ..models import AccessToken, get_application_model +from django.contrib.auth import get_user_model + + +Application = get_application_model() +UserModel = get_user_model() + + +class MockView(View): + def post(self, request): + return HttpResponse() + +urlpatterns = patterns( + '', + url(r'^cors-test/$', MockView.as_view()), +) + + +@override_settings( + ROOT_URLCONF='oauth2_provider.tests.test_cors_middleware', + AUTHENTICATION_BACKENDS=('oauth2_provider.backends.OAuth2Backend',), + MIDDLEWARE_CLASSES=( + 'oauth2_provider.middleware.OAuth2TokenMiddleware', + 'oauth2_provider.middleware.CorsMiddleware', + )) +class TestCORSMiddleware(TestCase): + def setUp(self): + self.user = UserModel.objects.create_user('test_user', 'test@user.com') + self.application = Application.objects.create( + name='Test Application', + redirect_uris='https://foo.bar', + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + ) + + self.access_token = AccessToken.objects.create( + user=self.user, + scope='read write', + expires=timezone.now() + timedelta(seconds=300), + token='secret-access-token-key', + application=self.application + ) + + auth_header = "Bearer {0}".format(self.access_token.token) + self.client = Client(HTTP_AUTHORIZATION=auth_header) + + def test_cors_successful(self): + '''Ensure that we get cors-headers according to our oauth-app''' + resp = self.client.post('/cors-test/', HTTP_ORIGIN='https://foo.bar') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp['Access-Control-Allow-Origin'], 'https://foo.bar') + self.assertEqual(resp['Access-Control-Allow-Credentials'], 'true') + + def test_cors_no_auth(self): + '''Ensure that CORS-headers are sent non-authenticated requests''' + client = Client() + resp = client.post('/cors-test/', HTTP_ORIGIN='https://foo.bar') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp['Access-Control-Allow-Origin'], 'https://foo.bar') + self.assertEqual(resp['Access-Control-Allow-Credentials'], 'true') + + def test_cors_wrong_origin(self): + '''Ensure that CORS-headers aren't sent to requests from wrong origin''' + resp = self.client.post('/cors-test/', HTTP_ORIGIN='https://bar.foo') + self.assertEqual(resp.status_code, 200) + self.assertFalse(resp.has_header('Access-Control-Allow-Origin')) + + def test_cors_200_preflight(self): + '''Ensure that preflight always get 200 responses''' + resp = self.client.options('/cors-test/', + HTTP_ACCESS_CONTROL_REQUEST_METHOD='GET', + HTTP_ORIGIN='https://foo.bar') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp['Access-Control-Allow-Origin'], 'https://foo.bar') + self.assertTrue(resp.has_header('Access-Control-Allow-Headers')) + self.assertTrue(resp.has_header('Access-Control-Allow-Methods'))