Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Share the code for handling required attributes between CAS & SAML. #9326

Merged
merged 4 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9326.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Share the code for handling required attributes between the CAS and SAML handlers.
32 changes: 30 additions & 2 deletions synapse/config/cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List

from synapse.config.sso import SsoAttributeRequirement

from ._base import Config
from ._util import validate_config


class CasConfig(Config):
Expand All @@ -38,12 +43,16 @@ def read_config(self, config, **kwargs):
public_base_url + "_matrix/client/r0/login/cas/ticket"
)
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes") or {}
required_attributes = cas_config.get("required_attributes") or {}
self.cas_required_attributes = _parsed_required_attributes_def(
required_attributes
)

else:
self.cas_server_url = None
self.cas_service_url = None
self.cas_displayname_attribute = None
self.cas_required_attributes = {}
self.cas_required_attributes = []

def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
Expand Down Expand Up @@ -75,3 +84,22 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
# userGroup: "staff"
# department: None
"""


# CAS uses a legacy required attributes mapping, not the one provided by
# SsoAttributeRequirement.
REQUIRED_ATTRIBUTES_SCHEMA = {
"type": "object",
"additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
}


def _parsed_required_attributes_def(
required_attributes: Any,
) -> List[SsoAttributeRequirement]:
validate_config(
REQUIRED_ATTRIBUTES_SCHEMA,
required_attributes,
config_path=("cas_config", "required_attributes"),
)
return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]
25 changes: 5 additions & 20 deletions synapse/config/saml2_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import logging
from typing import Any, List

import attr

from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module

Expand Down Expand Up @@ -396,32 +395,18 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
}


@attr.s(frozen=True)
class SamlAttributeRequirement:
"""Object describing a single requirement for SAML attributes."""

attribute = attr.ib(type=str)
value = attr.ib(type=str)

JSON_SCHEMA = {
"type": "object",
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
"required": ["attribute", "value"],
}


ATTRIBUTE_REQUIREMENTS_SCHEMA = {
"type": "array",
"items": SamlAttributeRequirement.JSON_SCHEMA,
"items": SsoAttributeRequirement.JSON_SCHEMA,
}


def _parse_attribute_requirements_def(
attribute_requirements: Any,
) -> List[SamlAttributeRequirement]:
) -> List[SsoAttributeRequirement]:
validate_config(
ATTRIBUTE_REQUIREMENTS_SCHEMA,
attribute_requirements,
config_path=["saml2_config", "attribute_requirements"],
config_path=("saml2_config", "attribute_requirements"),
)
return [SamlAttributeRequirement(**x) for x in attribute_requirements]
return [SsoAttributeRequirement(**x) for x in attribute_requirements]
19 changes: 18 additions & 1 deletion synapse/config/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from typing import Any, Dict, Optional

import attr

from ._base import Config


@attr.s(frozen=True)
class SsoAttributeRequirement:
"""Object describing a single requirement for SSO attributes."""

attribute = attr.ib(type=str)
# If a value is not given, than the attribute must simply exist.
value = attr.ib(type=Optional[str])

JSON_SCHEMA = {
"type": "object",
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
"required": ["attribute", "value"],
}


class SSOConfig(Config):
"""SSO Configuration
"""
Expand Down
40 changes: 11 additions & 29 deletions synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree import ElementTree as ET

import attr
Expand Down Expand Up @@ -49,7 +49,7 @@ def __str__(self):
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]])
attributes = attr.ib(type=Dict[str, List[Optional[str]]])


class CasHandler:
Expand Down Expand Up @@ -169,7 +169,7 @@ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:

# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
attributes = {} # type: Dict[str, List[Optional[str]]]
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
Expand All @@ -182,7 +182,7 @@ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
attributes.setdefault(tag, []).append(attribute.text)

# Ensure a user was found.
if user is None:
Expand Down Expand Up @@ -303,29 +303,10 @@ async def _handle_cas_response(

# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in cas_response.attributes:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return

# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
if not self._sso_handler.check_required_attributes(
request, cas_response.attributes, self._cas_required_attributes
):
return

# Call the mapper to register/login the user

Expand Down Expand Up @@ -372,9 +353,10 @@ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
if failures:
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")

# Arbitrarily use the first attribute found.
display_name = cas_response.attributes.get(
self._cas_displayname_attribute, None
)
self._cas_displayname_attribute, [None]
)[0]

return UserAttributes(localpart=localpart, display_name=display_name)

Expand Down
26 changes: 4 additions & 22 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
Expand Down Expand Up @@ -239,12 +238,10 @@ async def _handle_authn_response(

# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
if not self._sso_handler.check_required_attributes(
request, saml2_auth.ava, self._saml2_attribute_requirements
):
return

# Call the mapper to register/login the user
try:
Expand Down Expand Up @@ -373,21 +370,6 @@ def expire_sessions(self):
del self._outstanding_requests_dict[reqid]


def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return True

logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
return False


DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
Expand Down
71 changes: 71 additions & 0 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Expand All @@ -34,6 +36,7 @@

from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
Expand Down Expand Up @@ -893,6 +896,41 @@ def _expire_old_sessions(self):
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]

def check_required_attributes(
self,
request: SynapseRequest,
attributes: Mapping[str, List[Any]],
attribute_requirements: Iterable[SsoAttributeRequirement],
) -> bool:
"""
Confirm that the required attributes were present in the SSO response.

If all requirements are met, this will return True.

If any requirement is not met, then the request will be finalized by
showing an error page to the user and False will be returned.

Args:
request: The request to (potentially) respond to.
attributes: The attributes from the SSO IdP.
attribute_requirements: The requirements that attributes must meet.

Returns:
True if all requirements are met, False if any attribute fails to
meet the requirement.

"""
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in attribute_requirements:
if not _check_attribute_requirement(attributes, requirement):
self.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return False

return True


def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
Expand All @@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")


def _check_attribute_requirement(
attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
) -> bool:
"""Check if SSO attributes meet the proper requirements.

Args:
attributes: A mapping of attributes to an iterable of one or more values.
requirement: The configured requirement to check.

Returns:
True if the required attribute was found and had a proper value.
"""
if req.attribute not in attributes:
logger.info("SSO attribute missing: %s", req.attribute)
return False

# If the requirement is None, the attribute existing is enough.
if req.value is None:
return True

values = attributes[req.attribute]
if req.value in values:
return True

logger.info(
"SSO attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
return False
Loading