Skip to content

Commit

Permalink
Merge branch 'main' into notebooknode-input
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitsanj committed Aug 12, 2022
2 parents 9c4ff1f + 9f02383 commit 17be0f5
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 13 deletions.
22 changes: 20 additions & 2 deletions papermill/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .exceptions import PapermillException
from .clientwrap import PapermillNotebookClient
from .iorw import write_ipynb
from .utils import merge_kwargs, remove_args
from .utils import merge_kwargs, remove_args, nb_kernel_name, nb_language


class PapermillEngines(object):
Expand Down Expand Up @@ -48,6 +48,14 @@ def execute_notebook_with_engine(self, engine_name, nb, kernel_name, **kwargs):
"""Fetch a named engine and execute the nb object against it."""
return self.get_engine(engine_name).execute_notebook(nb, kernel_name, **kwargs)

def nb_kernel_name(self, engine_name, nb, name=None):
"""Fetch kernel name from the document by dropping-down into the provided engine."""
return self.get_engine(engine_name).nb_kernel_name(nb, name)

def nb_language(self, engine_name, nb, language=None):
"""Fetch language from the document by dropping-down into the provided engine."""
return self.get_engine(engine_name).nb_language(nb, language)


def catch_nb_assignment(func):
"""
Expand Down Expand Up @@ -368,6 +376,16 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs):
"""An abstract method where implementation will be defined in a subclass."""
raise NotImplementedError("'execute_managed_notebook' is not implemented for this engine")

@classmethod
def nb_kernel_name(cls, nb, name=None):
"""Use default implementation to fetch kernel name from the notebook object"""
return nb_kernel_name(nb, name)

@classmethod
def nb_language(cls, nb, language=None):
"""Use default implementation to fetch programming language from the notebook object"""
return nb_language(nb, language)


class NBClientEngine(Engine):
"""
Expand All @@ -393,7 +411,7 @@ def execute_managed_notebook(
Performs the actual execution of the parameterized notebook locally.
Args:
nb (NotebookNode): Executable notebook object.
nb_man (NotebookExecutionManager): Wrapper for execution state of a notebook.
kernel_name (str): Name of kernel to execute the notebook against.
log_output (bool): Flag for whether or not to write notebook output to the
configured logger.
Expand Down
15 changes: 11 additions & 4 deletions papermill/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .exceptions import PapermillExecutionError
from .iorw import get_pretty_path, local_file_io_cwd, load_notebook_node, write_ipynb
from .engines import papermill_engines
from .utils import chdir, nb_kernel_name
from .utils import chdir
from .parameterize import add_builtin_parameters, parameterize_notebook, parameterize_path


Expand Down Expand Up @@ -92,16 +92,23 @@ def execute_notebook(
# Parameterize the Notebook.
if parameters:
nb = parameterize_notebook(
nb, parameters, report_mode, kernel_name=kernel_name, language=language
nb,
parameters,
report_mode,
kernel_name=kernel_name,
language=language,
engine_name=engine_name,
)

nb = prepare_notebook_metadata(nb, input_path, output_path, report_mode)
# clear out any existing error markers from previous papermill runs
nb = remove_error_markers(nb)

if not prepare_only:
# Fetch out the name from the notebook document
kernel_name = nb_kernel_name(nb, kernel_name)
# Dropdown to the engine to fetch the kernel name from the notebook document
kernel_name = papermill_engines.nb_kernel_name(
engine_name=engine_name, nb=nb, name=kernel_name
)
# Execute the Notebook in `cwd` if it is set
with chdir(cwd):
nb = papermill_engines.execute_notebook_with_engine(
Expand Down
17 changes: 12 additions & 5 deletions papermill/parameterize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import copy
import nbformat

from .engines import papermill_engines
from .log import logger
from .exceptions import PapermillMissingParameterException
from .iorw import read_yaml_file
from .translators import translate_parameters
from .utils import find_first_tagged_cell_index, nb_kernel_name, nb_language
from .utils import find_first_tagged_cell_index

from uuid import uuid4
from datetime import datetime
Expand Down Expand Up @@ -57,7 +58,13 @@ def parameterize_path(path, parameters):


def parameterize_notebook(
nb, parameters, report_mode=False, comment='Parameters', kernel_name=None, language=None
nb,
parameters,
report_mode=False,
comment='Parameters',
kernel_name=None,
language=None,
engine_name=None,
):
"""Assigned parameters into the appropriate place in the input notebook
Expand All @@ -79,9 +86,9 @@ def parameterize_notebook(
# Copy the nb object to avoid polluting the input
nb = copy.deepcopy(nb)

# Fetch out the name and language from the notebook document
kernel_name = nb_kernel_name(nb, kernel_name)
language = nb_language(nb, language)
# Fetch out the name and language from the notebook document by dropping-down into the engine's implementation
kernel_name = papermill_engines.nb_kernel_name(engine_name, nb, kernel_name)
language = papermill_engines.nb_language(engine_name, nb, language)

# Generate parameter content based on the kernel_name
param_content = translate_parameters(kernel_name, language, parameters, comment)
Expand Down
54 changes: 52 additions & 2 deletions papermill/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import shutil
import tempfile
import unittest
from unittest.mock import patch
from copy import deepcopy
from unittest.mock import patch, ANY

from functools import partial
from pathlib import Path

import nbformat
from nbformat import validate

from .. import engines
from .. import engines, translators
from ..log import logger
from ..iorw import load_notebook_node
from ..utils import chdir
Expand Down Expand Up @@ -389,6 +390,55 @@ def test_no_v3_language_backport(self):
validate(nb)


class TestExecuteWithCustomEngine(unittest.TestCase):
class CustomEngine(engines.Engine):
@classmethod
def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs):
pass

@classmethod
def nb_kernel_name(cls, nb, name=None):
return "my_custom_kernel"

@classmethod
def nb_language(cls, nb, language=None):
return "my_custom_language"

def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.notebook_path = get_notebook_path('simple_execute.ipynb')
self.nb_test_executed_fname = os.path.join(
self.test_dir, 'output_{}'.format('simple_execute.ipynb')
)

self._orig_papermill_engines = deepcopy(engines.papermill_engines)
self._orig_translators = deepcopy(translators.papermill_translators)
engines.papermill_engines.register(
"custom_engine", self.CustomEngine
)
translators.papermill_translators.register("my_custom_language", translators.PythonTranslator())

def tearDown(self):
shutil.rmtree(self.test_dir)
engines.papermill_engines = self._orig_papermill_engines
translators.papermill_translators = self._orig_translators

@patch.object(CustomEngine, "execute_managed_notebook", wraps=CustomEngine.execute_managed_notebook)
@patch("papermill.parameterize.translate_parameters", wraps=translators.translate_parameters)
def test_custom_kernel_name_and_language(self, translate_parameters, execute_managed_notebook):
"""Tests execute against engine with custom implementations to fetch
kernel name and language from the notebook object
"""
execute_notebook(
self.notebook_path,
self.nb_test_executed_fname,
engine_name="custom_engine",
parameters={"msg": "fake msg"},
)
self.assertEqual(execute_managed_notebook.call_args[0], (ANY, "my_custom_kernel"))
self.assertEqual(translate_parameters.call_args[0], (ANY, 'my_custom_language', {"msg": "fake msg"}, ANY))


class TestNotebookNodeInput(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.TemporaryDirectory()
Expand Down

0 comments on commit 17be0f5

Please sign in to comment.