Skip to content

Commit

Permalink
[KED-1996] Create an IPython extension for Kedro (#853)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorena Bălan authored Nov 16, 2020
1 parent 6b2d676 commit 80e2202
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 4 deletions.
30 changes: 30 additions & 0 deletions kedro/extras/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains an IPython extension.
"""
133 changes: 133 additions & 0 deletions kedro/extras/extensions/ipython.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel,global-statement,invalid-name
"""
This script creates an IPython extension to load Kedro-related variables in
local scope.
"""
import logging.config
import sys
from pathlib import Path

from IPython import get_ipython
from IPython.core.magic import needs_local_scope, register_line_magic

project_path = Path.cwd()
catalog = None
context = None
session = None


def _remove_cached_modules(package_name):
to_remove = [mod for mod in sys.modules if mod.startswith(package_name)]
# `del` is used instead of `reload()` because: If the new version of a module does not
# define a name that was defined by the old version, the old definition remains.
for module in to_remove:
del sys.modules[module] # pragma: no cover


def _clear_hook_manager():
from kedro.framework.hooks import get_hook_manager

hook_manager = get_hook_manager()
name_plugin_pairs = hook_manager.list_name_plugin()
for name, plugin in name_plugin_pairs:
hook_manager.unregister(name=name, plugin=plugin) # pragma: no cover


def load_kedro_objects(path, line=None): # pylint: disable=unused-argument
"""Line magic which reloads all Kedro default variables."""

import kedro.config.default_logger # noqa: F401 # pylint: disable=unused-import
from kedro.framework.cli import load_entry_points
from kedro.framework.context.context import _add_src_to_path
from kedro.framework.project.metadata import _get_project_metadata
from kedro.framework.session import KedroSession
from kedro.framework.session.session import _activate_session

global context
global catalog
global session

path = path or project_path
project_metadata = _get_project_metadata(path)
_add_src_to_path(project_metadata.source_dir, path)

session = KedroSession.create(path)
_activate_session(session)

_remove_cached_modules(project_metadata.package_name)

# clear hook manager; hook implementations will be re-registered when the
# context is instantiated again in `session.context` below
_clear_hook_manager()

logging.debug("Loading the context from %s", str(path))
# Reload context to fix `pickle` related error (it is unable to serialize reloaded objects)
# Some details can be found here:
# https://modwsgi.readthedocs.io/en/develop/user-guides/issues-with-pickle-module.html#packing-and-script-reloading
context = session.load_context()
catalog = context.catalog
get_ipython().push(
variables={"context": context, "catalog": catalog, "session": session}
)

logging.info("** Kedro project %s", str(project_metadata.project_name))
logging.info("Defined global variable `context`, `session` and `catalog`")

for line_magic in load_entry_points("line_magic"):
register_line_magic(needs_local_scope(line_magic))
logging.info("Registered line magic `%s`", line_magic.__name__)


def init_kedro(path=""):
"""Line magic to set path to Kedro project.
`%reload_kedro` will default to this location.
"""
global project_path
if path:
project_path = Path(path).expanduser().resolve()
logging.info("Updated path to Kedro project: %s", str(project_path))
else:
logging.info("No path argument was provided. Using: %s", str(project_path))


def load_ipython_extension(ipython):
"""Main entry point when %load_ext is executed"""
ipython.register_magic_function(init_kedro, "line")
ipython.register_magic_function(load_kedro_objects, "line", "reload_kedro")

try:
load_kedro_objects(project_path)
except (ImportError, ModuleNotFoundError):
logging.error("Kedro appears not to be installed in your current environment.")
except Exception: # pylint: disable=broad-except
logging.error(
"Could not register Kedro extension. Make sure you're in a valid Kedro project.",
exc_info=True,
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pluggy~=0.13.0
python-json-logger~=0.1.9
PyYAML>=4.2, <6.0
setuptools>=38.0
toml~=0.10
toposort~=1.5 # Needs to be at least 1.5 to be able to raise CircularDependencyError
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ layers =
framework.context
framework.project
runner
extras
extras.datasets
io
pipeline
config
Expand All @@ -53,7 +53,7 @@ forbidden_modules =
kedro.runner
kedro.io
kedro.pipeline
kedro.extras
kedro.extras.datasets

[importlinter:contract:4]
name = Runner et al cannot import Config
Expand All @@ -62,7 +62,7 @@ source_modules =
kedro.runner
kedro.io
kedro.pipeline
kedro.extras
kedro.extras.datasets
forbidden_modules =
kedro.config
ignore_imports=
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _collect_requirements(requires):
"ipykernel>=4.8.1, <5.0",
],
"geopandas": _collect_requirements(geopandas_require),
"ipython": ["ipython~=7.0"],
"matplotlib": _collect_requirements(matplotlib_require),
"holoviews": _collect_requirements(holoviews_require),
"networkx": _collect_requirements(networkx_require),
Expand Down
Empty file.
173 changes: 173 additions & 0 deletions tests/extras/extensions/test_ipython.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel,reimported
import pytest

from kedro.extras.extensions.ipython import (
init_kedro,
load_ipython_extension,
load_kedro_objects,
)
from kedro.framework.project import ProjectMetadata
from kedro.framework.session.session import _deactivate_session


@pytest.fixture(autouse=True)
def project_path(mocker, tmp_path):
path = tmp_path
mocker.patch("kedro.extras.extensions.ipython.project_path", path)


@pytest.fixture(autouse=True)
def cleanup_session():
yield
_deactivate_session()


class TestInitKedro:
def test_init_kedro(self, tmp_path, caplog):
from kedro.extras.extensions.ipython import project_path

assert project_path == tmp_path

kedro_path = tmp_path / "here"
init_kedro(str(kedro_path))
expected_path = kedro_path.expanduser().resolve()
expected_message = f"Updated path to Kedro project: {expected_path}"

log_messages = [record.getMessage() for record in caplog.records]
assert expected_message in log_messages
from kedro.extras.extensions.ipython import project_path

# make sure global variable updated
assert project_path == expected_path

def test_init_kedro_no_path(self, tmp_path, caplog):
from kedro.extras.extensions.ipython import project_path

assert project_path == tmp_path

init_kedro()
expected_message = f"No path argument was provided. Using: {tmp_path}"

log_messages = [record.getMessage() for record in caplog.records]
assert expected_message in log_messages
from kedro.extras.extensions.ipython import project_path

# make sure global variable stayed the same
assert project_path == tmp_path


class TestLoadKedroObjects:
def test_load_kedro_objects(self, tmp_path, mocker):
fake_metadata = ProjectMetadata(
source_dir=tmp_path / "src", # default
config_file=tmp_path / "pyproject.toml",
package_name="fake_package_name",
project_name="fake_project_name",
project_version="0.1",
context_path="hello.there",
)
mocker.patch(
"kedro.framework.project.metadata._get_project_metadata",
return_value=fake_metadata,
)
mocker.patch(
"kedro.framework.session.session._get_project_metadata",
return_value=fake_metadata,
)
mocker.patch("kedro.framework.context.context._add_src_to_path")
mock_line_magic = mocker.MagicMock()
mock_line_magic.__name__ = "abc"
mocker.patch(
"kedro.framework.cli.load_entry_points", return_value=[mock_line_magic]
)
mock_register_line_magic = mocker.patch(
"kedro.extras.extensions.ipython.register_line_magic"
)
mock_context = mocker.patch("kedro.framework.session.KedroSession.load_context")
mock_ipython = mocker.patch("kedro.extras.extensions.ipython.get_ipython")

load_kedro_objects(tmp_path)

mock_ipython().push.assert_called_once_with(
variables={
"context": mock_context(),
"catalog": mock_context().catalog,
"session": mocker.ANY,
}
)
assert mock_register_line_magic.call_count == 1

def test_load_kedro_objects_not_in_kedro_project(self, tmp_path, mocker):
mocker.patch(
"kedro.framework.project.metadata._get_project_metadata",
side_effect=[RuntimeError],
)
mock_ipython = mocker.patch("kedro.extras.extensions.ipython.get_ipython")

with pytest.raises(RuntimeError):
load_kedro_objects(tmp_path)
assert not mock_ipython().called
assert not mock_ipython().push.called


class TestLoadIPythonExtension:
@pytest.mark.parametrize(
"error,expected_log_message",
[
(
ImportError,
"Kedro appears not to be installed in your current environment.",
),
(
RuntimeError,
"Could not register Kedro extension. Make sure you're in a valid Kedro project.",
),
],
)
def test_load_extension_not_in_kedro_env_or_project(
self, error, expected_log_message, mocker, caplog
):
mocker.patch(
"kedro.framework.project.metadata._get_project_metadata",
side_effect=[error],
)
mock_ipython = mocker.patch("kedro.extras.extensions.ipython.get_ipython")

load_ipython_extension(mocker.MagicMock())

assert not mock_ipython().called
assert not mock_ipython().push.called

log_messages = [
record.getMessage()
for record in caplog.records
if record.levelname == "ERROR"
]
assert log_messages == [expected_log_message]
2 changes: 1 addition & 1 deletion tests/framework/session/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_from_config_uncaught_error(self, mocker):
with pytest.raises(ValueError, match=re.escape(pattern)):
BaseSessionStore.from_config(config)

assert mocked_init.called_once_with(**config)
mocked_init.assert_called_once_with(**config)


@pytest.fixture
Expand Down

0 comments on commit 80e2202

Please sign in to comment.