diff --git a/docs/source/hooks/examples.md b/docs/source/hooks/examples.md index 9b2d490cce..f5008da75e 100644 --- a/docs/source/hooks/examples.md +++ b/docs/source/hooks/examples.md @@ -178,7 +178,7 @@ class DataValidationHooks: great_expectations checkpoint new raw_companies_dataset_checkpoint ``` -* Remove `data_connector_query` from the `batch_request` in the checkpoint config file: +* Remove `data_connector_query` from the `batch_request` in the checkpoint config file: ```python yaml_config = f""" @@ -249,7 +249,7 @@ class DataValidationHooks: }, "batch_identifiers": { "runtime_batch_identifier_name": dataset_name - } + }, }, run_name=session_id, ) diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index 3b4f9c2be6..5d954c19af 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -231,12 +231,11 @@ def load_and_merge_dir_config(self, conf_path: str, patterns: Iterable[str]): aggregate_config = config_per_file.values() self._check_duplicates(seen_file_to_keys) - if aggregate_config: - if len(aggregate_config) > 1: - merged_config = OmegaConf.merge(*aggregate_config) - return OmegaConf.to_container(merged_config) - return OmegaConf.to_container(list(aggregate_config)[0]) - return {} + if not aggregate_config: + return {} + if len(aggregate_config) == 1: + return list(aggregate_config)[0] + return dict(OmegaConf.merge(*aggregate_config)) @staticmethod def _is_valid_config_path(path): diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 6b751b910c..a435ffd582 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Iterable, Optional, Union import click +from omegaconf import OmegaConf, omegaconf from kedro import __version__ as kedro_version from kedro.config import ConfigLoader, MissingConfigException @@ -192,6 +193,8 @@ def create( # pylint: disable=too-many-arguments def _get_logging_config(self) -> Dict[str, Any]: logging_config = self._get_config_loader()["logging"] + if isinstance(logging_config, omegaconf.DictConfig): + logging_config = OmegaConf.to_container(logging_config) # turn relative paths in logging config into absolute path # before initialising loggers logging_config = _convert_paths_to_absolute_posix( diff --git a/tests/framework/session/test_session.py b/tests/framework/session/test_session.py index af8c734179..24b6652be2 100644 --- a/tests/framework/session/test_session.py +++ b/tests/framework/session/test_session.py @@ -9,7 +9,7 @@ import yaml from kedro import __version__ as kedro_version -from kedro.config import AbstractConfigLoader, ConfigLoader +from kedro.config import AbstractConfigLoader, ConfigLoader, OmegaConfLoader from kedro.framework.context import KedroContext from kedro.framework.project import ( ValidationError, @@ -91,6 +91,16 @@ class MockSettings(_ProjectSettings): return _mock_imported_settings_paths(mocker, MockSettings()) +@pytest.fixture +def mock_settings_omega_config_loader_class(mocker): + class MockSettings(_ProjectSettings): + _CONFIG_LOADER_CLASS = _HasSharedParentClassValidator( + "CONFIG_LOADER_CLASS", default=lambda *_: OmegaConfLoader + ) + + return _mock_imported_settings_paths(mocker, MockSettings()) + + @pytest.fixture def mock_settings_config_loader_args(mocker): class MockSettings(_ProjectSettings): @@ -889,3 +899,20 @@ def test_setup_logging_using_absolute_path( ).as_posix() actual_log_filepath = call_args["handlers"]["info_file_handler"]["filename"] assert actual_log_filepath == expected_log_filepath + + +@pytest.mark.usefixtures("mock_settings_omega_config_loader_class") +def test_setup_logging_using_omega_config_loader_class( + fake_project_with_logging_file_handler, mocker, mock_package_name +): + mocked_logging = mocker.patch("logging.config.dictConfig") + KedroSession.create(mock_package_name, fake_project_with_logging_file_handler) + + mocked_logging.assert_called_once() + call_args = mocked_logging.call_args[0][0] + + expected_log_filepath = ( + fake_project_with_logging_file_handler / "logs" / "info.log" + ).as_posix() + actual_log_filepath = call_args["handlers"]["info_file_handler"]["filename"] + assert actual_log_filepath == expected_log_filepath