Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix a bug with saml attribute maps. #6069

Merged
merged 3 commits into from Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6069.bugfix
@@ -0,0 +1 @@
Fix a bug which caused SAML attribute maps to be overridden by defaults.
48 changes: 42 additions & 6 deletions synapse/config/saml2_config.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,11 +13,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_python_module

from ._base import Config, ConfigError


def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts

Recursively merges `merge_dict` into `into_dict`:
* For keys where both `merge_dict` and `into_dict` have a dict value, the values
are recursively merged
* For all other keys, the values in `into_dict` (if any) are overwritten with
the value from `merge_dict`.

Args:
merge_dict (dict): dict to merge
into_dict (dict): target dict
"""
for k, v in merge_dict.items():
if k not in into_dict:
into_dict[k] = v
continue

current_val = into_dict[k]

if isinstance(v, dict) and isinstance(current_val, dict):
_dict_merge(v, current_val)
continue

# otherwise we just overwrite
into_dict[k] = v


class SAML2Config(Config):
def read_config(self, config, **kwargs):
self.saml2_enabled = False
Expand All @@ -33,15 +64,20 @@ def read_config(self, config, **kwargs):

self.saml2_enabled = True

import saml2.config

self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(self._default_saml_config_dict())
self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
saml2_config_dict = self._default_saml_config_dict()
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
)

config_path = saml2_config.get("config_path", None)
if config_path is not None:
self.saml2_sp_config.load_file(config_path)
mod = load_python_module(config_path)
_dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)

import saml2.config

self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(saml2_config_dict)

# session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration(
Expand Down
20 changes: 19 additions & 1 deletion synapse/util/module_loader.py
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.

import importlib
import importlib.util

from synapse.config._base import ConfigError


def load_module(provider):
""" Loads a module with its config
""" Loads a synapse module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).

Expand All @@ -38,3 +39,20 @@ def load_module(provider):
raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))

return provider_class, provider_config


def load_python_module(location: str):
"""Load a python module, and return a reference to its global namespace

Args:
location (str): path to the module

Returns:
python module object
"""
spec = importlib.util.spec_from_file_location(location, location)
if spec is None:
raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod