# Nok's experiment with nbdev and DebugRunner

> Demo

In [None]:
#| default_exp core

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# | export
from collections import Counter
from itertools import chain
from typing import Any, Dict, Iterable, List, Set

from kedro.framework.hooks.manager import _NullPluginManager
from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node
from kedro.runner import SequentialRunner
from kedro.runner.runner import AbstractRunner, run_node
from pluggy import PluginManager


class DebugRunner(SequentialRunner):
    def run(
        self,
        pipeline: Pipeline,
        catalog: DataCatalog,
        dataset_names: List[str] = None,
        hook_manager: PluginManager = None,
        session_id: str = None,
    ) -> Dict[str, Any]:
        """Run the ``Pipeline`` using the datasets provided by ``catalog``
        and save results back to the same objects.

        Args:
            pipeline: The ``Pipeline`` to run.
            catalog: The ``DataCatalog`` from which to fetch data.
            hook_manager: The ``PluginManager`` to activate hooks.
            session_id: The id of the session.

        Raises:
            ValueError: Raised when ``Pipeline`` inputs cannot be satisfied.

        Returns:
            Any node outputs that cannot be processed by the ``DataCatalog``.
            These are returned in a dictionary, where the keys are defined
            by the node outputs.

        """
        if not dataset_names:
            dataset_names = []
        hook_manager = hook_manager or _NullPluginManager()
        catalog = catalog.shallow_copy()

        unsatisfied = pipeline.inputs() - set(catalog.list())
        if unsatisfied:
            raise ValueError(
                f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
            )

        free_outputs = (
            pipeline.outputs()
        )  # Return everything regardless if it it's in catalog
        unregistered_ds = pipeline.data_sets() - set(catalog.list())
        for ds_name in unregistered_ds:
            catalog.add(ds_name, self.create_default_data_set(ds_name))

        if self._is_async:
            self._logger.info(
                "Asynchronous mode is enabled for loading and saving data"
            )
        self._run(pipeline, catalog, dataset_names, hook_manager, session_id)

        self._logger.info("Pipeline execution completed successfully.")
        
        free_outputs = free_outputs | set(dataset_names)  # Union

        return {ds_name: catalog.load(ds_name) for ds_name in free_outputs}

    def _run(
        self,
        pipeline: Pipeline,
        catalog: DataCatalog,
        dataset_names: List[str],
        hook_manager: PluginManager,
        session_id: str = None,
    ) -> None:
        """The method implementing sequential pipeline running.

        Args:
            pipeline: The ``Pipeline`` to run.
            catalog: The ``DataCatalog`` from which to fetch data.
            hook_manager: The ``PluginManager`` to activate hooks.
            session_id: The id of the session.

        Raises:
            Exception: in case of any downstream node failure.
        """
        nodes = pipeline.nodes
        done_nodes = set()

        load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))

        for exec_index, node in enumerate(nodes):
            try:
                run_node(node, catalog, hook_manager, self._is_async, session_id)
                done_nodes.add(node)
            except Exception:
                self._suggest_resume_scenario(pipeline, done_nodes, catalog)
                raise

            # decrement load counts and release any data sets we've finished with
            for data_set in node.inputs:
                load_counts[data_set] -= 1
                if load_counts[data_set] < 1 and data_set not in pipeline.inputs():
                    if data_set not in dataset_names:
                        catalog.release(data_set)
            for data_set in node.outputs:
                if load_counts[data_set] < 1 and data_set not in pipeline.outputs():
                    if data_set not in dataset_names:
                        catalog.release(data_set)

            self._logger.info(
                "Completed %d out of %d tasks", exec_index + 1, len(nodes)
            )


In [None]:
# `DebugRunner` has to be used in a different way since `session.run` don't support additional argument, so we are going to use a lower level approach and construct `Runner` and `Pipeline` and `DataCatalog` ourselves.

# Testing Kedro Project: https://github.com/noklam/kedro_gallery/tree/master/kedro-debug-runner-demo
%load_ext kedro.ipython
%reload_kedro ~/dev/kedro_gallery/kedro-debug-runner-demo

The kedro.ipython extension is already loaded. To reload it, use:
  %reload_ext kedro.ipython


In [None]:
%reload_kedro ~/dev/kedro_gallery/kedro-debug-runner-demo
runner = DebugRunner()
default_pipeline = pipelines["__default__"]
run_1 = runner.run(default_pipeline, catalog)


In [None]:
runner = DebugRunner()
default_pipeline = pipelines["__default__"]
run_2 = runner.run(default_pipeline, catalog, dataset_names=["example_iris_data"])


In [None]:
runner = DebugRunner()
default_pipeline = pipelines["__default__"]
run_3 = runner.run(default_pipeline, catalog, dataset_names=["X_train"]) # Input datasets


In [None]:
run_1

In [None]:
run_2

In [None]:
run_3

In [None]:
#| export


class GreedySequentialRunner(SequentialRunner):
    def run(
        self,
        pipeline: Pipeline,
        catalog: DataCatalog,
        hook_manager: PluginManager = None,
        session_id: str = None,
    ) -> Dict[str, Any]:
        """Run the ``Pipeline`` using the datasets provided by ``catalog``
        and save results back to the same objects.

        Args:
            pipeline: The ``Pipeline`` to run.
            catalog: The ``DataCatalog`` from which to fetch data.
            hook_manager: The ``PluginManager`` to activate hooks.
            session_id: The id of the session.

        Raises:
            ValueError: Raised when ``Pipeline`` inputs cannot be satisfied.

        Returns:
            Any node outputs that cannot be processed by the ``DataCatalog``.
            These are returned in a dictionary, where the keys are defined
            by the node outputs.

        """

        hook_manager = hook_manager or _NullPluginManager()
        catalog = catalog.shallow_copy()

        unsatisfied = pipeline.inputs() - set(catalog.list())
        if unsatisfied:
            raise ValueError(
                f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
            )

        free_outputs = pipeline.outputs() # Return everything regardless if it it's in catalog
        unregistered_ds = pipeline.data_sets() - set(catalog.list())
        for ds_name in unregistered_ds:
            catalog.add(ds_name, self.create_default_data_set(ds_name))

        if self._is_async:
            self._logger.info(
                "Asynchronous mode is enabled for loading and saving data"
            )
        self._run(pipeline, catalog, hook_manager, session_id)

        self._logger.info("Pipeline execution completed successfully.")

        return {ds_name: catalog.load(ds_name) for ds_name in free_outputs}


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()