Skip to content

Commit

Permalink
Merge pull request #208 from johnpaulett/extra-params
Browse files Browse the repository at this point in the history
get_extra_params and OIDC_AUTHENTICATE_CLASS
  • Loading branch information
akatsoulas committed Mar 29, 2018
2 parents d030c01 + cf91ccd commit a1e76b7
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
12 changes: 12 additions & 0 deletions docs/settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ of ``mozilla-django-oidc``.
When using a custom callback view, it is generally a good idea to subclass the
default ``OIDCAuthenticationCallbackView`` and override the methods you want to change.

.. py:attribute:: OIDC_AUTHENTICATE_CLASS
:default: ``mozilla_django_oidc.views.OIDCAuthenticationRequestView``

Allows you to substitute a custom class-based view to be used as OpenID Connect
authenticate URL.

.. note::

When using a custom authenticate view, it is generally a good idea to subclass the
default ``OIDCAuthenticationRequestView`` and override the methods you want to change.

.. py:attribute:: OIDC_RP_SCOPES
:default: ``openid email``
Expand Down
10 changes: 9 additions & 1 deletion mozilla_django_oidc/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@

OIDCCallbackClass = import_string(CALLBACK_CLASS_PATH)


DEFAULT_AUTHENTICATE_CLASS = 'mozilla_django_oidc.views.OIDCAuthenticationRequestView'
AUTHENTICATE_CLASS_PATH = import_from_settings(
'OIDC_AUTHENTICATE_CLASS', DEFAULT_AUTHENTICATE_CLASS
)

OIDCAuthenticateClass = import_string(AUTHENTICATE_CLASS_PATH)

urlpatterns = [
url(r'^callback/$', OIDCCallbackClass.as_view(),
name='oidc_authentication_callback'),
url(r'^authenticate/$', views.OIDCAuthenticationRequestView.as_view(),
url(r'^authenticate/$', OIDCAuthenticateClass.as_view(),
name='oidc_authentication_init'),
url(r'^logout/$', views.OIDCLogoutView.as_view(), name='oidc_logout'),
]
6 changes: 4 additions & 2 deletions mozilla_django_oidc/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def get(self, request):
'state': state,
}

extra = import_from_settings('OIDC_AUTH_REQUEST_EXTRA_PARAMS', {})
params.update(extra)
params.update(self.get_extra_params(request))

if import_from_settings('OIDC_USE_NONCE', True):
nonce = get_random_string(import_from_settings('OIDC_NONCE_SIZE', 32))
Expand All @@ -163,6 +162,9 @@ def get(self, request):
redirect_url = '{url}?{query}'.format(url=self.OIDC_OP_AUTH_ENDPOINT, query=query)
return HttpResponseRedirect(redirect_url)

def get_extra_params(self, request):
return import_from_settings('OIDC_AUTH_REQUEST_EXTRA_PARAMS', {})


class OIDCLogoutView(View):
"""Logout helper view"""
Expand Down
35 changes: 35 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,41 @@ def test_get_with_audience(self, mock_random_string):
self.assertEqual(o.hostname, 'server.example.com')
self.assertEqual(o.path, '/auth')

@override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT='https://server.example.com/auth')
@override_settings(OIDC_RP_CLIENT_ID='example_id')
@patch('mozilla_django_oidc.views.get_random_string')
@patch('mozilla_django_oidc.views.OIDCAuthenticationRequestView.get_extra_params')
def test_get_with_overridden_extra_params(self, mock_extra_params, mock_random_string):
"""Test overriding OIDCAuthenticationRequestView.get_extra_params()."""
mock_random_string.return_value = 'examplestring'

mock_extra_params.return_value = {
'connection': 'foo'
}

url = reverse('oidc_authentication_init')
request = self.factory.get(url)
request.session = dict()
login_view = views.OIDCAuthenticationRequestView.as_view()
response = login_view(request)
self.assertEqual(response.status_code, 302)

o = urlparse(response.url)
expected_query = {
'response_type': ['code'],
'scope': ['openid email'],
'client_id': ['example_id'],
'redirect_uri': ['http://testserver/callback/'],
'state': ['examplestring'],
'nonce': ['examplestring'],
'connection': ['foo'],
}
self.assertDictEqual(parse_qs(o.query), expected_query)
self.assertEqual(o.hostname, 'server.example.com')
self.assertEqual(o.path, '/auth')

mock_extra_params.assert_called_with(request)

@override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT='https://server.example.com/auth')
@override_settings(OIDC_RP_CLIENT_ID='example_id')
def test_next_url(self):
Expand Down

0 comments on commit a1e76b7

Please sign in to comment.