diff --git a/social/backends/saml.py b/social/backends/saml.py index 206b26c6f..a249c0b40 100644 --- a/social/backends/saml.py +++ b/social/backends/saml.py @@ -11,7 +11,7 @@ from onelogin.saml2.settings import OneLogin_Saml2_Settings from social.backends.base import BaseAuth -from social.exceptions import AuthFailed +from social.exceptions import AuthFailed, AuthMissingParameter # Helpful constants: OID_COMMON_NAME = "urn:oid:2.5.4.3" @@ -256,7 +256,10 @@ def _create_saml_auth(self, idp): def auth_url(self): """Get the URL to which we must redirect in order to authenticate the user""" - idp_name = self.strategy.request_data()['idp'] + try: + idp_name = self.strategy.request_data()['idp'] + except KeyError: + raise AuthMissingParameter(self, 'idp') auth = self._create_saml_auth(idp=self.get_idp(idp_name)) # Below, return_to sets the RelayState, which can contain # arbitrary data. We use it to store the specific SAML IdP diff --git a/social/tests/backends/test_saml.py b/social/tests/backends/test_saml.py index 2cd552087..f9fd4d41b 100644 --- a/social/tests/backends/test_saml.py +++ b/social/tests/backends/test_saml.py @@ -15,6 +15,7 @@ pass from social.tests.backends.base import BaseBackendTest +from social.exceptions import AuthMissingParameter from social.p3 import urlparse, urlunparse, urlencode, parse_qs DATA_DIR = path.join(path.dirname(__file__), 'data') @@ -64,8 +65,6 @@ def install_http_intercepts(self, start_url, return_url): body='foobar') def do_start(self): - # pretend we've started with a URL like /login/saml/?idp=testshib: - self.strategy.set_request_data({'idp': 'testshib'}, self.backend) start_url = self.backend.start().url # Modify the start URL to make the SAML request consistent # from test to test: @@ -91,8 +90,15 @@ def test_metadata_generation(self): def test_login(self): """Test that we can authenticate with a SAML IdP (TestShib)""" + # pretend we've started with a URL like /login/saml/?idp=testshib: + self.strategy.set_request_data({'idp': 'testshib'}, self.backend) self.do_login() + def test_login_no_idp(self): + """Logging in without an idp param should raise AuthMissingParameter""" + with self.assertRaises(AuthMissingParameter): + self.do_start() + def modify_start_url(self, start_url): """ Given a SAML redirect URL, parse it and change the ID to