Skip to content

Commit

Permalink
Use user defined default dataset factory pattern over the one from th…
Browse files Browse the repository at this point in the history
…e runner (#3859)

* Use user default pattern over the ones from the runner

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Update catalog CLI commands

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Fix lint

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Fix tests and add error for multiple catch-all patterns

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Remove note from docs

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Suggestions from code review, release notes

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Docs changes

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

---------

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
  • Loading branch information
ankatiyar committed May 14, 2024
1 parent 360cd12 commit 9369227
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 57 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Major features and improvements

## Bug fixes and other changes
* User defined catch-all dataset factory patterns now override the default pattern provided by the runner.

## Breaking changes to the API

Expand Down
13 changes: 4 additions & 9 deletions docs/source/data/kedro_dataset_factories.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,14 @@ The matches are ranked according to the following criteria:
You can use dataset factories to define a catch-all pattern which will overwrite the default [`MemoryDataset`](/api/kedro.io.MemoryDataset) creation.

```yaml
"{a_default_dataset}":
"{default_dataset}":
type: pandas.CSVDataset
filepath: data/{a_default_dataset}.csv
filepath: data/{default_dataset}.csv

```
Kedro will now treat all the datasets mentioned in your project's pipelines that do not appear as specific patterns or explicit entries in your catalog
as `pandas.CSVDataset`.

```{note}
Under the hood Kedro uses the pattern name "{default}" to generate the default datasets set in the runners. If you want to overwrite this pattern you should make sure you choose a name that comes
before "default" in the alphabet for it to be resolved first.
```

## CLI commands for dataset factories

To manage your dataset factories, two new commands have been added to the Kedro CLI: `kedro catalog rank` (0.18.12) and `kedro catalog resolve` (0.18.13).
Expand Down Expand Up @@ -322,9 +317,9 @@ shuttles:
type: pandas.ParquetDataset
filepath: data/02_intermediate/preprocessed_{name}.pq

"{a_default}":
"{default}":
type: pandas.ParquetDataset
filepath: data/03_primary/{a_default}.pq
filepath: data/03_primary/{default}.pq
```
</details>

Expand Down
17 changes: 12 additions & 5 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None:
for ds_name in default_ds:
matched_pattern = data_catalog._match_pattern(
data_catalog._dataset_patterns, ds_name
)
) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name)
if matched_pattern:
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns[matched_pattern]
data_catalog._dataset_patterns.get(matched_pattern)
or data_catalog._default_pattern.get(matched_pattern)
or {}
)

ds_config = data_catalog._resolve_config(
Expand Down Expand Up @@ -215,7 +217,10 @@ def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None:
session = _create_session(metadata.package_name, env=env)
context = session.load_context()

catalog_factories = context.catalog._dataset_patterns
catalog_factories = {
**context.catalog._dataset_patterns,
**context.catalog._default_pattern,
}
if catalog_factories:
click.echo(yaml.dump(list(catalog_factories.keys())))
else:
Expand Down Expand Up @@ -259,10 +264,12 @@ def resolve_patterns(metadata: ProjectMetadata, env: str) -> None:

matched_pattern = data_catalog._match_pattern(
data_catalog._dataset_patterns, ds_name
)
) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name)
if matched_pattern:
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns[matched_pattern]
data_catalog._dataset_patterns.get(matched_pattern)
or data_catalog._default_pattern.get(matched_pattern)
or {}
)

ds_config = data_catalog._resolve_config(
Expand Down
34 changes: 30 additions & 4 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__( # noqa: PLR0913
dataset_patterns: Patterns | None = None,
load_versions: dict[str, str] | None = None,
save_version: str | None = None,
default_pattern: Patterns | None = None,
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataset``
implementations to provide ``load`` and ``save`` capabilities from
Expand Down Expand Up @@ -172,6 +173,8 @@ def __init__( # noqa: PLR0913
case-insensitive string that conforms with operating system
filename limitations, b) always return the latest version when
sorted in lexicographical order.
default_pattern: A dictionary of the default catch-all pattern that overrides the default
pattern provided through the runners.
Example:
::
Expand All @@ -190,6 +193,7 @@ def __init__( # noqa: PLR0913
self._dataset_patterns = dataset_patterns or {}
self._load_versions = load_versions or {}
self._save_version = save_version
self._default_pattern = default_pattern or {}

if feed_dict:
self.add_feed_dict(feed_dict)
Expand Down Expand Up @@ -281,6 +285,7 @@ class to be loaded is specified with the key ``type`` and their
credentials = copy.deepcopy(credentials) or {}
save_version = save_version or generate_timestamp()
load_versions = copy.deepcopy(load_versions) or {}
user_default = {}

for ds_name, ds_config in catalog.items():
ds_config = _resolve_credentials( # noqa: PLW2901
Expand All @@ -295,6 +300,12 @@ class to be loaded is specified with the key ``type`` and their
ds_name, ds_config, load_versions.get(ds_name), save_version
)
sorted_patterns = cls._sort_patterns(dataset_patterns)
if sorted_patterns:
# If the last pattern is a catch-all pattern, pop it and set it as the default
if cls._specificity(list(sorted_patterns.keys())[-1]) == 0:
last_pattern = sorted_patterns.popitem()
user_default = {last_pattern[0]: last_pattern[1]}

missing_keys = [
key
for key in load_versions.keys()
Expand All @@ -311,6 +322,7 @@ class to be loaded is specified with the key ``type`` and their
dataset_patterns=sorted_patterns,
load_versions=load_versions,
save_version=save_version,
default_pattern=user_default,
)

@staticmethod
Expand Down Expand Up @@ -346,6 +358,13 @@ def _sort_patterns(cls, dataset_patterns: Patterns) -> dict[str, dict[str, Any]]
pattern,
),
)
catch_all = [
pattern for pattern in sorted_keys if cls._specificity(pattern) == 0
]
if len(catch_all) > 1:
raise DatasetError(
f"Multiple catch-all patterns found in the catalog: {', '.join(catch_all)}. Only one catch-all pattern is allowed, remove the extras."
)
return {key: dataset_patterns[key] for key in sorted_keys}

@staticmethod
Expand All @@ -369,11 +388,17 @@ def _get_dataset(
version: Version | None = None,
suggest: bool = True,
) -> AbstractDataset:
matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name)
matched_pattern = self._match_pattern(
self._dataset_patterns, dataset_name
) or self._match_pattern(self._default_pattern, dataset_name)
if dataset_name not in self._datasets and matched_pattern:
# If the dataset is a patterned dataset, materialise it and add it to
# the catalog
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
config_copy = copy.deepcopy(
self._dataset_patterns.get(matched_pattern)
or self._default_pattern.get(matched_pattern)
or {}
)
dataset_config = self._resolve_config(
dataset_name, matched_pattern, config_copy
)
Expand All @@ -385,7 +410,7 @@ def _get_dataset(
)
if (
self._specificity(matched_pattern) == 0
and matched_pattern != "{default}"
and matched_pattern in self._default_pattern
):
self._logger.warning(
"Config from the dataset factory pattern '%s' in the catalog will be used to "
Expand Down Expand Up @@ -721,7 +746,7 @@ def shallow_copy(
Returns:
Copy of the current object.
"""
if extra_dataset_patterns:
if not self._default_pattern and extra_dataset_patterns:
unsorted_dataset_patterns = {
**self._dataset_patterns,
**extra_dataset_patterns,
Expand All @@ -734,6 +759,7 @@ def shallow_copy(
dataset_patterns=dataset_patterns,
load_versions=self._load_versions,
save_version=self._save_version,
default_pattern=self._default_pattern,
)

def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
Expand Down
81 changes: 42 additions & 39 deletions tests/io/test_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ def config_with_dataset_factories_bad_pattern(config_with_dataset_factories):
def config_with_dataset_factories_only_patterns():
return {
"catalog": {
"{default}": {
"type": "pandas.CSVDataset",
"filepath": "data/01_raw/{default}.csv",
},
"{namespace}_{dataset}": {
"type": "pandas.CSVDataset",
"filepath": "data/01_raw/{namespace}_{dataset}.pq",
Expand All @@ -175,6 +171,14 @@ def config_with_dataset_factories_only_patterns():
}


@pytest.fixture
def config_with_dataset_factories_only_patterns_no_default(
config_with_dataset_factories_only_patterns,
):
del config_with_dataset_factories_only_patterns["catalog"]["{user_default}"]
return config_with_dataset_factories_only_patterns


@pytest.fixture
def dataset(filepath):
return CSVDataset(filepath=filepath, save_args={"index": False})
Expand Down Expand Up @@ -841,23 +845,47 @@ def test_sorting_order_patterns(self, config_with_dataset_factories_only_pattern
"{country}_companies",
"{namespace}_{dataset}",
"{dataset}s",
"{default}",
"{user_default}",
]
assert list(catalog._dataset_patterns.keys()) == sorted_keys_expected
assert (
list(catalog._dataset_patterns.keys())
+ list(catalog._default_pattern.keys())
== sorted_keys_expected
)

def test_sorting_order_with_default_and_other_dataset_through_extra_pattern(
self, config_with_dataset_factories_only_patterns
def test_multiple_catch_all_patterns_not_allowed(
self, config_with_dataset_factories
):
"""Check that the sorted order of the patterns is correct according to parsing rules when a default dataset is added through extra patterns (this would happen via the runner)."""
"""Check that multiple catch-all patterns are not allowed"""
config_with_dataset_factories["catalog"]["{default1}"] = {
"filepath": "data/01_raw/{default1}.csv",
"type": "pandas.CSVDataset",
}
config_with_dataset_factories["catalog"]["{default2}"] = {
"filepath": "data/01_raw/{default2}.xlsx",
"type": "pandas.ExcelDataset",
}

with pytest.raises(
DatasetError, match="Multiple catch-all patterns found in the catalog"
):
DataCatalog.from_config(**config_with_dataset_factories)

def test_sorting_order_with_other_dataset_through_extra_pattern(
self, config_with_dataset_factories_only_patterns_no_default
):
"""Check that the sorted order of the patterns is correct according to parsing rules when a default dataset
is added through extra patterns (this would happen via the runner) and user default is not present"""
extra_dataset_patterns = {
"{default}": {"type": "MemoryDataset"},
"{another}#csv": {
"type": "pandas.CSVDataset",
"filepath": "data/{another}.csv",
},
}
catalog = DataCatalog.from_config(**config_with_dataset_factories_only_patterns)
catalog = DataCatalog.from_config(
**config_with_dataset_factories_only_patterns_no_default
)
catalog_with_default = catalog.shallow_copy(
extra_dataset_patterns=extra_dataset_patterns
)
Expand All @@ -867,38 +895,13 @@ def test_sorting_order_with_default_and_other_dataset_through_extra_pattern(
"{namespace}_{dataset}",
"{dataset}s",
"{default}",
"{user_default}",
]
assert (
list(catalog_with_default._dataset_patterns.keys()) == sorted_keys_expected
)

def test_runner_default_overwrites_user_default(
self, config_with_dataset_factories_only_patterns
):
"""Check that the runner default overwrites the user default."""
catalog = DataCatalog.from_config(**config_with_dataset_factories_only_patterns)
assert catalog._dataset_patterns["{default}"] == {
"filepath": "data/01_raw/{default}.csv",
"type": "pandas.CSVDataset",
}

extra_dataset_patterns = {
"{default}": {"type": "MemoryDataset"},
"{another}#csv": {
"type": "pandas.CSVDataset",
"filepath": "data/{another}.csv",
},
}
catalog_with_runner_default = catalog.shallow_copy(
extra_dataset_patterns=extra_dataset_patterns
)
assert catalog_with_runner_default._dataset_patterns["{default}"] == {
"type": "MemoryDataset"
}

def test_user_default_overwrites_runner_default_alphabetically(self):
"""Check that the runner default overwrites the user default if earlier in alphabet."""
def test_user_default_overwrites_runner_default(self):
"""Check that the user default overwrites the runner default when both are present"""
catalog_config = {
"{dataset}s": {
"type": "pandas.CSVDataset",
Expand All @@ -921,13 +924,13 @@ def test_user_default_overwrites_runner_default_alphabetically(self):
extra_dataset_patterns=extra_dataset_patterns
)
sorted_keys_expected = [
"{another}#csv",
"{dataset}s",
"{a_default}",
"{default}",
]
assert "{a_default}" in catalog_with_runner_default._default_pattern
assert (
list(catalog_with_runner_default._dataset_patterns.keys())
+ list(catalog_with_runner_default._default_pattern.keys())
== sorted_keys_expected
)

Expand Down

0 comments on commit 9369227

Please sign in to comment.