Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix bugs in handling clientRedirectUrl #9128

Merged
merged 5 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9127.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.
1 change: 1 addition & 0 deletions changelog.d/9128.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login.
4 changes: 2 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,8 +1504,8 @@ def _expire_sso_extra_attributes(self) -> None:
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({param_name: param})
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, hs: "HomeServer"):
self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
}
} # type: Dict[str, OidcProvider]

async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
Expand Down
4 changes: 3 additions & 1 deletion synapse/rest/synapse/client/pick_idp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(self, hs: "HomeServer"):
self._server_name = hs.hostname

async def _async_render_GET(self, request: SynapseRequest) -> None:
client_redirect_url = parse_string(request, "redirectUrl", required=True)
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding="utf-8"
)
idp = parse_string(request, "idp", required=False)

# if we need to pick an IdP, do so
Expand Down
146 changes: 91 additions & 55 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import time
import urllib.parse
from html.parser import HTMLParser
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse
from typing import Any, Dict, Union
from urllib.parse import urlencode

from mock import Mock

Expand All @@ -38,6 +37,7 @@
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless

try:
Expand Down Expand Up @@ -69,6 +69,12 @@
LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami"

# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'

# the query params in TEST_CLIENT_REDIRECT_URL
EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]


class LoginRestServletTestCase(unittest.HomeserverTestCase):

Expand Down Expand Up @@ -389,23 +395,44 @@ def default_config(self) -> Dict[str, Any]:
},
}

# default OIDC provider
config["oidc_config"] = TEST_OIDC_CONFIG

# additional OIDC providers
config["oidc_providers"] = [
{
"idp_id": "idp1",
"idp_name": "IDP1",
"discover": False,
"issuer": "https://issuer1",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": "https://issuer1/auth",
"token_endpoint": "https://issuer1/token",
"userinfo_endpoint": "https://issuer1/userinfo",
"user_mapping_provider": {
"config": {"localpart_template": "{{ user.sub }}"}
},
}
]
return config

def create_resource_dict(self) -> Dict[str, Resource]:
from synapse.rest.oidc import OIDCResource

d = super().create_resource_dict()
d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
d["/_synapse/oidc"] = OIDCResource(self.hs)
return d

def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
client_redirect_url = "https://x?<abc>"

# first hit the redirect url, which should redirect to our idp picker
channel = self.make_request(
"GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
"/_matrix/client/r0/login/sso/redirect?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0]
Expand All @@ -415,46 +442,22 @@ def test_multi_sso_redirect(self):
self.assertEqual(channel.code, 200, channel.result)

# parse the form to check it has fields assumed elsewhere in this class
class FormPageParser(HTMLParser):
def __init__(self):
super().__init__()

# the values of the hidden inputs: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]

# the values of the radio buttons
self.radios = [] # type: List[Optional[str]]

def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "input":
if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
self.radios.append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
input_name = attr_dict["name"]
assert input_name
self.hiddens[input_name] = attr_dict["value"]

def error(_, message):
self.fail(message)

p = FormPageParser()
p = TestHtmlParser()
p.feed(channel.result["body"].decode("utf-8"))
p.close()

self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"])

self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)

def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
client_redirect_url = "https://x?<abc>"

channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
"/_synapse/client/pick_idp?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
Expand All @@ -470,16 +473,14 @@ def test_multi_sso_redirect_to_cas(self):
service_uri = cas_uri_params["service"][0]
_, service_uri_query = service_uri.split("?", 1)
service_uri_params = urllib.parse.parse_qs(service_uri_query)
self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)

def test_multi_sso_redirect_to_saml(self):
"""If SAML is chosen, should redirect to the SAML server"""
client_redirect_url = "https://x?<abc>"

channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
Expand All @@ -492,16 +493,16 @@ def test_multi_sso_redirect_to_saml(self):
# the RelayState is used to carry the client redirect url
saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, client_redirect_url)
self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)

def test_multi_sso_redirect_to_oidc(self):
def test_login_via_oidc(self):
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
client_redirect_url = "https://x?<abc>"

# pick the default OIDC provider
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
Expand All @@ -521,9 +522,41 @@ def test_multi_sso_redirect_to_oidc(self):
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
self.assertEqual(
self._get_value_from_macaroon(macaroon, "client_redirect_url"),
client_redirect_url,
TEST_CLIENT_REDIRECT_URL,
)

channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})

# that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result)
self.assertTrue(
channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
)
p = TestHtmlParser()
p.feed(channel.text_body)
p.close()

# ... which should contain our redirect link
self.assertEqual(len(p.links), 1)
path, query = p.links[0].split("?", 1)
self.assertEqual(path, "https://x")

# it will have url-encoded the params properly, so we'll have to parse them
params = urllib.parse.parse_qsl(
query, keep_blank_values=True, strict_parsing=True, errors="strict"
)
self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
self.assertEqual(params[2][0], "loginToken")

# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
login_token = params[2][1]
chan = self.make_request(
"POST", "/login", content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")

def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
Expand Down Expand Up @@ -1082,7 +1115,7 @@ def default_config(self):

# whitelist this client URI so we redirect straight to it rather than
# serving a confirmation page
config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
config["sso"] = {"client_whitelist": ["https://x"]}
return config

def create_resource_dict(self) -> Dict[str, Resource]:
Expand All @@ -1095,11 +1128,10 @@ def create_resource_dict(self) -> Dict[str, Resource]:

def test_username_picker(self):
"""Test the happy path of a username picker flow."""
client_redirect_url = "https://whitelisted.client"

# do the start of the login flow
channel = self.helper.auth_via_oidc(
{"sub": "tester", "displayname": "Jonny"}, client_redirect_url
{"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
)

# that should redirect to the username picker
Expand All @@ -1122,7 +1154,7 @@ def test_username_picker(self):
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
self.assertEqual(session.display_name, "Jonny")
self.assertEqual(session.client_redirect_url, client_redirect_url)
self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)

# the expiry time should be about 15 minutes away
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
Expand All @@ -1146,15 +1178,19 @@ def test_username_picker(self):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
# ensure that the returned location starts with the requested redirect URL
self.assertEqual(
location_headers[0][: len(client_redirect_url)], client_redirect_url
# ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1)
self.assertEqual(path, "https://x")

# it will have url-encoded the params properly, so we'll have to parse them
params = urllib.parse.parse_qsl(
query, keep_blank_values=True, strict_parsing=True, errors="strict"
)
self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
self.assertEqual(params[2][0], "loginToken")

# fish the login token out of the returned redirect uri
parts = urlparse(location_headers[0])
query = parse_qs(parts.query)
login_token = query["loginToken"][0]
login_token = params[2][1]

# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
Expand Down
Loading