Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dataset factory patterns in Experiment Tracking #1588

Merged
merged 12 commits into from
Oct 27, 2023
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ Please follow the established format:
- Use present tense (e.g. 'Add new feature')
- Include the ID number for the related PR (or PRs) in parentheses
-->
# Upcoming Release

## Bug fixes and other changes

- Fix dataset factory patterns in Experiment Tracking. (#1588)

# Release 6.6.1

Expand Down
15 changes: 15 additions & 0 deletions package/kedro_viz/data_access/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ def set_db_session(self, db_session_class: sessionmaker):
"""Set db session on repositories that need it."""
self.runs.set_db_session(db_session_class)

def resolve_dataset_factory_patterns(
self, catalog: DataCatalog, pipelines: Dict[str, KedroPipeline]
ravi-kumar-pilla marked this conversation as resolved.
Show resolved Hide resolved
):
"""Resolve dataset factory patterns in data catalog by matching
them against the datasets in the pipelines.
"""
for pipeline in pipelines.values():
if hasattr(pipeline, "datasets"):
datasets = pipeline.datasets() # kedro 0.19.0 onwards
else:
datasets = pipeline.data_sets()

for dataset_name in datasets:
catalog.exists(dataset_name)

def add_catalog(self, catalog: DataCatalog):
"""Add a catalog to the CatalogRepository and relevant tracking datasets to
TrackingDatasetRepository.
Expand Down
4 changes: 4 additions & 0 deletions package/kedro_viz/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def populate_data(
session_class = make_db_session_factory(session_store.location)
data_access_manager.set_db_session(session_class)

# resolve the dataset factory patterns
data_access_manager.resolve_dataset_factory_patterns(catalog, pipelines)

# add catalog and relevant tracking datasets
data_access_manager.add_catalog(catalog)

# add dataset stats before adding pipelines
Expand Down
22 changes: 22 additions & 0 deletions package/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def example_catalog():
},
"model_inputs": {"model_inputs"},
},
dataset_patterns={
"{dataset_name}#csv": {
"type": "pandas.CSVDataset",
"filepath": "data/01_raw/{dataset_name}#csv.csv",
},
},
)


Expand Down Expand Up @@ -290,3 +296,19 @@ def example_csv_dataset(tmp_path, example_data_frame):
)
new_csv_dataset.save(example_data_frame)
yield new_csv_dataset


# Create a mock for KedroPipeline with datasets method
@pytest.fixture
def pipeline_with_datasets_mock():
pipeline = mock.MagicMock()
pipeline.datasets.return_value = ["model_inputs#csv"]
return pipeline


# Create a mock for KedroPipeline with data_sets method
@pytest.fixture
def pipeline_with_data_sets_mock():
pipeline = mock.MagicMock()
pipeline.data_sets.return_value = ["model_inputs#csv"]
return pipeline
23 changes: 23 additions & 0 deletions package/tests/test_data_access/test_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from kedro_viz.constants import DEFAULT_REGISTERED_PIPELINE_ID, ROOT_MODULAR_PIPELINE_ID
from kedro_viz.data_access.managers import DataAccessManager
from kedro_viz.data_access.repositories.catalog import CatalogRepository
from kedro_viz.models.flowchart import (
DataNode,
GraphEdge,
Expand Down Expand Up @@ -464,3 +465,25 @@ def test_add_pipelines_with_circular_modular_pipelines(
digraph.add_edge(edge.source, edge.target)
with pytest.raises(nx.NetworkXNoCycle):
nx.find_cycle(digraph)


class TestResolveDatasetFactoryPatterns:
def test_resolve_dataset_factory_patterns(
self,
example_catalog,
pipeline_with_datasets_mock,
pipeline_with_data_sets_mock,
data_access_manager: DataAccessManager,
):
pipelines = {
"pipeline1": pipeline_with_datasets_mock,
"pipeline2": pipeline_with_data_sets_mock,
}
new_catalog = CatalogRepository()
new_catalog.set_catalog(example_catalog)

assert "model_inputs#csv" not in new_catalog.as_dict().keys()

data_access_manager.resolve_dataset_factory_patterns(example_catalog, pipelines)

assert "model_inputs#csv" in new_catalog.as_dict().keys()
3 changes: 3 additions & 0 deletions package/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def test_run_server_from_project(
):
run_server()
# assert that when running server, data are added correctly to the data access manager
patched_data_access_manager.resolve_dataset_factory_patterns.assert_called_once_with(
example_catalog, example_pipelines
)
patched_data_access_manager.add_catalog.assert_called_once_with(example_catalog)
patched_data_access_manager.add_pipelines.assert_called_once_with(
example_pipelines
Expand Down