Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Preparatory refactors of OidcHandler #9067

Merged
merged 4 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9067.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.
304 changes: 163 additions & 141 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from urllib.parse import urlencode

import attr
Expand All @@ -35,7 +35,6 @@
from twisted.web.client import readBody

from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
Expand Down Expand Up @@ -85,12 +84,15 @@ def __str__(self):
return self.error


class OidcHandler(BaseHandler):
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""

def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._store = hs.get_datastore()

self._token_generator = OidcSessionTokenGenerator(hs)

self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
Expand All @@ -116,7 +118,6 @@ def __init__(self, hs: "HomeServer"):

self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key

# identifier for the external_ids table
self.idp_id = "oidc"
Expand Down Expand Up @@ -519,11 +520,13 @@ async def handle_redirect_request(
if not client_redirect_url:
client_redirect_url = b""

cookie = self._generate_oidc_session_token(
cookie = self._token_generator.generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
session_data=OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
),
)
request.addCookie(
SESSION_COOKIE_NAME,
Expand Down Expand Up @@ -628,11 +631,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:

# Deserialize the session token and verify it.
try:
(
nonce,
client_redirect_url,
ui_auth_session_id,
) = self._verify_oidc_session_token(session, state)
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
Expand Down Expand Up @@ -674,14 +675,14 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
else:
logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return

# first check if we're doing a UIA
if ui_auth_session_id:
if session_data.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
Expand All @@ -690,141 +691,20 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
return

return await self._sso_handler.complete_sso_ui_auth_request(
self.idp_id, remote_user_id, ui_auth_session_id, request
self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
)

# otherwise, it's a login

# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
userinfo, token, request, client_redirect_url
userinfo, token, request, session_data.client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))

def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.

When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.

Args:
state: The ``state`` parameter passed to the OIDC provider.
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.

Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
if ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))

return macaroon.serialize()

def _verify_oidc_session_token(
self, session: bytes, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.

This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.

Args:
session: The session token to verify
state: The state the OIDC provider gave back

Returns:
The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)

v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)

v.verify(macaroon, self._macaroon_secret_key)

# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None

return nonce, client_redirect_url, ui_auth_session_id

def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.

Args:
macaroon: the token
key: the key of the caveat to extract

Returns:
The extracted value

Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))

def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self.clock.time_msec()
return now < expiry

async def _complete_oidc_login(
self,
userinfo: UserInfo,
Expand Down Expand Up @@ -901,8 +781,8 @@ async def grandfather_existing_users() -> Optional[str]:
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)

user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
user_id = UserID(attributes.localpart, self._server_name).to_string()
users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
Expand Down Expand Up @@ -954,6 +834,148 @@ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
return str(remote_user_id)


class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""

def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._server_name = hs.hostname
self._macaroon_secret_key = hs.config.key.macaroon_secret_key

def generate_oidc_session_token(
self,
state: str,
session_data: "OidcSessionData",
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.

When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.

Args:
state: The ``state`` parameter passed to the OIDC provider.
session_data: data to include in the session token.
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.

Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
if session_data.ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))

return macaroon.serialize()

def verify_oidc_session_token(
self, session: bytes, state: str
) -> "OidcSessionData":
"""Verifies and extract an OIDC session token.

This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.

Args:
session: The session token to verify
state: The state the OIDC provider gave back

Returns:
The data extracted from the session cookie
"""
macaroon = pymacaroons.Macaroon.deserialize(session)

v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)

v.verify(macaroon, self._macaroon_secret_key)

# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None

return OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
)

def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.

Args:
macaroon: the token
key: the key of the caveat to extract

Returns:
The extracted value

Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))

def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry


@attr.s(frozen=True, slots=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""

# The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str)

# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)

# The session ID of the ongoing UI Auth (None if this is a login)
ui_auth_session_id = attr.ib(type=Optional[str], default=None)


UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
)
Expand Down
Loading