From 70ce95740303c8f2578e85793e001dbf9397ed4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Nov 2023 22:49:06 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20format?= =?UTF-8?q?=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/conf.py | 12 +- papermill/__init__.py | 3 +- papermill/__main__.py | 2 +- papermill/abs.py | 44 +- papermill/adl.py | 11 +- papermill/cli.py | 184 ++++---- papermill/clientwrap.py | 32 +- papermill/engines.py | 118 +++--- papermill/exceptions.py | 20 +- papermill/execute.py | 71 ++-- papermill/inspection.py | 34 +- papermill/iorw.py | 167 ++++---- papermill/log.py | 2 +- papermill/models.py | 10 +- papermill/parameterize.py | 32 +- papermill/s3.py | 146 +++---- papermill/tests/__init__.py | 6 +- papermill/tests/test_abs.py | 107 ++--- papermill/tests/test_adl.py | 61 ++- papermill/tests/test_autosave.py | 36 +- papermill/tests/test_cli.py | 470 +++++++++------------ papermill/tests/test_clientwrap.py | 31 +- papermill/tests/test_engines.py | 341 +++++++-------- papermill/tests/test_exceptions.py | 20 +- papermill/tests/test_execute.py | 352 ++++++---------- papermill/tests/test_gcs.py | 110 ++--- papermill/tests/test_hdfs.py | 20 +- papermill/tests/test_inspect.py | 106 +++-- papermill/tests/test_iorw.py | 232 +++++------ papermill/tests/test_parameterize.py | 160 +++---- papermill/tests/test_s3.py | 90 ++-- papermill/tests/test_translators.py | 603 +++++++++++++-------------- papermill/tests/test_utils.py | 38 +- papermill/translators.py | 267 ++++++------ papermill/utils.py | 21 +- papermill/version.py | 2 +- 36 files changed, 1763 insertions(+), 2198 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 00ddfde6..50adcd0f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -80,7 +80,7 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'UPDATE.md'] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" +pygments_style = 'sphinx' # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -90,14 +90,14 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = "furo" +html_theme = 'furo' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = { - "sidebar_hide_name": True, + 'sidebar_hide_name': True, } # Add any paths that contain custom static files (such as style sheets) here, @@ -105,7 +105,7 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] -html_logo = "_static/images/papermill.png" +html_logo = '_static/images/papermill.png' # -- Options for HTMLHelp output ------------------------------------------ @@ -132,9 +132,7 @@ # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'papermill.tex', 'papermill Documentation', 'nteract team', 'manual') -] +latex_documents = [(master_doc, 'papermill.tex', 'papermill Documentation', 'nteract team', 'manual')] # -- Options for manual page output --------------------------------------- diff --git a/papermill/__init__.py b/papermill/__init__.py index af32a9d3..e3b98fb6 100644 --- a/papermill/__init__.py +++ b/papermill/__init__.py @@ -1,5 +1,4 @@ -from .version import version as __version__ - from .exceptions import PapermillException, PapermillExecutionError from .execute import execute_notebook from .inspection import inspect_notebook +from .version import version as __version__ diff --git a/papermill/__main__.py b/papermill/__main__.py index 1f08dacb..c386c2ff 100644 --- a/papermill/__main__.py +++ b/papermill/__main__.py @@ -1,4 +1,4 @@ from papermill.cli import papermill -if __name__ == "__main__": +if __name__ == '__main__': papermill() diff --git a/papermill/abs.py b/papermill/abs.py index 2c5d4a45..0378d45f 100644 --- a/papermill/abs.py +++ b/papermill/abs.py @@ -1,9 +1,9 @@ """Utilities for working with Azure blob storage""" -import re import io +import re -from azure.storage.blob import BlobServiceClient from azure.identity import EnvironmentCredential +from azure.storage.blob import BlobServiceClient class AzureBlobStore: @@ -20,7 +20,7 @@ class AzureBlobStore: def _blob_service_client(self, account_name, sas_token=None): blob_service_client = BlobServiceClient( - account_url=f"{account_name}.blob.core.windows.net", + account_url=f'{account_name}.blob.core.windows.net', credential=sas_token or EnvironmentCredential(), ) @@ -32,17 +32,15 @@ def _split_url(self, url): see: https://docs.microsoft.com/en-us/azure/storage/common/storage-dotnet-shared-access-signature-part-1 # noqa: E501 abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken """ - match = re.match( - r"abs://(.*)\.blob\.core\.windows\.net\/(.*?)\/([^\?]*)\??(.*)$", url - ) + match = re.match(r'abs://(.*)\.blob\.core\.windows\.net\/(.*?)\/([^\?]*)\??(.*)$', url) if not match: raise Exception(f"Invalid azure blob url '{url}'") else: params = { - "account": match.group(1), - "container": match.group(2), - "blob": match.group(3), - "sas_token": match.group(4), + 'account': match.group(1), + 'container': match.group(2), + 'blob': match.group(3), + 'sas_token': match.group(4), } return params @@ -50,32 +48,22 @@ def read(self, url): """Read storage at a given url""" params = self._split_url(url) output_stream = io.BytesIO() - blob_service_client = self._blob_service_client( - params["account"], params["sas_token"] - ) - blob_client = blob_service_client.get_blob_client( - params["container"], params["blob"] - ) + blob_service_client = self._blob_service_client(params['account'], params['sas_token']) + blob_client = blob_service_client.get_blob_client(params['container'], params['blob']) blob_client.download_blob().readinto(output_stream) output_stream.seek(0) - return [line.decode("utf-8") for line in output_stream] + return [line.decode('utf-8') for line in output_stream] def listdir(self, url): """Returns a list of the files under the specified path""" params = self._split_url(url) - blob_service_client = self._blob_service_client( - params["account"], params["sas_token"] - ) - container_client = blob_service_client.get_container_client(params["container"]) - return list(container_client.list_blobs(params["blob"])) + blob_service_client = self._blob_service_client(params['account'], params['sas_token']) + container_client = blob_service_client.get_container_client(params['container']) + return list(container_client.list_blobs(params['blob'])) def write(self, buf, url): """Write buffer to storage at a given url""" params = self._split_url(url) - blob_service_client = self._blob_service_client( - params["account"], params["sas_token"] - ) - blob_client = blob_service_client.get_blob_client( - params["container"], params["blob"] - ) + blob_service_client = self._blob_service_client(params['account'], params['sas_token']) + blob_client = blob_service_client.get_blob_client(params['container'], params['blob']) blob_client.upload_blob(data=buf, overwrite=True) diff --git a/papermill/adl.py b/papermill/adl.py index de7b64cb..4ad0f62a 100644 --- a/papermill/adl.py +++ b/papermill/adl.py @@ -21,7 +21,7 @@ def __init__(self): @classmethod def _split_url(cls, url): - match = re.match(r"adl://(.*)\.azuredatalakestore\.net\/(.*)$", url) + match = re.match(r'adl://(.*)\.azuredatalakestore\.net\/(.*)$', url) if not match: raise Exception(f"Invalid ADL url '{url}'") else: @@ -39,12 +39,7 @@ def listdir(self, url): """Returns a list of the files under the specified path""" (store_name, path) = self._split_url(url) adapter = self._create_adapter(store_name) - return [ - "adl://{store_name}.azuredatalakestore.net/{path_to_child}".format( - store_name=store_name, path_to_child=path_to_child - ) - for path_to_child in adapter.ls(path) - ] + return [f'adl://{store_name}.azuredatalakestore.net/{path_to_child}' for path_to_child in adapter.ls(path)] def read(self, url): """Read storage at a given url""" @@ -60,5 +55,5 @@ def write(self, buf, url): """Write buffer to storage at a given url""" (store_name, path) = self._split_url(url) adapter = self._create_adapter(store_name) - with adapter.open(path, "wb") as f: + with adapter.open(path, 'wb') as f: f.write(buf.encode()) diff --git a/papermill/cli.py b/papermill/cli.py index 3b76b00e..e80867df 100755 --- a/papermill/cli.py +++ b/papermill/cli.py @@ -1,23 +1,21 @@ """Main `papermill` interface.""" +import base64 +import logging import os +import platform import sys -from stat import S_ISFIFO -import nbclient import traceback - -import base64 -import logging +from stat import S_ISFIFO import click - +import nbclient import yaml -import platform +from . import __version__ as papermill_version from .execute import execute_notebook -from .iorw import read_yaml_file, NoDatesSafeLoader from .inspection import display_notebook_help -from . import __version__ as papermill_version +from .iorw import NoDatesSafeLoader, read_yaml_file click.disable_unicode_literals_warning = True @@ -28,155 +26,147 @@ def print_papermill_version(ctx, param, value): if not value: return - print( - "{version} from {path} ({pyver})".format( - version=papermill_version, path=__file__, pyver=platform.python_version() - ) - ) + print(f'{papermill_version} from {__file__} ({platform.python_version()})') ctx.exit() -@click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.command(context_settings=dict(help_option_names=['-h', '--help'])) @click.pass_context -@click.argument("notebook_path", required=not INPUT_PIPED) -@click.argument("output_path", default="") +@click.argument('notebook_path', required=not INPUT_PIPED) +@click.argument('output_path', default='') @click.option( - "--help-notebook", + '--help-notebook', is_flag=True, default=False, - help="Display parameters information for the given notebook path.", + help='Display parameters information for the given notebook path.', ) @click.option( - "--parameters", - "-p", + '--parameters', + '-p', nargs=2, multiple=True, - help="Parameters to pass to the parameters cell.", + help='Parameters to pass to the parameters cell.', ) @click.option( - "--parameters_raw", - "-r", + '--parameters_raw', + '-r', nargs=2, multiple=True, - help="Parameters to be read as raw string.", + help='Parameters to be read as raw string.', ) @click.option( - "--parameters_file", - "-f", + '--parameters_file', + '-f', multiple=True, - help="Path to YAML file containing parameters.", + help='Path to YAML file containing parameters.', ) @click.option( - "--parameters_yaml", - "-y", + '--parameters_yaml', + '-y', multiple=True, - help="YAML string to be used as parameters.", + help='YAML string to be used as parameters.', ) @click.option( - "--parameters_base64", - "-b", + '--parameters_base64', + '-b', multiple=True, - help="Base64 encoded YAML string as parameters.", + help='Base64 encoded YAML string as parameters.', ) @click.option( - "--inject-input-path", + '--inject-input-path', is_flag=True, default=False, - help="Insert the path of the input notebook as PAPERMILL_INPUT_PATH as a notebook parameter.", + help='Insert the path of the input notebook as PAPERMILL_INPUT_PATH as a notebook parameter.', ) @click.option( - "--inject-output-path", + '--inject-output-path', is_flag=True, default=False, - help="Insert the path of the output notebook as PAPERMILL_OUTPUT_PATH as a notebook parameter.", + help='Insert the path of the output notebook as PAPERMILL_OUTPUT_PATH as a notebook parameter.', ) @click.option( - "--inject-paths", + '--inject-paths', is_flag=True, default=False, help=( - "Insert the paths of input/output notebooks as PAPERMILL_INPUT_PATH/PAPERMILL_OUTPUT_PATH" - " as notebook parameters." + 'Insert the paths of input/output notebooks as PAPERMILL_INPUT_PATH/PAPERMILL_OUTPUT_PATH' + ' as notebook parameters.' ), ) +@click.option('--engine', help='The execution engine name to use in evaluating the notebook.') @click.option( - "--engine", help="The execution engine name to use in evaluating the notebook." -) -@click.option( - "--request-save-on-cell-execute/--no-request-save-on-cell-execute", + '--request-save-on-cell-execute/--no-request-save-on-cell-execute', default=True, - help="Request save notebook after each cell execution", + help='Request save notebook after each cell execution', ) @click.option( - "--autosave-cell-every", + '--autosave-cell-every', default=30, type=int, - help="How often in seconds to autosave the notebook during long cell executions (0 to disable)", + help='How often in seconds to autosave the notebook during long cell executions (0 to disable)', ) @click.option( - "--prepare-only/--prepare-execute", + '--prepare-only/--prepare-execute', default=False, - help="Flag for outputting the notebook without execution, but with parameters applied.", + help='Flag for outputting the notebook without execution, but with parameters applied.', ) @click.option( - "--kernel", - "-k", - help="Name of kernel to run. Ignores kernel name in the notebook document metadata.", + '--kernel', + '-k', + help='Name of kernel to run. Ignores kernel name in the notebook document metadata.', ) @click.option( - "--language", - "-l", - help="Language for notebook execution. Ignores language in the notebook document metadata.", + '--language', + '-l', + help='Language for notebook execution. Ignores language in the notebook document metadata.', ) -@click.option("--cwd", default=None, help="Working directory to run notebook in.") +@click.option('--cwd', default=None, help='Working directory to run notebook in.') @click.option( - "--progress-bar/--no-progress-bar", + '--progress-bar/--no-progress-bar', default=None, - help="Flag for turning on the progress bar.", + help='Flag for turning on the progress bar.', ) @click.option( - "--log-output/--no-log-output", + '--log-output/--no-log-output', default=False, - help="Flag for writing notebook output to the configured logger.", + help='Flag for writing notebook output to the configured logger.', ) @click.option( - "--stdout-file", - type=click.File(mode="w", encoding="utf-8"), - help="File to write notebook stdout output to.", + '--stdout-file', + type=click.File(mode='w', encoding='utf-8'), + help='File to write notebook stdout output to.', ) @click.option( - "--stderr-file", - type=click.File(mode="w", encoding="utf-8"), - help="File to write notebook stderr output to.", + '--stderr-file', + type=click.File(mode='w', encoding='utf-8'), + help='File to write notebook stderr output to.', ) @click.option( - "--log-level", - type=click.Choice(["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), - default="INFO", - help="Set log level", + '--log-level', + type=click.Choice(['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']), + default='INFO', + help='Set log level', ) @click.option( - "--start-timeout", - "--start_timeout", # Backwards compatible naming + '--start-timeout', + '--start_timeout', # Backwards compatible naming type=int, default=60, - help="Time in seconds to wait for kernel to start.", + help='Time in seconds to wait for kernel to start.', ) @click.option( - "--execution-timeout", + '--execution-timeout', type=int, - help="Time in seconds to wait for each cell before failing execution (default: forever)", + help='Time in seconds to wait for each cell before failing execution (default: forever)', ) +@click.option('--report-mode/--no-report-mode', default=False, help='Flag for hiding input.') @click.option( - "--report-mode/--no-report-mode", default=False, help="Flag for hiding input." -) -@click.option( - "--version", + '--version', is_flag=True, callback=print_papermill_version, expose_value=False, is_eager=True, - help="Flag for displaying the version.", + help='Flag for displaying the version.', ) def papermill( click_ctx, @@ -224,8 +214,8 @@ def papermill( """ # Jupyter deps use frozen modules, so we disable the python 3.11+ warning about debugger if running the CLI - if "PYDEVD_DISABLE_FILE_VALIDATION" not in os.environ: - os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" + if 'PYDEVD_DISABLE_FILE_VALIDATION' not in os.environ: + os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1' if not help_notebook: required_output_path = not (INPUT_PIPED or OUTPUT_PIPED) @@ -233,35 +223,33 @@ def papermill( raise click.UsageError("Missing argument 'OUTPUT_PATH'") if INPUT_PIPED and notebook_path and not output_path: - input_path = "-" + input_path = '-' output_path = notebook_path else: - input_path = notebook_path or "-" - output_path = output_path or "-" + input_path = notebook_path or '-' + output_path = output_path or '-' - if output_path == "-": + if output_path == '-': # Save notebook to stdout just once request_save_on_cell_execute = False # Reduce default log level if we pipe to stdout - if log_level == "INFO": - log_level = "ERROR" + if log_level == 'INFO': + log_level = 'ERROR' elif progress_bar is None: progress_bar = not log_output - logging.basicConfig(level=log_level, format="%(message)s") + logging.basicConfig(level=log_level, format='%(message)s') # Read in Parameters parameters_final = {} if inject_input_path or inject_paths: - parameters_final["PAPERMILL_INPUT_PATH"] = input_path + parameters_final['PAPERMILL_INPUT_PATH'] = input_path if inject_output_path or inject_paths: - parameters_final["PAPERMILL_OUTPUT_PATH"] = output_path + parameters_final['PAPERMILL_OUTPUT_PATH'] = output_path for params in parameters_base64 or []: - parameters_final.update( - yaml.load(base64.b64decode(params), Loader=NoDatesSafeLoader) or {} - ) + parameters_final.update(yaml.load(base64.b64decode(params), Loader=NoDatesSafeLoader) or {}) for files in parameters_file or []: parameters_final.update(read_yaml_file(files) or {}) for params in parameters_yaml or []: @@ -301,11 +289,11 @@ def papermill( def _resolve_type(value): - if value == "True": + if value == 'True': return True - elif value == "False": + elif value == 'False': return False - elif value == "None": + elif value == 'None': return None elif _is_int(value): return int(value) diff --git a/papermill/clientwrap.py b/papermill/clientwrap.py index b6718a2f..f4d4a8b2 100644 --- a/papermill/clientwrap.py +++ b/papermill/clientwrap.py @@ -1,5 +1,5 @@ -import sys import asyncio +import sys from nbclient import NotebookClient from nbclient.exceptions import CellExecutionError @@ -27,9 +27,7 @@ def __init__(self, nb_man, km=None, raise_on_iopub_timeout=True, **kw): Optional kernel manager. If none is provided, a kernel manager will be created. """ - super().__init__( - nb_man.nb, km=km, raise_on_iopub_timeout=raise_on_iopub_timeout, **kw - ) + super().__init__(nb_man.nb, km=km, raise_on_iopub_timeout=raise_on_iopub_timeout, **kw) self.nb_man = nb_man def execute(self, **kwargs): @@ -39,18 +37,14 @@ def execute(self, **kwargs): self.reset_execution_trackers() # See https://bugs.python.org/issue37373 :( - if ( - sys.version_info[0] == 3 - and sys.version_info[1] >= 8 - and sys.platform.startswith("win") - ): + if sys.version_info[0] == 3 and sys.version_info[1] >= 8 and sys.platform.startswith('win'): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) with self.setup_kernel(**kwargs): - self.log.info("Executing notebook with kernel: %s" % self.kernel_name) + self.log.info('Executing notebook with kernel: %s' % self.kernel_name) self.papermill_execute_cells() info_msg = self.wait_for_reply(self.kc.kernel_info()) - self.nb.metadata["language_info"] = info_msg["content"]["language_info"] + self.nb.metadata['language_info'] = info_msg['content']['language_info'] self.set_widgets_metadata() return self.nb @@ -77,9 +71,7 @@ def papermill_execute_cells(self): self.nb_man.cell_start(cell, index) self.execute_cell(cell, index) except CellExecutionError as ex: - self.nb_man.cell_exception( - self.nb.cells[index], cell_index=index, exception=ex - ) + self.nb_man.cell_exception(self.nb.cells[index], cell_index=index, exception=ex) break finally: self.nb_man.cell_complete(self.nb.cells[index], cell_index=index) @@ -92,23 +84,23 @@ def log_output_message(self, output): :param output: nbformat.notebooknode.NotebookNode :return: """ - if output.output_type == "stream": - content = "".join(output.text) - if output.name == "stdout": + if output.output_type == 'stream': + content = ''.join(output.text) + if output.name == 'stdout': if self.log_output: self.log.info(content) if self.stdout_file: self.stdout_file.write(content) self.stdout_file.flush() - elif output.name == "stderr": + elif output.name == 'stderr': if self.log_output: # In case users want to redirect stderr differently, pipe to warning self.log.warning(content) if self.stderr_file: self.stderr_file.write(content) self.stderr_file.flush() - elif self.log_output and ("data" in output and "text/plain" in output.data): - self.log.info("".join(output.data["text/plain"])) + elif self.log_output and ('data' in output and 'text/plain' in output.data): + self.log.info(''.join(output.data['text/plain'])) def process_message(self, *arg, **kwargs): output = super().process_message(*arg, **kwargs) diff --git a/papermill/engines.py b/papermill/engines.py index 5200ff7d..3e87f52b 100644 --- a/papermill/engines.py +++ b/papermill/engines.py @@ -1,16 +1,16 @@ """Engines to perform different roles""" -import sys import datetime -import dateutil - +import sys from functools import wraps + +import dateutil import entrypoints -from .log import logger -from .exceptions import PapermillException from .clientwrap import PapermillNotebookClient +from .exceptions import PapermillException from .iorw import write_ipynb -from .utils import merge_kwargs, remove_args, nb_kernel_name, nb_language +from .log import logger +from .utils import merge_kwargs, nb_kernel_name, nb_language, remove_args class PapermillEngines: @@ -33,7 +33,7 @@ def register_entry_points(self): Load handlers provided by other packages """ - for entrypoint in entrypoints.get_group_all("papermill.engine"): + for entrypoint in entrypoints.get_group_all('papermill.engine'): self.register(entrypoint.name, entrypoint.load()) def get_engine(self, name=None): @@ -69,7 +69,7 @@ def catch_nb_assignment(func): @wraps(func) def wrapper(self, *args, **kwargs): - nb = kwargs.get("nb") + nb = kwargs.get('nb') if nb: # Reassign if executing notebook object was replaced self.nb = nb @@ -90,10 +90,10 @@ class NotebookExecutionManager: shared manner. """ - PENDING = "pending" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" + PENDING = 'pending' + RUNNING = 'running' + COMPLETED = 'completed' + FAILED = 'failed' def __init__( self, @@ -110,15 +110,13 @@ def __init__( self.end_time = None self.autosave_cell_every = autosave_cell_every self.max_autosave_pct = 25 - self.last_save_time = ( - self.now() - ) # Not exactly true, but simplifies testing logic + self.last_save_time = self.now() # Not exactly true, but simplifies testing logic self.pbar = None if progress_bar: # lazy import due to implict slow ipython import from tqdm.auto import tqdm - self.pbar = tqdm(total=len(self.nb.cells), unit="cell", desc="Executing") + self.pbar = tqdm(total=len(self.nb.cells), unit='cell', desc='Executing') def now(self): """Helper to return current UTC time""" @@ -169,7 +167,7 @@ def autosave_cell(self): # Autosave is taking too long, so exponentially back off. self.autosave_cell_every *= 2 logger.warning( - "Autosave too slow: {:.2f} sec, over {}% limit. Backing off to {} sec".format( + 'Autosave too slow: {:.2f} sec, over {}% limit. Backing off to {} sec'.format( save_elapsed, self.max_autosave_pct, self.autosave_cell_every ) ) @@ -187,14 +185,14 @@ def notebook_start(self, **kwargs): """ self.set_timer() - self.nb.metadata.papermill["start_time"] = self.start_time.isoformat() - self.nb.metadata.papermill["end_time"] = None - self.nb.metadata.papermill["duration"] = None - self.nb.metadata.papermill["exception"] = None + self.nb.metadata.papermill['start_time'] = self.start_time.isoformat() + self.nb.metadata.papermill['end_time'] = None + self.nb.metadata.papermill['duration'] = None + self.nb.metadata.papermill['exception'] = None for cell in self.nb.cells: # Reset the cell execution counts. - if cell.get("cell_type") == "code": + if cell.get('cell_type') == 'code': cell.execution_count = None # Clear out the papermill metadata for each cell. @@ -205,7 +203,7 @@ def notebook_start(self, **kwargs): duration=None, status=self.PENDING, # pending, running, completed ) - if cell.get("cell_type") == "code": + if cell.get('cell_type') == 'code': cell.outputs = [] self.save() @@ -219,17 +217,17 @@ def cell_start(self, cell, cell_index=None, **kwargs): metadata for a cell and save the notebook to the output path. """ if self.log_output: - ceel_num = cell_index + 1 if cell_index is not None else "" - logger.info(f"Executing Cell {ceel_num:-<40}") + ceel_num = cell_index + 1 if cell_index is not None else '' + logger.info(f'Executing Cell {ceel_num:-<40}') - cell.metadata.papermill["start_time"] = self.now().isoformat() - cell.metadata.papermill["status"] = self.RUNNING - cell.metadata.papermill["exception"] = False + cell.metadata.papermill['start_time'] = self.now().isoformat() + cell.metadata.papermill['status'] = self.RUNNING + cell.metadata.papermill['exception'] = False # injects optional description of the current cell directly in the tqdm cell_description = self.get_cell_description(cell) - if cell_description is not None and hasattr(self, "pbar") and self.pbar: - self.pbar.set_description(f"Executing {cell_description}") + if cell_description is not None and hasattr(self, 'pbar') and self.pbar: + self.pbar.set_description(f'Executing {cell_description}') self.save() @@ -242,9 +240,9 @@ def cell_exception(self, cell, cell_index=None, **kwargs): set the metadata on the notebook indicating the location of the failure. """ - cell.metadata.papermill["exception"] = True - cell.metadata.papermill["status"] = self.FAILED - self.nb.metadata.papermill["exception"] = True + cell.metadata.papermill['exception'] = True + cell.metadata.papermill['status'] = self.FAILED + self.nb.metadata.papermill['exception'] = True @catch_nb_assignment def cell_complete(self, cell, cell_index=None, **kwargs): @@ -257,20 +255,18 @@ def cell_complete(self, cell, cell_index=None, **kwargs): end_time = self.now() if self.log_output: - ceel_num = cell_index + 1 if cell_index is not None else "" - logger.info(f"Ending Cell {ceel_num:-<43}") + ceel_num = cell_index + 1 if cell_index is not None else '' + logger.info(f'Ending Cell {ceel_num:-<43}') # Ensure our last cell messages are not buffered by python sys.stdout.flush() sys.stderr.flush() - cell.metadata.papermill["end_time"] = end_time.isoformat() - if cell.metadata.papermill.get("start_time"): - start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"]) - cell.metadata.papermill["duration"] = ( - end_time - start_time - ).total_seconds() - if cell.metadata.papermill["status"] != self.FAILED: - cell.metadata.papermill["status"] = self.COMPLETED + cell.metadata.papermill['end_time'] = end_time.isoformat() + if cell.metadata.papermill.get('start_time'): + start_time = dateutil.parser.parse(cell.metadata.papermill['start_time']) + cell.metadata.papermill['duration'] = (end_time - start_time).total_seconds() + if cell.metadata.papermill['status'] != self.FAILED: + cell.metadata.papermill['status'] = self.COMPLETED self.save() if self.pbar: @@ -285,18 +281,16 @@ def notebook_complete(self, **kwargs): Called by Engine when execution concludes, regardless of exceptions. """ self.end_time = self.now() - self.nb.metadata.papermill["end_time"] = self.end_time.isoformat() - if self.nb.metadata.papermill.get("start_time"): - self.nb.metadata.papermill["duration"] = ( - self.end_time - self.start_time - ).total_seconds() + self.nb.metadata.papermill['end_time'] = self.end_time.isoformat() + if self.nb.metadata.papermill.get('start_time'): + self.nb.metadata.papermill['duration'] = (self.end_time - self.start_time).total_seconds() # Cleanup cell statuses in case callbacks were never called for cell in self.nb.cells: - if cell.metadata.papermill["status"] == self.FAILED: + if cell.metadata.papermill['status'] == self.FAILED: break - elif cell.metadata.papermill["status"] == self.PENDING: - cell.metadata.papermill["status"] = self.COMPLETED + elif cell.metadata.papermill['status'] == self.PENDING: + cell.metadata.papermill['status'] = self.COMPLETED self.complete_pbar() self.cleanup_pbar() @@ -304,12 +298,12 @@ def notebook_complete(self, **kwargs): # Force a final sync self.save() - def get_cell_description(self, cell, escape_str="papermill_description="): + def get_cell_description(self, cell, escape_str='papermill_description='): """Fetches cell description if present""" if cell is None: return None - cell_code = cell["source"] + cell_code = cell['source'] if cell_code is None or escape_str not in cell_code: return None @@ -317,13 +311,13 @@ def get_cell_description(self, cell, escape_str="papermill_description="): def complete_pbar(self): """Refresh progress bar""" - if hasattr(self, "pbar") and self.pbar: + if hasattr(self, 'pbar') and self.pbar: self.pbar.n = len(self.nb.cells) self.pbar.refresh() def cleanup_pbar(self): """Clean up a progress bar""" - if hasattr(self, "pbar") and self.pbar: + if hasattr(self, 'pbar') and self.pbar: self.pbar.close() self.pbar = None @@ -371,9 +365,7 @@ def execute_notebook( nb_man.notebook_start() try: - cls.execute_managed_notebook( - nb_man, kernel_name, log_output=log_output, **kwargs - ) + cls.execute_managed_notebook(nb_man, kernel_name, log_output=log_output, **kwargs) finally: nb_man.cleanup_pbar() nb_man.notebook_complete() @@ -383,9 +375,7 @@ def execute_notebook( @classmethod def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): """An abstract method where implementation will be defined in a subclass.""" - raise NotImplementedError( - "'execute_managed_notebook' is not implemented for this engine" - ) + raise NotImplementedError("'execute_managed_notebook' is not implemented for this engine") @classmethod def nb_kernel_name(cls, nb, name=None): @@ -431,12 +421,12 @@ def execute_managed_notebook( """ # Exclude parameters that named differently downstream - safe_kwargs = remove_args(["timeout", "startup_timeout"], **kwargs) + safe_kwargs = remove_args(['timeout', 'startup_timeout'], **kwargs) # Nicely handle preprocessor arguments prioritizing values set by engine final_kwargs = merge_kwargs( safe_kwargs, - timeout=execution_timeout if execution_timeout else kwargs.get("timeout"), + timeout=execution_timeout if execution_timeout else kwargs.get('timeout'), startup_timeout=start_timeout, kernel_name=kernel_name, log=logger, @@ -450,5 +440,5 @@ def execute_managed_notebook( # Instantiate a PapermillEngines instance, register Handlers and entrypoints papermill_engines = PapermillEngines() papermill_engines.register(None, NBClientEngine) -papermill_engines.register("nbclient", NBClientEngine) +papermill_engines.register('nbclient', NBClientEngine) papermill_engines.register_entry_points() diff --git a/papermill/exceptions.py b/papermill/exceptions.py index 38aab7e8..f78f95f7 100644 --- a/papermill/exceptions.py +++ b/papermill/exceptions.py @@ -33,10 +33,10 @@ def __str__(self): # when called with str(). In order to maintain compatability with previous versions which # passed only the message to the superclass constructor, __str__ method is implemented to # provide the same result as was produced in the past. - message = "\n" + 75 * "-" + "\n" + message = '\n' + 75 * '-' + '\n' message += 'Exception encountered at "In [%s]":\n' % str(self.exec_count) - message += "\n".join(self.traceback) - message += "\n" + message += '\n'.join(self.traceback) + message += '\n' return message @@ -59,10 +59,8 @@ class PapermillParameterOverwriteWarning(PapermillWarning): def missing_dependency_generator(package, dep): def missing_dep(): raise PapermillOptionalDependencyException( - "The {package} optional dependency is missing. " - "Please run pip install papermill[{dep}] to install this dependency".format( - package=package, dep=dep - ) + f'The {package} optional dependency is missing. ' + f'Please run pip install papermill[{dep}] to install this dependency' ) return missing_dep @@ -71,11 +69,9 @@ def missing_dep(): def missing_environment_variable_generator(package, env_key): def missing_dep(): raise PapermillOptionalDependencyException( - "The {package} optional dependency is present, but the environment " - "variable {env_key} is not set. Please set this variable as " - "required by {package} on your platform.".format( - package=package, env_key=env_key - ) + f'The {package} optional dependency is present, but the environment ' + f'variable {env_key} is not set. Please set this variable as ' + f'required by {package} on your platform.' ) return missing_dep diff --git a/papermill/execute.py b/papermill/execute.py index 3d0d23ae..1b683918 100644 --- a/papermill/execute.py +++ b/papermill/execute.py @@ -1,17 +1,18 @@ -import nbformat from pathlib import Path -from .log import logger -from .exceptions import PapermillExecutionError -from .iorw import get_pretty_path, local_file_io_cwd, load_notebook_node, write_ipynb +import nbformat + from .engines import papermill_engines -from .utils import chdir +from .exceptions import PapermillExecutionError +from .inspection import _infer_parameters +from .iorw import get_pretty_path, load_notebook_node, local_file_io_cwd, write_ipynb +from .log import logger from .parameterize import ( add_builtin_parameters, parameterize_notebook, parameterize_path, ) -from .inspection import _infer_parameters +from .utils import chdir def execute_notebook( @@ -83,23 +84,21 @@ def execute_notebook( input_path = parameterize_path(input_path, path_parameters) output_path = parameterize_path(output_path, path_parameters) - logger.info("Input Notebook: %s" % get_pretty_path(input_path)) - logger.info("Output Notebook: %s" % get_pretty_path(output_path)) + logger.info('Input Notebook: %s' % get_pretty_path(input_path)) + logger.info('Output Notebook: %s' % get_pretty_path(output_path)) with local_file_io_cwd(): if cwd is not None: - logger.info(f"Working directory: {get_pretty_path(cwd)}") + logger.info(f'Working directory: {get_pretty_path(cwd)}') nb = load_notebook_node(input_path) # Parameterize the Notebook. if parameters: - parameter_predefined = _infer_parameters( - nb, name=kernel_name, language=language - ) + parameter_predefined = _infer_parameters(nb, name=kernel_name, language=language) parameter_predefined = {p.name for p in parameter_predefined} for p in parameters: if p not in parameter_predefined: - logger.warning(f"Passed unknown parameter: {p}") + logger.warning(f'Passed unknown parameter: {p}') nb = parameterize_notebook( nb, parameters, @@ -115,9 +114,7 @@ def execute_notebook( if not prepare_only: # Dropdown to the engine to fetch the kernel name from the notebook document - kernel_name = papermill_engines.nb_kernel_name( - engine_name=engine_name, nb=nb, name=kernel_name - ) + kernel_name = papermill_engines.nb_kernel_name(engine_name=engine_name, nb=nb, name=kernel_name) # Execute the Notebook in `cwd` if it is set with chdir(cwd): nb = papermill_engines.execute_notebook_with_engine( @@ -160,40 +157,36 @@ def prepare_notebook_metadata(nb, input_path, output_path, report_mode=False): # Hide input if report-mode is set to True. if report_mode: for cell in nb.cells: - if cell.cell_type == "code": - cell.metadata["jupyter"] = cell.get("jupyter", {}) - cell.metadata["jupyter"]["source_hidden"] = True + if cell.cell_type == 'code': + cell.metadata['jupyter'] = cell.get('jupyter', {}) + cell.metadata['jupyter']['source_hidden'] = True # Record specified environment variable values. - nb.metadata.papermill["input_path"] = input_path - nb.metadata.papermill["output_path"] = output_path + nb.metadata.papermill['input_path'] = input_path + nb.metadata.papermill['output_path'] = output_path return nb -ERROR_MARKER_TAG = "papermill-error-cell-tag" +ERROR_MARKER_TAG = 'papermill-error-cell-tag' ERROR_STYLE = 'style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;"' ERROR_MESSAGE_TEMPLATE = ( - "" - "An Exception was encountered at 'In [%s]'." - "" + '' + 'An Exception was encountered at \'In [%s]\'.' + '' ) ERROR_ANCHOR_MSG = ( - '" - "Execution using papermill encountered an exception here and stopped:" - "" + '' + 'Execution using papermill encountered an exception here and stopped:' + '' ) def remove_error_markers(nb): - nb.cells = [ - cell - for cell in nb.cells - if ERROR_MARKER_TAG not in cell.metadata.get("tags", []) - ] + nb.cells = [cell for cell in nb.cells if ERROR_MARKER_TAG not in cell.metadata.get('tags', [])] return nb @@ -209,14 +202,12 @@ def raise_for_execution_errors(nb, output_path): """ error = None for index, cell in enumerate(nb.cells): - if cell.get("outputs") is None: + if cell.get('outputs') is None: continue for output in cell.outputs: - if output.output_type == "error": - if output.ename == "SystemExit" and ( - output.evalue == "" or output.evalue == "0" - ): + if output.output_type == 'error': + if output.ename == 'SystemExit' and (output.evalue == '' or output.evalue == '0'): continue error = PapermillExecutionError( cell_index=index, @@ -233,9 +224,9 @@ def raise_for_execution_errors(nb, output_path): # the relevant cell (by adding a note just before the failure with an HTML anchor) error_msg = ERROR_MESSAGE_TEMPLATE % str(error.exec_count) error_msg_cell = nbformat.v4.new_markdown_cell(error_msg) - error_msg_cell.metadata["tags"] = [ERROR_MARKER_TAG] + error_msg_cell.metadata['tags'] = [ERROR_MARKER_TAG] error_anchor_cell = nbformat.v4.new_markdown_cell(ERROR_ANCHOR_MSG) - error_anchor_cell.metadata["tags"] = [ERROR_MARKER_TAG] + error_anchor_cell.metadata['tags'] = [ERROR_MARKER_TAG] # Upgrade the Notebook to the latest v4 before writing into it nb = nbformat.v4.upgrade(nb) diff --git a/papermill/inspection.py b/papermill/inspection.py index b1ec68f7..db5a6136 100644 --- a/papermill/inspection.py +++ b/papermill/inspection.py @@ -1,7 +1,8 @@ """Deduce parameters of a notebook from the parameters cell.""" -import click from pathlib import Path +import click + from .iorw import get_pretty_path, load_notebook_node, local_file_io_cwd from .log import logger from .parameterize import add_builtin_parameters, parameterize_path @@ -17,7 +18,7 @@ def _open_notebook(notebook_path, parameters): path_parameters = add_builtin_parameters(parameters) input_path = parameterize_path(notebook_path, path_parameters) - logger.info("Input Notebook: %s" % get_pretty_path(input_path)) + logger.info('Input Notebook: %s' % get_pretty_path(input_path)) with local_file_io_cwd(): return load_notebook_node(input_path) @@ -38,7 +39,7 @@ def _infer_parameters(nb, name=None, language=None): """ params = [] - parameter_cell_idx = find_first_tagged_cell_index(nb, "parameters") + parameter_cell_idx = find_first_tagged_cell_index(nb, 'parameters') if parameter_cell_idx < 0: return params parameter_cell = nb.cells[parameter_cell_idx] @@ -50,11 +51,7 @@ def _infer_parameters(nb, name=None, language=None): try: params = translator.inspect(parameter_cell) except NotImplementedError: - logger.warning( - "Translator for '{}' language does not support parameter introspection.".format( - language - ) - ) + logger.warning(f"Translator for '{language}' language does not support parameter introspection.") return params @@ -74,7 +71,7 @@ def display_notebook_help(ctx, notebook_path, parameters): pretty_path = get_pretty_path(notebook_path) click.echo(f"\nParameters inferred for notebook '{pretty_path}':") - if not any_tagged_cell(nb, "parameters"): + if not any_tagged_cell(nb, 'parameters'): click.echo("\n No cell tagged 'parameters'") return 1 @@ -82,25 +79,22 @@ def display_notebook_help(ctx, notebook_path, parameters): if params: for param in params: p = param._asdict() - type_repr = p["inferred_type_name"] - if type_repr == "None": - type_repr = "Unknown type" + type_repr = p['inferred_type_name'] + if type_repr == 'None': + type_repr = 'Unknown type' - definition = " {}: {} (default {})".format( - p["name"], type_repr, p["default"] - ) + definition = ' {}: {} (default {})'.format(p['name'], type_repr, p['default']) if len(definition) > 30: - if len(p["help"]): - param_help = "".join((definition, "\n", 34 * " ", p["help"])) + if len(p['help']): + param_help = ''.join((definition, '\n', 34 * ' ', p['help'])) else: param_help = definition else: - param_help = "{:<34}{}".format(definition, p["help"]) + param_help = '{:<34}{}'.format(definition, p['help']) click.echo(param_help) else: click.echo( - "\n Can't infer anything about this notebook's parameters. " - "It may not have any parameter defined." + "\n Can't infer anything about this notebook's parameters. " 'It may not have any parameter defined.' ) return 0 diff --git a/papermill/iorw.py b/papermill/iorw.py index 961ee207..ecca680f 100644 --- a/papermill/iorw.py +++ b/papermill/iorw.py @@ -1,15 +1,14 @@ +import fnmatch +import json import os import sys -import json -import yaml -import fnmatch -import nbformat -import requests import warnings -import entrypoints - from contextlib import contextmanager +import entrypoints +import nbformat +import requests +import yaml from tenacity import ( retry, retry_if_exception_type, @@ -30,37 +29,37 @@ try: from .s3 import S3 except ImportError: - S3 = missing_dependency_generator("boto3", "s3") + S3 = missing_dependency_generator('boto3', 's3') try: from .adl import ADL except ImportError: - ADL = missing_dependency_generator("azure.datalake.store", "azure") + ADL = missing_dependency_generator('azure.datalake.store', 'azure') except KeyError as exc: - if exc.args[0] == "APPDATA": - ADL = missing_environment_variable_generator("azure.datalake.store", "APPDATA") + if exc.args[0] == 'APPDATA': + ADL = missing_environment_variable_generator('azure.datalake.store', 'APPDATA') else: raise try: from .abs import AzureBlobStore except ImportError: - AzureBlobStore = missing_dependency_generator("azure.storage.blob", "azure") + AzureBlobStore = missing_dependency_generator('azure.storage.blob', 'azure') try: from gcsfs import GCSFileSystem except ImportError: - GCSFileSystem = missing_dependency_generator("gcsfs", "gcs") + GCSFileSystem = missing_dependency_generator('gcsfs', 'gcs') try: - from pyarrow.fs import HadoopFileSystem, FileSelector + from pyarrow.fs import FileSelector, HadoopFileSystem except ImportError: - HadoopFileSystem = missing_dependency_generator("pyarrow", "hdfs") + HadoopFileSystem = missing_dependency_generator('pyarrow', 'hdfs') try: from github import Github except ImportError: - Github = missing_dependency_generator("pygithub", "github") + Github = missing_dependency_generator('pygithub', 'github') def fallback_gs_is_retriable(e): @@ -97,14 +96,14 @@ class PapermillIO: def __init__(self): self.reset() - def read(self, path, extensions=[".ipynb", ".json"]): + def read(self, path, extensions=['.ipynb', '.json']): # Handle https://github.com/nteract/papermill/issues/317 notebook_metadata = self.get_handler(path, extensions).read(path) if isinstance(notebook_metadata, (bytes, bytearray)): - return notebook_metadata.decode("utf-8") + return notebook_metadata.decode('utf-8') return notebook_metadata - def write(self, buf, path, extensions=[".ipynb", ".json"]): + def write(self, buf, path, extensions=['.ipynb', '.json']): return self.get_handler(path, extensions).write(buf, path) def listdir(self, path): @@ -122,7 +121,7 @@ def register(self, scheme, handler): def register_entry_points(self): # Load handlers provided by other packages - for entrypoint in entrypoints.get_group_all("papermill.io"): + for entrypoint in entrypoints.get_group_all('papermill.io'): self.register(entrypoint.name, entrypoint.load()) def get_handler(self, path, extensions=None): @@ -151,31 +150,21 @@ def get_handler(self, path, extensions=None): return NotebookNodeHandler() if extensions: - if not fnmatch.fnmatch(os.path.basename(path).split("?")[0], "*.*"): - warnings.warn( - "the file is not specified with any extension : " - + os.path.basename(path) - ) - elif not any( - fnmatch.fnmatch(os.path.basename(path).split("?")[0], "*" + ext) - for ext in extensions - ): - warnings.warn( - f"The specified file ({path}) does not end in one of {extensions}" - ) + if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'): + warnings.warn('the file is not specified with any extension : ' + os.path.basename(path)) + elif not any(fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext) for ext in extensions): + warnings.warn(f'The specified file ({path}) does not end in one of {extensions}') local_handler = None for scheme, handler in self._handlers: - if scheme == "local": + if scheme == 'local': local_handler = handler if path.startswith(scheme): return handler if local_handler is None: - raise PapermillException( - f"Could not find a registered schema handler for: {path}" - ) + raise PapermillException(f'Could not find a registered schema handler for: {path}') return local_handler @@ -183,11 +172,11 @@ def get_handler(self, path, extensions=None): class HttpHandler: @classmethod def read(cls, path): - return requests.get(path, headers={"Accept": "application/json"}).text + return requests.get(path, headers={'Accept': 'application/json'}).text @classmethod def listdir(cls, path): - raise PapermillException("listdir is not supported by HttpHandler") + raise PapermillException('listdir is not supported by HttpHandler') @classmethod def write(cls, buf, path): @@ -206,7 +195,7 @@ def __init__(self): def read(self, path): try: with chdir(self._cwd): - with open(path, encoding="utf-8") as f: + with open(path, encoding='utf-8') as f: return f.read() except OSError as e: try: @@ -227,7 +216,7 @@ def write(self, buf, path): dirname = os.path.dirname(path) if dirname and not os.path.exists(dirname): raise FileNotFoundError(f"output folder {dirname} doesn't exist.") - with open(path, "w", encoding="utf-8") as f: + with open(path, 'w', encoding='utf-8') as f: f.write(buf) def pretty_path(self, path): @@ -243,7 +232,7 @@ def cwd(self, new_path): class S3Handler: @classmethod def read(cls, path): - return "\n".join(S3().read(path)) + return '\n'.join(S3().read(path)) @classmethod def listdir(cls, path): @@ -269,7 +258,7 @@ def _get_client(self): def read(self, path): lines = self._get_client().read(path) - return "\n".join(lines) + return '\n'.join(lines) def listdir(self, path): return self._get_client().listdir(path) @@ -292,7 +281,7 @@ def _get_client(self): def read(self, path): lines = self._get_client().read(path) - return "\n".join(lines) + return '\n'.join(lines) def listdir(self, path): return self._get_client().listdir(path) @@ -339,13 +328,13 @@ def write(self, buf, path): ) def retry_write(): try: - with self._get_client().open(path, "w") as f: + with self._get_client().open(path, 'w') as f: return f.write(buf) except Exception as e: try: message = e.message except AttributeError: - message = f"Generic exception {type(e)} raised" + message = f'Generic exception {type(e)} raised' if gs_is_retriable(e): raise PapermillRateLimitException(message) # Reraise the original exception without retries @@ -363,7 +352,7 @@ def __init__(self): def _get_client(self): if self._client is None: - self._client = HadoopFileSystem(host="default") + self._client = HadoopFileSystem(host='default') return self._client def read(self, path): @@ -387,7 +376,7 @@ def __init__(self): def _get_client(self): if self._client is None: - token = os.environ.get("GITHUB_ACCESS_TOKEN", None) + token = os.environ.get('GITHUB_ACCESS_TOKEN', None) if token: self._client = Github(token) else: @@ -395,20 +384,20 @@ def _get_client(self): return self._client def read(self, path): - splits = path.split("/") + splits = path.split('/') org_id = splits[3] repo_id = splits[4] ref_id = splits[6] - sub_path = "/".join(splits[7:]) - repo = self._get_client().get_repo(org_id + "/" + repo_id) + sub_path = '/'.join(splits[7:]) + repo = self._get_client().get_repo(org_id + '/' + repo_id) content = repo.get_contents(sub_path, ref=ref_id) return content.decoded_content def listdir(self, path): - raise PapermillException("listdir is not supported by GithubHandler") + raise PapermillException('listdir is not supported by GithubHandler') def write(self, buf, path): - raise PapermillException("write is not supported by GithubHandler") + raise PapermillException('write is not supported by GithubHandler') def pretty_path(self, path): return path @@ -421,15 +410,15 @@ def read(self, path): return sys.stdin.read() def listdir(self, path): - raise PapermillException("listdir is not supported by Stream Handler") + raise PapermillException('listdir is not supported by Stream Handler') def write(self, buf, path): try: - return sys.stdout.buffer.write(buf.encode("utf-8")) + return sys.stdout.buffer.write(buf.encode('utf-8')) except AttributeError: # Originally required by https://github.com/nteract/papermill/issues/420 # Support Buffer.io objects - return sys.stdout.write(buf.encode("utf-8")) + return sys.stdout.write(buf.encode('utf-8')) def pretty_path(self, path): return path @@ -442,61 +431,59 @@ def read(self, path): return nbformat.writes(path) def listdir(self, path): - raise PapermillException("listdir is not supported by NotebookNode Handler") + raise PapermillException('listdir is not supported by NotebookNode Handler') def write(self, buf, path): - raise PapermillException("write is not supported by NotebookNode Handler") + raise PapermillException('write is not supported by NotebookNode Handler') def pretty_path(self, path): - return "NotebookNode object" + return 'NotebookNode object' class NoIOHandler: """Handler for output_path of None - intended to not write anything""" def read(self, path): - raise PapermillException("read is not supported by NoIOHandler") + raise PapermillException('read is not supported by NoIOHandler') def listdir(self, path): - raise PapermillException("listdir is not supported by NoIOHandler") + raise PapermillException('listdir is not supported by NoIOHandler') def write(self, buf, path): return def pretty_path(self, path): - return "Notebook will not be saved" + return 'Notebook will not be saved' # Hack to make YAML loader not auto-convert datetimes # https://stackoverflow.com/a/52312810 class NoDatesSafeLoader(yaml.SafeLoader): yaml_implicit_resolvers = { - k: [r for r in v if r[0] != "tag:yaml.org,2002:timestamp"] + k: [r for r in v if r[0] != 'tag:yaml.org,2002:timestamp'] for k, v in yaml.SafeLoader.yaml_implicit_resolvers.items() } # Instantiate a PapermillIO instance and register Handlers. papermill_io = PapermillIO() -papermill_io.register("local", LocalHandler()) -papermill_io.register("s3://", S3Handler) -papermill_io.register("adl://", ADLHandler()) -papermill_io.register("abs://", ABSHandler()) -papermill_io.register("http://", HttpHandler) -papermill_io.register("https://", HttpHandler) -papermill_io.register("gs://", GCSHandler()) -papermill_io.register("hdfs://", HDFSHandler()) -papermill_io.register("http://github.com/", GithubHandler()) -papermill_io.register("https://github.com/", GithubHandler()) -papermill_io.register("-", StreamHandler()) +papermill_io.register('local', LocalHandler()) +papermill_io.register('s3://', S3Handler) +papermill_io.register('adl://', ADLHandler()) +papermill_io.register('abs://', ABSHandler()) +papermill_io.register('http://', HttpHandler) +papermill_io.register('https://', HttpHandler) +papermill_io.register('gs://', GCSHandler()) +papermill_io.register('hdfs://', HDFSHandler()) +papermill_io.register('http://github.com/', GithubHandler()) +papermill_io.register('https://github.com/', GithubHandler()) +papermill_io.register('-', StreamHandler()) papermill_io.register_entry_points() def read_yaml_file(path): """Reads a YAML file from the location specified at 'path'.""" - return yaml.load( - papermill_io.read(path, [".json", ".yaml", ".yml"]), Loader=NoDatesSafeLoader - ) + return yaml.load(papermill_io.read(path, ['.json', '.yaml', '.yml']), Loader=NoDatesSafeLoader) def write_ipynb(nb, path): @@ -523,27 +510,27 @@ def load_notebook_node(notebook_path): if nb_upgraded is not None: nb = nb_upgraded - if not hasattr(nb.metadata, "papermill"): - nb.metadata["papermill"] = { - "default_parameters": dict(), - "parameters": dict(), - "environment_variables": dict(), - "version": __version__, + if not hasattr(nb.metadata, 'papermill'): + nb.metadata['papermill'] = { + 'default_parameters': dict(), + 'parameters': dict(), + 'environment_variables': dict(), + 'version': __version__, } for cell in nb.cells: - if not hasattr(cell.metadata, "tags"): - cell.metadata["tags"] = [] # Create tags attr if one doesn't exist. + if not hasattr(cell.metadata, 'tags'): + cell.metadata['tags'] = [] # Create tags attr if one doesn't exist. - if not hasattr(cell.metadata, "papermill"): - cell.metadata["papermill"] = dict() + if not hasattr(cell.metadata, 'papermill'): + cell.metadata['papermill'] = dict() return nb def list_notebook_files(path): """Returns a list of all the notebook files in a directory.""" - return [p for p in papermill_io.listdir(path) if p.endswith(".ipynb")] + return [p for p in papermill_io.listdir(path) if p.endswith('.ipynb')] def get_pretty_path(path): @@ -553,14 +540,14 @@ def get_pretty_path(path): @contextmanager def local_file_io_cwd(path=None): try: - local_handler = papermill_io.get_handler("local") + local_handler = papermill_io.get_handler('local') except PapermillException: - logger.warning("No local file handler detected") + logger.warning('No local file handler detected') else: try: old_cwd = local_handler.cwd(path or os.getcwd()) except AttributeError: - logger.warning("Local file handler does not support cwd assignment") + logger.warning('Local file handler does not support cwd assignment') else: try: yield diff --git a/papermill/log.py b/papermill/log.py index 273bc8f3..b90225d2 100644 --- a/papermill/log.py +++ b/papermill/log.py @@ -1,4 +1,4 @@ """Sets up a logger""" import logging -logger = logging.getLogger("papermill") +logger = logging.getLogger('papermill') diff --git a/papermill/models.py b/papermill/models.py index fcbb627f..35c077e5 100644 --- a/papermill/models.py +++ b/papermill/models.py @@ -2,11 +2,11 @@ from collections import namedtuple Parameter = namedtuple( - "Parameter", + 'Parameter', [ - "name", - "inferred_type_name", # string of type - "default", # string representing the default value - "help", + 'name', + 'inferred_type_name', # string of type + 'default', # string representing the default value + 'help', ], ) diff --git a/papermill/parameterize.py b/papermill/parameterize.py index db3ac837..a210f26e 100644 --- a/papermill/parameterize.py +++ b/papermill/parameterize.py @@ -1,15 +1,15 @@ +from datetime import datetime +from uuid import uuid4 + import nbformat from .engines import papermill_engines -from .log import logger from .exceptions import PapermillMissingParameterException from .iorw import read_yaml_file +from .log import logger from .translators import translate_parameters from .utils import find_first_tagged_cell_index -from uuid import uuid4 -from datetime import datetime - def add_builtin_parameters(parameters): """Add built-in parameters to a dictionary of parameters @@ -20,10 +20,10 @@ def add_builtin_parameters(parameters): Dictionary of parameters provided by the user """ with_builtin_parameters = { - "pm": { - "run_uuid": str(uuid4()), - "current_datetime_local": datetime.now(), - "current_datetime_utc": datetime.utcnow(), + 'pm': { + 'run_uuid': str(uuid4()), + 'current_datetime_local': datetime.now(), + 'current_datetime_utc': datetime.utcnow(), } } @@ -53,14 +53,14 @@ def parameterize_path(path, parameters): try: return path.format(**parameters) except KeyError as key_error: - raise PapermillMissingParameterException(f"Missing parameter {key_error}") + raise PapermillMissingParameterException(f'Missing parameter {key_error}') def parameterize_notebook( nb, parameters, report_mode=False, - comment="Parameters", + comment='Parameters', kernel_name=None, language=None, engine_name=None, @@ -93,14 +93,14 @@ def parameterize_notebook( nb = nbformat.v4.upgrade(nb) newcell = nbformat.v4.new_code_cell(source=param_content) - newcell.metadata["tags"] = ["injected-parameters"] + newcell.metadata['tags'] = ['injected-parameters'] if report_mode: - newcell.metadata["jupyter"] = newcell.get("jupyter", {}) - newcell.metadata["jupyter"]["source_hidden"] = True + newcell.metadata['jupyter'] = newcell.get('jupyter', {}) + newcell.metadata['jupyter']['source_hidden'] = True - param_cell_index = find_first_tagged_cell_index(nb, "parameters") - injected_cell_index = find_first_tagged_cell_index(nb, "injected-parameters") + param_cell_index = find_first_tagged_cell_index(nb, 'parameters') + injected_cell_index = find_first_tagged_cell_index(nb, 'injected-parameters') if injected_cell_index >= 0: # Replace the injected cell with a new version before = nb.cells[:injected_cell_index] @@ -116,6 +116,6 @@ def parameterize_notebook( after = nb.cells nb.cells = before + [newcell] + after - nb.metadata.papermill["parameters"] = parameters + nb.metadata.papermill['parameters'] = parameters return nb diff --git a/papermill/s3.py b/papermill/s3.py index ccd2141a..06ac9aff 100644 --- a/papermill/s3.py +++ b/papermill/s3.py @@ -1,8 +1,7 @@ """Utilities for working with S3.""" -import os - import logging +import os import threading import zlib @@ -11,8 +10,7 @@ from .exceptions import AwsError from .utils import retry - -logger = logging.getLogger("papermill.s3") +logger = logging.getLogger('papermill.s3') class Bucket: @@ -32,11 +30,9 @@ def __init__(self, name, service=None): self.name = name self.service = service - def list(self, prefix="", delimiter=None): + def list(self, prefix='', delimiter=None): """Limits a list of Bucket's objects based on prefix and delimiter.""" - return self.service._list( - bucket=self.name, prefix=prefix, delimiter=delimiter, objects=True - ) + return self.service._list(bucket=self.name, prefix=prefix, delimiter=delimiter, objects=True) class Prefix: @@ -61,7 +57,7 @@ def __init__(self, bucket, name, service=None): self.service = service def __str__(self): - return f"s3://{self.bucket.name}/{self.name}" + return f's3://{self.bucket.name}/{self.name}' def __repr__(self): return self.__str__() @@ -106,7 +102,7 @@ def __init__( self.etag = etag if last_modified: try: - self.last_modified = last_modified.isoformat().split("+")[0] + ".000Z" + self.last_modified = last_modified.isoformat().split('+')[0] + '.000Z' except ValueError: self.last_modified = last_modified self.storage_class = storage_class @@ -114,7 +110,7 @@ def __init__( self.service = service def __str__(self): - return f"s3://{self.bucket.name}/{self.name}" + return f's3://{self.bucket.name}/{self.name}' def __repr__(self): return self.__str__() @@ -146,47 +142,45 @@ def __init__(self, keyname=None, *args, **kwargs): with self.lock: if not all(S3.s3_session): session = Session() - client = session.client("s3") + client = session.client('s3') session_params = {} - endpoint_url = os.environ.get("BOTO3_ENDPOINT_URL", None) + endpoint_url = os.environ.get('BOTO3_ENDPOINT_URL', None) if endpoint_url: - session_params["endpoint_url"] = endpoint_url + session_params['endpoint_url'] = endpoint_url - s3 = session.resource("s3", **session_params) + s3 = session.resource('s3', **session_params) S3.s3_session = (session, client, s3) (self.session, self.client, self.s3) = S3.s3_session def _bucket_name(self, bucket): - return self._clean(bucket).split("/", 1)[0] + return self._clean(bucket).split('/', 1)[0] def _clean(self, name): - if name.startswith("s3n:"): - name = "s3:" + name[4:] + if name.startswith('s3n:'): + name = 's3:' + name[4:] if self._is_s3(name): return name[5:] return name def _clean_s3(self, name): - return "s3:" + name[4:] if name.startswith("s3n:") else name + return 's3:' + name[4:] if name.startswith('s3n:') else name def _get_key(self, name): if isinstance(name, Key): return name - return Key( - bucket=self._bucket_name(name), name=self._key_name(name), service=self - ) + return Key(bucket=self._bucket_name(name), name=self._key_name(name), service=self) def _key_name(self, name): - cleaned = self._clean(name).split("/", 1) + cleaned = self._clean(name).split('/', 1) return cleaned[1] if len(cleaned) > 1 else None @retry(3) def _list( self, - prefix="", + prefix='', bucket=None, delimiter=None, keys=False, @@ -194,55 +188,55 @@ def _list( page_size=1000, **kwargs, ): - assert bucket is not None, "You must specify a bucket to list" + assert bucket is not None, 'You must specify a bucket to list' bucket = self._bucket_name(bucket) - paginator = self.client.get_paginator("list_objects_v2") + paginator = self.client.get_paginator('list_objects_v2') operation_parameters = { - "Bucket": bucket, - "Prefix": prefix, - "PaginationConfig": {"PageSize": page_size}, + 'Bucket': bucket, + 'Prefix': prefix, + 'PaginationConfig': {'PageSize': page_size}, } if delimiter: - operation_parameters["Delimiter"] = delimiter + operation_parameters['Delimiter'] = delimiter page_iterator = paginator.paginate(**operation_parameters) def sort(item): - if "Key" in item: - return item["Key"] - return item["Prefix"] + if 'Key' in item: + return item['Key'] + return item['Prefix'] for page in page_iterator: locations = sorted( - [i for i in page.get("Contents", []) + page.get("CommonPrefixes", [])], + [i for i in page.get('Contents', []) + page.get('CommonPrefixes', [])], key=sort, ) for item in locations: if objects or keys: - if "Key" in item: + if 'Key' in item: yield Key( bucket, - item["Key"], - size=item.get("Size"), - etag=item.get("ETag"), - last_modified=item.get("LastModified"), - storage_class=item.get("StorageClass"), + item['Key'], + size=item.get('Size'), + etag=item.get('ETag'), + last_modified=item.get('LastModified'), + storage_class=item.get('StorageClass'), service=self, ) elif objects: - yield Prefix(bucket, item["Prefix"], service=self) + yield Prefix(bucket, item['Prefix'], service=self) else: - prefix = item["Key"] if "Key" in item else item["Prefix"] - yield f"s3://{bucket}/{prefix}" + prefix = item['Key'] if 'Key' in item else item['Prefix'] + yield f's3://{bucket}/{prefix}' def _put( self, source, dest, num_callbacks=10, - policy="bucket-owner-full-control", + policy='bucket-owner-full-control', **kwargs, ): key = self._get_key(dest) @@ -251,9 +245,9 @@ def _put( # support passing in open file obj. Why did we do this in the past? if not isinstance(source, str): - obj.upload_fileobj(source, ExtraArgs={"ACL": policy}) + obj.upload_fileobj(source, ExtraArgs={'ACL': policy}) else: - obj.upload_file(source, ExtraArgs={"ACL": policy}) + obj.upload_file(source, ExtraArgs={'ACL': policy}) return key def _put_string( @@ -261,14 +255,14 @@ def _put_string( source, dest, num_callbacks=10, - policy="bucket-owner-full-control", + policy='bucket-owner-full-control', **kwargs, ): key = self._get_key(dest) obj = self.s3.Object(key.bucket.name, key.name) if isinstance(source, str): - source = source.encode("utf-8") + source = source.encode('utf-8') obj.put(Body=source, ACL=policy) return key @@ -278,7 +272,7 @@ def _is_s3(self, name): return False name = self._clean_s3(name) - return "s3://" in name + return 's3://' in name def cat( self, @@ -286,7 +280,7 @@ def cat( buffersize=None, memsize=2**24, compressed=False, - encoding="UTF-8", + encoding='UTF-8', raw=False, ): """ @@ -296,19 +290,17 @@ def cat( skip encoding. """ - assert self._is_s3(source) or isinstance( - source, Key - ), "source must be a valid s3 path" + assert self._is_s3(source) or isinstance(source, Key), 'source must be a valid s3 path' key = self._get_key(source) if not isinstance(source, Key) else source - compressed = (compressed or key.name.endswith(".gz")) and not raw + compressed = (compressed or key.name.endswith('.gz')) and not raw if compressed: decompress = zlib.decompressobj(16 + zlib.MAX_WBITS) size = 0 bytes_read = 0 err = None - undecoded = "" + undecoded = '' if key: # try to read the file multiple times for i in range(100): @@ -318,7 +310,7 @@ def cat( if not size: size = obj.content_length elif size != obj.content_length: - raise AwsError("key size unexpectedly changed while reading") + raise AwsError('key size unexpectedly changed while reading') # For an empty file, 0 (first-bytes-pos) is equal to the length of the object # hence the range is "unsatisfiable", and botocore correctly handles it by @@ -326,16 +318,16 @@ def cat( if size == 0: break - r = obj.get(Range=f"bytes={bytes_read}-") + r = obj.get(Range=f'bytes={bytes_read}-') try: while bytes_read < size: # this making this weird check because this call is # about 100 times slower if the amt is too high if size - bytes_read > buffersize: - bytes = r["Body"].read(amt=buffersize) + bytes = r['Body'].read(amt=buffersize) else: - bytes = r["Body"].read() + bytes = r['Body'].read() if compressed: s = decompress.decompress(bytes) else: @@ -344,7 +336,7 @@ def cat( if encoding and not raw: try: decoded = undecoded + s.decode(encoding) - undecoded = "" + undecoded = '' yield decoded except UnicodeDecodeError: undecoded += s @@ -356,7 +348,7 @@ def cat( bytes_read += len(bytes) except zlib.error: - logger.error("Error while decompressing [%s]", key.name) + logger.error('Error while decompressing [%s]', key.name) raise except UnicodeDecodeError: raise @@ -371,7 +363,7 @@ def cat( if err: raise Exception else: - raise AwsError("Failed to fully read [%s]" % source.name) + raise AwsError('Failed to fully read [%s]' % source.name) if undecoded: assert encoding is not None # only time undecoded is set @@ -392,8 +384,8 @@ def cp_string(self, source, dest, **kwargs): the s3 location """ - assert isinstance(source, str), "source must be a string" - assert self._is_s3(dest), "Destination must be s3 location" + assert isinstance(source, str), 'source must be a string' + assert self._is_s3(dest), 'Destination must be s3 location' return self._put_string(source, dest, **kwargs) @@ -416,11 +408,9 @@ def list(self, name, iterator=False, **kwargs): if True return iterator rather than converting to list object """ - assert self._is_s3(name), "name must be in form s3://bucket/key" + assert self._is_s3(name), 'name must be in form s3://bucket/key' - it = self._list( - bucket=self._bucket_name(name), prefix=self._key_name(name), **kwargs - ) + it = self._list(bucket=self._bucket_name(name), prefix=self._key_name(name), **kwargs) return iter(it) if iterator else list(it) def listdir(self, name, **kwargs): @@ -442,27 +432,27 @@ def listdir(self, name, **kwargs): files or prefixes that are encountered """ - assert self._is_s3(name), "name must be in form s3://bucket/prefix/" + assert self._is_s3(name), 'name must be in form s3://bucket/prefix/' - if not name.endswith("/"): - name += "/" - return self.list(name, delimiter="/", **kwargs) + if not name.endswith('/'): + name += '/' + return self.list(name, delimiter='/', **kwargs) - def read(self, source, compressed=False, encoding="UTF-8"): + def read(self, source, compressed=False, encoding='UTF-8'): """ Iterates over a file in s3 split on newline. Yields a line in file. """ - buf = "" + buf = '' for block in self.cat(source, compressed=compressed, encoding=encoding): buf += block - if "\n" in buf: - ret, buf = buf.rsplit("\n", 1) - yield from ret.split("\n") + if '\n' in buf: + ret, buf = buf.rsplit('\n', 1) + yield from ret.split('\n') - lines = buf.split("\n") + lines = buf.split('\n') yield from lines[:-1] # only yield the last line if the line has content in it diff --git a/papermill/tests/__init__.py b/papermill/tests/__init__.py index 9843f37e..6ef2067e 100644 --- a/papermill/tests/__init__.py +++ b/papermill/tests/__init__.py @@ -1,13 +1,11 @@ import os - from io import StringIO - -kernel_name = "python3" +kernel_name = 'python3' def get_notebook_path(*args): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "notebooks", *args) + return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'notebooks', *args) def get_notebook_dir(*args): diff --git a/papermill/tests/test_abs.py b/papermill/tests/test_abs.py index 7793f4bd..580828b9 100644 --- a/papermill/tests/test_abs.py +++ b/papermill/tests/test_abs.py @@ -1,14 +1,15 @@ import os import unittest - from unittest.mock import Mock, patch + from azure.identity import EnvironmentCredential + from ..abs import AzureBlobStore class MockBytesIO: def __init__(self): - self.list = [b"hello", b"world!"] + self.list = [b'hello', b'world!'] def __getitem__(self, index): return self.list[index] @@ -23,106 +24,86 @@ class ABSTest(unittest.TestCase): """ def setUp(self): - self.list_blobs = Mock(return_value=["foo", "bar", "baz"]) + self.list_blobs = Mock(return_value=['foo', 'bar', 'baz']) self.upload_blob = Mock() self.download_blob = Mock() self._container_client = Mock(list_blobs=self.list_blobs) - self._blob_client = Mock( - upload_blob=self.upload_blob, download_blob=self.download_blob - ) + self._blob_client = Mock(upload_blob=self.upload_blob, download_blob=self.download_blob) self._blob_service_client = Mock( get_blob_client=Mock(return_value=self._blob_client), get_container_client=Mock(return_value=self._container_client), ) self.abs = AzureBlobStore() self.abs._blob_service_client = Mock(return_value=self._blob_service_client) - os.environ["AZURE_TENANT_ID"] = "mytenantid" - os.environ["AZURE_CLIENT_ID"] = "myclientid" - os.environ["AZURE_CLIENT_SECRET"] = "myclientsecret" + os.environ['AZURE_TENANT_ID'] = 'mytenantid' + os.environ['AZURE_CLIENT_ID'] = 'myclientid' + os.environ['AZURE_CLIENT_SECRET'] = 'myclientsecret' def test_split_url_raises_exception_on_invalid_url(self): with self.assertRaises(Exception) as context: - AzureBlobStore._split_url("this_is_not_a_valid_url") - self.assertTrue( - "Invalid azure blob url 'this_is_not_a_valid_url'" in str(context.exception) - ) + AzureBlobStore._split_url('this_is_not_a_valid_url') + self.assertTrue("Invalid azure blob url 'this_is_not_a_valid_url'" in str(context.exception)) def test_split_url_splits_valid_url(self): - params = AzureBlobStore._split_url( - "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken" - ) - self.assertEqual(params["account"], "myaccount") - self.assertEqual(params["container"], "sascontainer") - self.assertEqual(params["blob"], "sasblob.txt") - self.assertEqual(params["sas_token"], "sastoken") + params = AzureBlobStore._split_url('abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken') + self.assertEqual(params['account'], 'myaccount') + self.assertEqual(params['container'], 'sascontainer') + self.assertEqual(params['blob'], 'sasblob.txt') + self.assertEqual(params['sas_token'], 'sastoken') def test_split_url_splits_valid_url_no_sas(self): - params = AzureBlobStore._split_url( - "abs://myaccount.blob.core.windows.net/container/blob.txt" - ) - self.assertEqual(params["account"], "myaccount") - self.assertEqual(params["container"], "container") - self.assertEqual(params["blob"], "blob.txt") - self.assertEqual(params["sas_token"], "") + params = AzureBlobStore._split_url('abs://myaccount.blob.core.windows.net/container/blob.txt') + self.assertEqual(params['account'], 'myaccount') + self.assertEqual(params['container'], 'container') + self.assertEqual(params['blob'], 'blob.txt') + self.assertEqual(params['sas_token'], '') def test_split_url_splits_valid_url_with_prefix(self): params = AzureBlobStore._split_url( - "abs://myaccount.blob.core.windows.net/sascontainer/A/B/sasblob.txt?sastoken" + 'abs://myaccount.blob.core.windows.net/sascontainer/A/B/sasblob.txt?sastoken' ) - self.assertEqual(params["account"], "myaccount") - self.assertEqual(params["container"], "sascontainer") - self.assertEqual(params["blob"], "A/B/sasblob.txt") - self.assertEqual(params["sas_token"], "sastoken") + self.assertEqual(params['account'], 'myaccount') + self.assertEqual(params['container'], 'sascontainer') + self.assertEqual(params['blob'], 'A/B/sasblob.txt') + self.assertEqual(params['sas_token'], 'sastoken') def test_listdir_calls(self): self.assertEqual( - self.abs.listdir( - "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken" - ), - ["foo", "bar", "baz"], + self.abs.listdir('abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken'), + ['foo', 'bar', 'baz'], ) - self._blob_service_client.get_container_client.assert_called_once_with( - "sascontainer" - ) - self.list_blobs.assert_called_once_with("sasblob.txt") + self._blob_service_client.get_container_client.assert_called_once_with('sascontainer') + self.list_blobs.assert_called_once_with('sasblob.txt') - @patch("papermill.abs.io.BytesIO", side_effect=MockBytesIO) + @patch('papermill.abs.io.BytesIO', side_effect=MockBytesIO) def test_reads_file(self, mockBytesIO): self.assertEqual( - self.abs.read( - "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken" - ), - ["hello", "world!"], - ) - self._blob_service_client.get_blob_client.assert_called_once_with( - "sascontainer", "sasblob.txt" + self.abs.read('abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken'), + ['hello', 'world!'], ) + self._blob_service_client.get_blob_client.assert_called_once_with('sascontainer', 'sasblob.txt') self.download_blob.assert_called_once_with() def test_write_file(self): self.abs.write( - "hello world", - "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken", + 'hello world', + 'abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken', ) - self._blob_service_client.get_blob_client.assert_called_once_with( - "sascontainer", "sasblob.txt" - ) - self.upload_blob.assert_called_once_with(data="hello world", overwrite=True) + self._blob_service_client.get_blob_client.assert_called_once_with('sascontainer', 'sasblob.txt') + self.upload_blob.assert_called_once_with(data='hello world', overwrite=True) def test_blob_service_client(self): abs = AzureBlobStore() - blob = abs._blob_service_client(account_name="myaccount", sas_token="sastoken") - self.assertEqual(blob.account_name, "myaccount") + blob = abs._blob_service_client(account_name='myaccount', sas_token='sastoken') + self.assertEqual(blob.account_name, 'myaccount') # Credentials gets funky with v12.0.0, so I comment this out # self.assertEqual(blob.credential, "sastoken") def test_blob_service_client_environment_credentials(self): abs = AzureBlobStore() - blob = abs._blob_service_client(account_name="myaccount", sas_token="") - self.assertEqual(blob.account_name, "myaccount") + blob = abs._blob_service_client(account_name='myaccount', sas_token='') + self.assertEqual(blob.account_name, 'myaccount') self.assertIsInstance(blob.credential, EnvironmentCredential) - self.assertEqual(blob.credential._credential._tenant_id, "mytenantid") - self.assertEqual(blob.credential._credential._client_id, "myclientid") - self.assertEqual( - blob.credential._credential._client_credential, "myclientsecret" - ) + self.assertEqual(blob.credential._credential._tenant_id, 'mytenantid') + self.assertEqual(blob.credential._credential._client_id, 'myclientid') + self.assertEqual(blob.credential._credential._client_credential, 'myclientsecret') diff --git a/papermill/tests/test_adl.py b/papermill/tests/test_adl.py index 6db76be3..952c7a19 100644 --- a/papermill/tests/test_adl.py +++ b/papermill/tests/test_adl.py @@ -1,8 +1,9 @@ import unittest +from unittest.mock import MagicMock, Mock, patch -from unittest.mock import Mock, MagicMock, patch - -from ..adl import ADL, core as adl_core, lib as adl_lib +from ..adl import ADL +from ..adl import core as adl_core +from ..adl import lib as adl_lib class ADLTest(unittest.TestCase): @@ -13,13 +14,13 @@ class ADLTest(unittest.TestCase): def setUp(self): self.ls = Mock( return_value=[ - "path/to/directory/foo", - "path/to/directory/bar", - "path/to/directory/baz", + 'path/to/directory/foo', + 'path/to/directory/bar', + 'path/to/directory/baz', ] ) self.fakeFile = MagicMock() - self.fakeFile.__iter__.return_value = [b"a", b"b", b"c"] + self.fakeFile.__iter__.return_value = [b'a', b'b', b'c'] self.fakeFile.__enter__.return_value = self.fakeFile self.open = Mock(return_value=self.fakeFile) self.fakeAdapter = Mock(open=self.open, ls=self.ls) @@ -28,49 +29,41 @@ def setUp(self): def test_split_url_raises_exception_on_invalid_url(self): with self.assertRaises(Exception) as context: - ADL._split_url("this_is_not_a_valid_url") - self.assertTrue( - "Invalid ADL url 'this_is_not_a_valid_url'" in str(context.exception) - ) + ADL._split_url('this_is_not_a_valid_url') + self.assertTrue("Invalid ADL url 'this_is_not_a_valid_url'" in str(context.exception)) def test_split_url_splits_valid_url(self): - (store_name, path) = ADL._split_url("adl://foo.azuredatalakestore.net/bar/baz") - self.assertEqual(store_name, "foo") - self.assertEqual(path, "bar/baz") + (store_name, path) = ADL._split_url('adl://foo.azuredatalakestore.net/bar/baz') + self.assertEqual(store_name, 'foo') + self.assertEqual(path, 'bar/baz') def test_listdir_calls_ls_on_adl_adapter(self): self.assertEqual( - self.adl.listdir( - "adl://foo_store.azuredatalakestore.net/path/to/directory" - ), + self.adl.listdir('adl://foo_store.azuredatalakestore.net/path/to/directory'), [ - "adl://foo_store.azuredatalakestore.net/path/to/directory/foo", - "adl://foo_store.azuredatalakestore.net/path/to/directory/bar", - "adl://foo_store.azuredatalakestore.net/path/to/directory/baz", + 'adl://foo_store.azuredatalakestore.net/path/to/directory/foo', + 'adl://foo_store.azuredatalakestore.net/path/to/directory/bar', + 'adl://foo_store.azuredatalakestore.net/path/to/directory/baz', ], ) - self.ls.assert_called_once_with("path/to/directory") + self.ls.assert_called_once_with('path/to/directory') def test_read_opens_and_reads_file(self): self.assertEqual( - self.adl.read("adl://foo_store.azuredatalakestore.net/path/to/file"), - ["a", "b", "c"], + self.adl.read('adl://foo_store.azuredatalakestore.net/path/to/file'), + ['a', 'b', 'c'], ) self.fakeFile.__iter__.assert_called_once_with() def test_write_opens_file_and_writes_to_it(self): - self.adl.write( - "hello world", "adl://foo_store.azuredatalakestore.net/path/to/file" - ) - self.fakeFile.write.assert_called_once_with(b"hello world") + self.adl.write('hello world', 'adl://foo_store.azuredatalakestore.net/path/to/file') + self.fakeFile.write.assert_called_once_with(b'hello world') - @patch.object(adl_lib, "auth", return_value="my_token") - @patch.object(adl_core, "AzureDLFileSystem", return_value="my_adapter") + @patch.object(adl_lib, 'auth', return_value='my_token') + @patch.object(adl_core, 'AzureDLFileSystem', return_value='my_adapter') def test_create_adapter(self, azure_dl_filesystem_mock, auth_mock): sut = ADL() - actual = sut._create_adapter("my_store_name") - assert actual == "my_adapter" + actual = sut._create_adapter('my_store_name') + assert actual == 'my_adapter' auth_mock.assert_called_once_with() - azure_dl_filesystem_mock.assert_called_once_with( - "my_token", store_name="my_store_name" - ) + azure_dl_filesystem_mock.assert_called_once_with('my_token', store_name='my_store_name') diff --git a/papermill/tests/test_autosave.py b/papermill/tests/test_autosave.py index b234c29a..74ae06e8 100644 --- a/papermill/tests/test_autosave.py +++ b/papermill/tests/test_autosave.py @@ -1,28 +1,26 @@ -import nbformat import os import tempfile import time import unittest from unittest.mock import patch -from . import get_notebook_path +import nbformat from .. import engines from ..engines import NotebookExecutionManager from ..execute import execute_notebook +from . import get_notebook_path class TestMidCellAutosave(unittest.TestCase): def setUp(self): - self.notebook_name = "test_autosave.ipynb" + self.notebook_name = 'test_autosave.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb = nbformat.read(self.notebook_path, as_version=4) def test_autosave_not_too_fast(self): - nb_man = NotebookExecutionManager( - self.nb, output_path="test.ipynb", autosave_cell_every=0.5 - ) - with patch.object(engines, "write_ipynb") as write_mock: + nb_man = NotebookExecutionManager(self.nb, output_path='test.ipynb', autosave_cell_every=0.5) + with patch.object(engines, 'write_ipynb') as write_mock: write_mock.reset_mock() assert write_mock.call_count == 0 # check that the mock is sane nb_man.autosave_cell() # First call to autosave shouldn't trigger save @@ -34,38 +32,30 @@ def test_autosave_not_too_fast(self): assert write_mock.call_count == 1 def test_autosave_disable(self): - nb_man = NotebookExecutionManager( - self.nb, output_path="test.ipynb", autosave_cell_every=0 - ) - with patch.object(engines, "write_ipynb") as write_mock: + nb_man = NotebookExecutionManager(self.nb, output_path='test.ipynb', autosave_cell_every=0) + with patch.object(engines, 'write_ipynb') as write_mock: write_mock.reset_mock() assert write_mock.call_count == 0 # check that the mock is sane nb_man.autosave_cell() # First call to autosave shouldn't trigger save assert write_mock.call_count == 0 nb_man.autosave_cell() # Call again right away. Still shouldn't save. assert write_mock.call_count == 0 - time.sleep( - 0.55 - ) # Sleep for long enough that autosave should work, if enabled + time.sleep(0.55) # Sleep for long enough that autosave should work, if enabled nb_man.autosave_cell() assert write_mock.call_count == 0 # but it's disabled. def test_end2end_autosave_slow_notebook(self): test_dir = tempfile.mkdtemp() - nb_test_executed_fname = os.path.join(test_dir, f"output_{self.notebook_name}") + nb_test_executed_fname = os.path.join(test_dir, f'output_{self.notebook_name}') # Count how many times it writes the file w/o autosave - with patch.object(engines, "write_ipynb") as write_mock: - execute_notebook( - self.notebook_path, nb_test_executed_fname, autosave_cell_every=0 - ) + with patch.object(engines, 'write_ipynb') as write_mock: + execute_notebook(self.notebook_path, nb_test_executed_fname, autosave_cell_every=0) default_write_count = write_mock.call_count # Turn on autosave and see how many more times it gets saved. - with patch.object(engines, "write_ipynb") as write_mock: - execute_notebook( - self.notebook_path, nb_test_executed_fname, autosave_cell_every=1 - ) + with patch.object(engines, 'write_ipynb') as write_mock: + execute_notebook(self.notebook_path, nb_test_executed_fname, autosave_cell_every=1) # This notebook has a cell which takes 2.5 seconds to run. # Autosave every 1 sec should add two more saves. assert write_mock.call_count == default_write_count + 2 diff --git a/papermill/tests/test_cli.py b/papermill/tests/test_cli.py index 7381fd24..ad6ddbed 100755 --- a/papermill/tests/test_cli.py +++ b/papermill/tests/test_cli.py @@ -2,35 +2,34 @@ """ Test the command line interface """ import os -from pathlib import Path -import sys import subprocess +import sys import tempfile -import uuid -import nbclient - -import nbformat import unittest +import uuid +from pathlib import Path from unittest.mock import patch +import nbclient +import nbformat import pytest from click.testing import CliRunner -from . import get_notebook_path, kernel_name from .. import cli -from ..cli import papermill, _is_int, _is_float, _resolve_type +from ..cli import _is_float, _is_int, _resolve_type, papermill +from . import get_notebook_path, kernel_name @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("True", True), - ("False", False), - ("None", None), - ("12.51", 12.51), - ("10", 10), - ("hello world", "hello world"), - ("😍", "😍"), + ('True', True), + ('False', False), + ('None', None), + ('12.51', 12.51), + ('10', 10), + ('hello world', 'hello world'), + ('😍', '😍'), ], ) def test_resolve_type(test_input, expected): @@ -38,17 +37,17 @@ def test_resolve_type(test_input, expected): @pytest.mark.parametrize( - "value,expected", + 'value,expected', [ (13.71, True), - ("False", False), - ("None", False), + ('False', False), + ('None', False), (-8.2, True), (10, True), - ("10", True), - ("12.31", True), - ("hello world", False), - ("😍", False), + ('10', True), + ('12.31', True), + ('hello world', False), + ('😍', False), ], ) def test_is_float(value, expected): @@ -56,17 +55,17 @@ def test_is_float(value, expected): @pytest.mark.parametrize( - "value,expected", + 'value,expected', [ (13.71, True), - ("False", False), - ("None", False), + ('False', False), + ('None', False), (-8.2, True), - ("-23.2", False), + ('-23.2', False), (10, True), - ("13", True), - ("hello world", False), - ("😍", False), + ('13', True), + ('hello world', False), + ('😍', False), ], ) def test_is_int(value, expected): @@ -75,8 +74,8 @@ def test_is_int(value, expected): class TestCLI(unittest.TestCase): default_execute_kwargs = dict( - input_path="input.ipynb", - output_path="output.ipynb", + input_path='input.ipynb', + output_path='output.ipynb', parameters={}, engine_name=None, request_save_on_cell_execute=True, @@ -97,47 +96,39 @@ class TestCLI(unittest.TestCase): def setUp(self): self.runner = CliRunner() self.default_args = [ - self.default_execute_kwargs["input_path"], - self.default_execute_kwargs["output_path"], + self.default_execute_kwargs['input_path'], + self.default_execute_kwargs['output_path'], ] - self.sample_yaml_file = os.path.join( - os.path.dirname(__file__), "parameters", "example.yaml" - ) - self.sample_json_file = os.path.join( - os.path.dirname(__file__), "parameters", "example.json" - ) + self.sample_yaml_file = os.path.join(os.path.dirname(__file__), 'parameters', 'example.yaml') + self.sample_json_file = os.path.join(os.path.dirname(__file__), 'parameters', 'example.json') def augment_execute_kwargs(self, **new_kwargs): kwargs = self.default_execute_kwargs.copy() kwargs.update(new_kwargs) return kwargs - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters(self, execute_patch): self.runner.invoke( papermill, - self.default_args + ["-p", "foo", "bar", "--parameters", "baz", "42"], - ) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"foo": "bar", "baz": 42}) + self.default_args + ['-p', 'foo', 'bar', '--parameters', 'baz', '42'], ) + execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'foo': 'bar', 'baz': 42})) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_raw(self, execute_patch): self.runner.invoke( papermill, - self.default_args + ["-r", "foo", "bar", "--parameters_raw", "baz", "42"], - ) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"foo": "bar", "baz": "42"}) + self.default_args + ['-r', 'foo', 'bar', '--parameters_raw', 'baz', '42'], ) + execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'foo': 'bar', 'baz': '42'})) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_file(self, execute_patch): extra_args = [ - "-f", + '-f', self.sample_yaml_file, - "--parameters_file", + '--parameters_file', self.sample_json_file, ] self.runner.invoke(papermill, self.default_args + extra_args) @@ -145,45 +136,40 @@ def test_parameters_file(self, execute_patch): **self.augment_execute_kwargs( # Last input wins dict update parameters={ - "foo": 54321, - "bar": "value", - "baz": {"k2": "v2", "k1": "v1"}, - "a_date": "2019-01-01", + 'foo': 54321, + 'bar': 'value', + 'baz': {'k2': 'v2', 'k1': 'v1'}, + 'a_date': '2019-01-01', } ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_yaml(self, execute_patch): self.runner.invoke( papermill, - self.default_args - + ["-y", '{"foo": "bar"}', "--parameters_yaml", '{"foo2": ["baz"]}'], - ) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"foo": "bar", "foo2": ["baz"]}) + self.default_args + ['-y', '{"foo": "bar"}', '--parameters_yaml', '{"foo2": ["baz"]}'], ) + execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'foo': 'bar', 'foo2': ['baz']})) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_yaml_date(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["-y", "a_date: 2019-01-01"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"a_date": "2019-01-01"}) - ) + self.runner.invoke(papermill, self.default_args + ['-y', 'a_date: 2019-01-01']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'a_date': '2019-01-01'})) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_empty(self, execute_patch): # "#empty" ---base64--> "I2VtcHR5" with tempfile.TemporaryDirectory() as tmpdir: - empty_yaml = Path(tmpdir) / "empty.yaml" - empty_yaml.write_text("#empty") + empty_yaml = Path(tmpdir) / 'empty.yaml' + empty_yaml.write_text('#empty') extra_args = [ - "--parameters_file", + '--parameters_file', str(empty_yaml), - "--parameters_yaml", - "#empty", - "--parameters_base64", - "I2VtcHR5", + '--parameters_yaml', + '#empty', + '--parameters_base64', + 'I2VtcHR5', ] self.runner.invoke( papermill, @@ -196,139 +182,113 @@ def test_parameters_empty(self, execute_patch): ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_yaml_override(self, execute_patch): self.runner.invoke( papermill, - self.default_args - + ["--parameters_yaml", '{"foo": "bar"}', "-y", '{"foo": ["baz"]}'], + self.default_args + ['--parameters_yaml', '{"foo": "bar"}', '-y', '{"foo": ["baz"]}'], ) execute_patch.assert_called_with( **self.augment_execute_kwargs( # Last input wins dict update - parameters={"foo": ["baz"]} + parameters={'foo': ['baz']} ) ) @patch( - cli.__name__ + ".execute_notebook", - side_effect=nbclient.exceptions.DeadKernelError("Fake"), + cli.__name__ + '.execute_notebook', + side_effect=nbclient.exceptions.DeadKernelError('Fake'), ) def test_parameters_dead_kernel(self, execute_patch): result = self.runner.invoke( papermill, - self.default_args - + ["--parameters_yaml", '{"foo": "bar"}', "-y", '{"foo": ["baz"]}'], + self.default_args + ['--parameters_yaml', '{"foo": "bar"}', '-y', '{"foo": ["baz"]}'], ) assert result.exit_code == 138 - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_base64(self, execute_patch): extra_args = [ - "--parameters_base64", - "eyJmb28iOiAicmVwbGFjZWQiLCAiYmFyIjogMn0=", - "-b", - "eydmb28nOiAxfQ==", + '--parameters_base64', + 'eyJmb28iOiAicmVwbGFjZWQiLCAiYmFyIjogMn0=', + '-b', + 'eydmb28nOiAxfQ==', ] self.runner.invoke(papermill, self.default_args + extra_args) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"foo": 1, "bar": 2}) - ) + execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'foo': 1, 'bar': 2})) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_base64_date(self, execute_patch): self.runner.invoke( papermill, - self.default_args + ["--parameters_base64", "YV9kYXRlOiAyMDE5LTAxLTAx"], - ) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"a_date": "2019-01-01"}) + self.default_args + ['--parameters_base64', 'YV9kYXRlOiAyMDE5LTAxLTAx'], ) + execute_patch.assert_called_with(**self.augment_execute_kwargs(parameters={'a_date': '2019-01-01'})) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_inject_input_path(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--inject-input-path"]) + self.runner.invoke(papermill, self.default_args + ['--inject-input-path']) execute_patch.assert_called_with( - **self.augment_execute_kwargs( - parameters={"PAPERMILL_INPUT_PATH": "input.ipynb"} - ) + **self.augment_execute_kwargs(parameters={'PAPERMILL_INPUT_PATH': 'input.ipynb'}) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_inject_output_path(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--inject-output-path"]) + self.runner.invoke(papermill, self.default_args + ['--inject-output-path']) execute_patch.assert_called_with( - **self.augment_execute_kwargs( - parameters={"PAPERMILL_OUTPUT_PATH": "output.ipynb"} - ) + **self.augment_execute_kwargs(parameters={'PAPERMILL_OUTPUT_PATH': 'output.ipynb'}) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_inject_paths(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--inject-paths"]) + self.runner.invoke(papermill, self.default_args + ['--inject-paths']) execute_patch.assert_called_with( **self.augment_execute_kwargs( parameters={ - "PAPERMILL_INPUT_PATH": "input.ipynb", - "PAPERMILL_OUTPUT_PATH": "output.ipynb", + 'PAPERMILL_INPUT_PATH': 'input.ipynb', + 'PAPERMILL_OUTPUT_PATH': 'output.ipynb', } ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_engine(self, execute_patch): - self.runner.invoke( - papermill, self.default_args + ["--engine", "engine-that-could"] - ) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(engine_name="engine-that-could") - ) + self.runner.invoke(papermill, self.default_args + ['--engine', 'engine-that-could']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(engine_name='engine-that-could')) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_prepare_only(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--prepare-only"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(prepare_only=True) - ) + self.runner.invoke(papermill, self.default_args + ['--prepare-only']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(prepare_only=True)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_kernel(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["-k", "python3"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(kernel_name="python3") - ) + self.runner.invoke(papermill, self.default_args + ['-k', 'python3']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(kernel_name='python3')) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_language(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["-l", "python"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(language="python") - ) + self.runner.invoke(papermill, self.default_args + ['-l', 'python']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(language='python')) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_set_cwd(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--cwd", "a/path/here"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(cwd="a/path/here") - ) + self.runner.invoke(papermill, self.default_args + ['--cwd', 'a/path/here']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(cwd='a/path/here')) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_progress_bar(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--progress-bar"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(progress_bar=True) - ) + self.runner.invoke(papermill, self.default_args + ['--progress-bar']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(progress_bar=True)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_no_progress_bar(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--no-progress-bar"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(progress_bar=False) - ) + self.runner.invoke(papermill, self.default_args + ['--no-progress-bar']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(progress_bar=False)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_log_output(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--log-output"]) + self.runner.invoke(papermill, self.default_args + ['--log-output']) execute_patch.assert_called_with( **self.augment_execute_kwargs( log_output=True, @@ -336,107 +296,89 @@ def test_log_output(self, execute_patch): ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_log_output_plus_progress(self, execute_patch): - self.runner.invoke( - papermill, self.default_args + ["--log-output", "--progress-bar"] - ) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(log_output=True, progress_bar=True) - ) + self.runner.invoke(papermill, self.default_args + ['--log-output', '--progress-bar']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(log_output=True, progress_bar=True)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_no_log_output(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--no-log-output"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(log_output=False) - ) + self.runner.invoke(papermill, self.default_args + ['--no-log-output']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(log_output=False)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_log_level(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--log-level", "WARNING"]) + self.runner.invoke(papermill, self.default_args + ['--log-level', 'WARNING']) # TODO: this does not actually test log-level being set execute_patch.assert_called_with(**self.augment_execute_kwargs()) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_start_timeout(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--start-timeout", "123"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(start_timeout=123) - ) + self.runner.invoke(papermill, self.default_args + ['--start-timeout', '123']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(start_timeout=123)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_start_timeout_backwards_compatibility(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--start_timeout", "123"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(start_timeout=123) - ) + self.runner.invoke(papermill, self.default_args + ['--start_timeout', '123']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(start_timeout=123)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_execution_timeout(self, execute_patch): - self.runner.invoke( - papermill, self.default_args + ["--execution-timeout", "123"] - ) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(execution_timeout=123) - ) + self.runner.invoke(papermill, self.default_args + ['--execution-timeout', '123']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(execution_timeout=123)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_report_mode(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--report-mode"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(report_mode=True) - ) + self.runner.invoke(papermill, self.default_args + ['--report-mode']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(report_mode=True)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_no_report_mode(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--no-report-mode"]) - execute_patch.assert_called_with( - **self.augment_execute_kwargs(report_mode=False) - ) + self.runner.invoke(papermill, self.default_args + ['--no-report-mode']) + execute_patch.assert_called_with(**self.augment_execute_kwargs(report_mode=False)) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_version(self, execute_patch): - self.runner.invoke(papermill, ["--version"]) + self.runner.invoke(papermill, ['--version']) execute_patch.assert_not_called() - @patch(cli.__name__ + ".execute_notebook") - @patch(cli.__name__ + ".display_notebook_help") + @patch(cli.__name__ + '.execute_notebook') + @patch(cli.__name__ + '.display_notebook_help') def test_help_notebook(self, display_notebook_help, execute_path): - self.runner.invoke(papermill, ["--help-notebook", "input_path.ipynb"]) + self.runner.invoke(papermill, ['--help-notebook', 'input_path.ipynb']) execute_path.assert_not_called() assert display_notebook_help.call_count == 1 - assert display_notebook_help.call_args[0][1] == "input_path.ipynb" + assert display_notebook_help.call_args[0][1] == 'input_path.ipynb' - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_many_args(self, execute_patch): extra_args = [ - "-f", + '-f', self.sample_yaml_file, - "-y", + '-y', '{"yaml_foo": {"yaml_bar": "yaml_baz"}}', - "-b", - "eyJiYXNlNjRfZm9vIjogImJhc2U2NF9iYXIifQ==", - "-p", - "baz", - "replace", - "-r", - "foo", - "54321", - "--kernel", - "R", - "--engine", - "engine-that-could", - "--prepare-only", - "--log-output", - "--autosave-cell-every", - "17", - "--no-progress-bar", - "--start-timeout", - "321", - "--execution-timeout", - "654", - "--report-mode", + '-b', + 'eyJiYXNlNjRfZm9vIjogImJhc2U2NF9iYXIifQ==', + '-p', + 'baz', + 'replace', + '-r', + 'foo', + '54321', + '--kernel', + 'R', + '--engine', + 'engine-that-could', + '--prepare-only', + '--log-output', + '--autosave-cell-every', + '17', + '--no-progress-bar', + '--start-timeout', + '321', + '--execution-timeout', + '654', + '--report-mode', ] self.runner.invoke( papermill, @@ -445,18 +387,18 @@ def test_many_args(self, execute_patch): execute_patch.assert_called_with( **self.augment_execute_kwargs( parameters={ - "foo": "54321", - "bar": "value", - "baz": "replace", - "yaml_foo": {"yaml_bar": "yaml_baz"}, - "base64_foo": "base64_bar", - "a_date": "2019-01-01", + 'foo': '54321', + 'bar': 'value', + 'baz': 'replace', + 'yaml_foo': {'yaml_bar': 'yaml_baz'}, + 'base64_foo': 'base64_bar', + 'a_date': '2019-01-01', }, - engine_name="engine-that-could", + engine_name='engine-that-could', request_save_on_cell_execute=True, autosave_cell_every=17, prepare_only=True, - kernel_name="R", + kernel_name='R', log_output=True, progress_bar=False, start_timeout=321, @@ -468,7 +410,7 @@ def test_many_args(self, execute_patch): def papermill_cli(papermill_args=None, **kwargs): - cmd = [sys.executable, "-m", "papermill"] + cmd = [sys.executable, '-m', 'papermill'] if papermill_args: cmd.extend(papermill_args) return subprocess.Popen(cmd, **kwargs) @@ -476,11 +418,11 @@ def papermill_cli(papermill_args=None, **kwargs): def papermill_version(): try: - proc = papermill_cli(["--version"], stdout=subprocess.PIPE) + proc = papermill_cli(['--version'], stdout=subprocess.PIPE) out, _ = proc.communicate() if proc.returncode: return None - return out.decode("utf-8") + return out.decode('utf-8') except (OSError, SystemExit): # pragma: no cover return None @@ -488,54 +430,50 @@ def papermill_version(): @pytest.fixture() def notebook(): metadata = { - "kernelspec": { - "name": "python3", - "language": "python", - "display_name": "python3", + 'kernelspec': { + 'name': 'python3', + 'language': 'python', + 'display_name': 'python3', } } return nbformat.v4.new_notebook( metadata=metadata, - cells=[ - nbformat.v4.new_markdown_cell("This is a notebook with kernel: python3") - ], + cells=[nbformat.v4.new_markdown_cell('This is a notebook with kernel: python3')], ) -require_papermill_installed = pytest.mark.skipif( - not papermill_version(), reason="papermill is not installed" -) +require_papermill_installed = pytest.mark.skipif(not papermill_version(), reason='papermill is not installed') @require_papermill_installed def test_pipe_in_out_auto(notebook): process = papermill_cli(stdout=subprocess.PIPE, stdin=subprocess.PIPE) text = nbformat.writes(notebook) - out, err = process.communicate(input=text.encode("utf-8")) + out, err = process.communicate(input=text.encode('utf-8')) # Test no message on std error assert not err # Test that output is a valid notebook - nbformat.reads(out.decode("utf-8"), as_version=4) + nbformat.reads(out.decode('utf-8'), as_version=4) @require_papermill_installed def test_pipe_in_out_explicit(notebook): - process = papermill_cli(["-", "-"], stdout=subprocess.PIPE, stdin=subprocess.PIPE) + process = papermill_cli(['-', '-'], stdout=subprocess.PIPE, stdin=subprocess.PIPE) text = nbformat.writes(notebook) - out, err = process.communicate(input=text.encode("utf-8")) + out, err = process.communicate(input=text.encode('utf-8')) # Test no message on std error assert not err # Test that output is a valid notebook - nbformat.reads(out.decode("utf-8"), as_version=4) + nbformat.reads(out.decode('utf-8'), as_version=4) @require_papermill_installed def test_pipe_out_auto(tmpdir, notebook): - nb_file = tmpdir.join("notebook.ipynb") + nb_file = tmpdir.join('notebook.ipynb') nb_file.write(nbformat.writes(notebook)) process = papermill_cli([str(nb_file)], stdout=subprocess.PIPE) @@ -545,31 +483,31 @@ def test_pipe_out_auto(tmpdir, notebook): assert not err # Test that output is a valid notebook - nbformat.reads(out.decode("utf-8"), as_version=4) + nbformat.reads(out.decode('utf-8'), as_version=4) @require_papermill_installed def test_pipe_out_explicit(tmpdir, notebook): - nb_file = tmpdir.join("notebook.ipynb") + nb_file = tmpdir.join('notebook.ipynb') nb_file.write(nbformat.writes(notebook)) - process = papermill_cli([str(nb_file), "-"], stdout=subprocess.PIPE) + process = papermill_cli([str(nb_file), '-'], stdout=subprocess.PIPE) out, err = process.communicate() # Test no message on std error assert not err # Test that output is a valid notebook - nbformat.reads(out.decode("utf-8"), as_version=4) + nbformat.reads(out.decode('utf-8'), as_version=4) @require_papermill_installed def test_pipe_in_auto(tmpdir, notebook): - nb_file = tmpdir.join("notebook.ipynb") + nb_file = tmpdir.join('notebook.ipynb') process = papermill_cli([str(nb_file)], stdin=subprocess.PIPE) text = nbformat.writes(notebook) - out, _ = process.communicate(input=text.encode("utf-8")) + out, _ = process.communicate(input=text.encode('utf-8')) # Nothing on stdout assert not out @@ -581,11 +519,11 @@ def test_pipe_in_auto(tmpdir, notebook): @require_papermill_installed def test_pipe_in_explicit(tmpdir, notebook): - nb_file = tmpdir.join("notebook.ipynb") + nb_file = tmpdir.join('notebook.ipynb') - process = papermill_cli(["-", str(nb_file)], stdin=subprocess.PIPE) + process = papermill_cli(['-', str(nb_file)], stdin=subprocess.PIPE) text = nbformat.writes(notebook) - out, _ = process.communicate(input=text.encode("utf-8")) + out, _ = process.communicate(input=text.encode('utf-8')) # Nothing on stdout assert not out @@ -597,20 +535,20 @@ def test_pipe_in_explicit(tmpdir, notebook): @require_papermill_installed def test_stdout_file(tmpdir): - nb_file = tmpdir.join("notebook.ipynb") - stdout_file = tmpdir.join("notebook.stdout") + nb_file = tmpdir.join('notebook.ipynb') + stdout_file = tmpdir.join('notebook.stdout') secret = str(uuid.uuid4()) process = papermill_cli( [ - get_notebook_path("simple_execute.ipynb"), + get_notebook_path('simple_execute.ipynb'), str(nb_file), - "-k", + '-k', kernel_name, - "-p", - "msg", + '-p', + 'msg', secret, - "--stdout-file", + '--stdout-file', str(stdout_file), ] ) @@ -620,4 +558,4 @@ def test_stdout_file(tmpdir): assert not err with open(str(stdout_file)) as fp: - assert fp.read() == secret + "\n" + assert fp.read() == secret + '\n' diff --git a/papermill/tests/test_clientwrap.py b/papermill/tests/test_clientwrap.py index deeb29a1..32309cf6 100644 --- a/papermill/tests/test_clientwrap.py +++ b/papermill/tests/test_clientwrap.py @@ -1,40 +1,39 @@ -import nbformat import unittest - from unittest.mock import call, patch -from . import get_notebook_path +import nbformat -from ..log import logger -from ..engines import NotebookExecutionManager from ..clientwrap import PapermillNotebookClient +from ..engines import NotebookExecutionManager +from ..log import logger +from . import get_notebook_path class TestPapermillClientWrapper(unittest.TestCase): def setUp(self): - self.nb = nbformat.read(get_notebook_path("test_logging.ipynb"), as_version=4) + self.nb = nbformat.read(get_notebook_path('test_logging.ipynb'), as_version=4) self.nb_man = NotebookExecutionManager(self.nb) self.client = PapermillNotebookClient(self.nb_man, log=logger, log_output=True) def test_logging_stderr_msg(self): - with patch.object(logger, "warning") as warning_mock: - for output in self.nb.cells[0].get("outputs", []): + with patch.object(logger, 'warning') as warning_mock: + for output in self.nb.cells[0].get('outputs', []): self.client.log_output_message(output) - warning_mock.assert_called_once_with("INFO:test:test text\n") + warning_mock.assert_called_once_with('INFO:test:test text\n') def test_logging_stdout_msg(self): - with patch.object(logger, "info") as info_mock: - for output in self.nb.cells[1].get("outputs", []): + with patch.object(logger, 'info') as info_mock: + for output in self.nb.cells[1].get('outputs', []): self.client.log_output_message(output) - info_mock.assert_called_once_with("hello world\n") + info_mock.assert_called_once_with('hello world\n') def test_logging_data_msg(self): - with patch.object(logger, "info") as info_mock: - for output in self.nb.cells[2].get("outputs", []): + with patch.object(logger, 'info') as info_mock: + for output in self.nb.cells[2].get('outputs', []): self.client.log_output_message(output) info_mock.assert_has_calls( [ - call(""), - call(""), + call(''), + call(''), ] ) diff --git a/papermill/tests/test_engines.py b/papermill/tests/test_engines.py index e635a6f9..b750a01e 100644 --- a/papermill/tests/test_engines.py +++ b/papermill/tests/test_engines.py @@ -1,17 +1,16 @@ import copy -import dateutil import unittest - from abc import ABCMeta -from unittest.mock import Mock, patch, call -from nbformat.notebooknode import NotebookNode +from unittest.mock import Mock, call, patch -from . import get_notebook_path +import dateutil +from nbformat.notebooknode import NotebookNode from .. import engines, exceptions -from ..log import logger +from ..engines import Engine, NBClientEngine, NotebookExecutionManager from ..iorw import load_notebook_node -from ..engines import NotebookExecutionManager, Engine, NBClientEngine +from ..log import logger +from . import get_notebook_path def AnyMock(cls): @@ -30,11 +29,11 @@ def __eq__(self, other): class TestNotebookExecutionManager(unittest.TestCase): def setUp(self): - self.notebook_name = "simple_execute.ipynb" + self.notebook_name = 'simple_execute.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb = load_notebook_node(self.notebook_path) self.foo_nb = copy.deepcopy(self.nb) - self.foo_nb.metadata["foo"] = "bar" + self.foo_nb.metadata['foo'] = 'bar' def test_basic_pbar(self): nb_man = NotebookExecutionManager(self.nb) @@ -51,73 +50,69 @@ def test_set_timer(self): nb_man = NotebookExecutionManager(self.nb) now = nb_man.now() - with patch.object(nb_man, "now", return_value=now): + with patch.object(nb_man, 'now', return_value=now): nb_man.set_timer() self.assertEqual(nb_man.start_time, now) self.assertIsNone(nb_man.end_time) def test_save(self): - nb_man = NotebookExecutionManager(self.nb, output_path="test.ipynb") - with patch.object(engines, "write_ipynb") as write_mock: + nb_man = NotebookExecutionManager(self.nb, output_path='test.ipynb') + with patch.object(engines, 'write_ipynb') as write_mock: nb_man.save() - write_mock.assert_called_with(self.nb, "test.ipynb") + write_mock.assert_called_with(self.nb, 'test.ipynb') def test_save_no_output(self): nb_man = NotebookExecutionManager(self.nb) - with patch.object(engines, "write_ipynb") as write_mock: + with patch.object(engines, 'write_ipynb') as write_mock: nb_man.save() write_mock.assert_not_called() def test_save_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.save(nb=self.foo_nb) - self.assertEqual(nb_man.nb.metadata["foo"], "bar") + self.assertEqual(nb_man.nb.metadata['foo'], 'bar') def test_get_cell_description(self): nb_man = NotebookExecutionManager(self.nb) self.assertIsNone(nb_man.get_cell_description(nb_man.nb.cells[0])) - self.assertEqual(nb_man.get_cell_description(nb_man.nb.cells[1]), "DESC") + self.assertEqual(nb_man.get_cell_description(nb_man.nb.cells[1]), 'DESC') def test_notebook_start(self): nb_man = NotebookExecutionManager(self.nb) - nb_man.nb.metadata["foo"] = "bar" + nb_man.nb.metadata['foo'] = 'bar' nb_man.save = Mock() nb_man.notebook_start() - self.assertEqual( - nb_man.nb.metadata.papermill["start_time"], nb_man.start_time.isoformat() - ) - self.assertIsNone(nb_man.nb.metadata.papermill["end_time"]) - self.assertIsNone(nb_man.nb.metadata.papermill["duration"]) - self.assertIsNone(nb_man.nb.metadata.papermill["exception"]) + self.assertEqual(nb_man.nb.metadata.papermill['start_time'], nb_man.start_time.isoformat()) + self.assertIsNone(nb_man.nb.metadata.papermill['end_time']) + self.assertIsNone(nb_man.nb.metadata.papermill['duration']) + self.assertIsNone(nb_man.nb.metadata.papermill['exception']) for cell in nb_man.nb.cells: - self.assertIsNone(cell.metadata.papermill["start_time"]) - self.assertIsNone(cell.metadata.papermill["end_time"]) - self.assertIsNone(cell.metadata.papermill["duration"]) - self.assertIsNone(cell.metadata.papermill["exception"]) - self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.PENDING - ) - self.assertIsNone(cell.get("execution_count")) - if cell.cell_type == "code": - self.assertEqual(cell.get("outputs"), []) + self.assertIsNone(cell.metadata.papermill['start_time']) + self.assertIsNone(cell.metadata.papermill['end_time']) + self.assertIsNone(cell.metadata.papermill['duration']) + self.assertIsNone(cell.metadata.papermill['exception']) + self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.PENDING) + self.assertIsNone(cell.get('execution_count')) + if cell.cell_type == 'code': + self.assertEqual(cell.get('outputs'), []) else: - self.assertIsNone(cell.get("outputs")) + self.assertIsNone(cell.get('outputs')) nb_man.save.assert_called_once() def test_notebook_start_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start(nb=self.foo_nb) - self.assertEqual(nb_man.nb.metadata["foo"], "bar") + self.assertEqual(nb_man.nb.metadata['foo'], 'bar') def test_notebook_start_markdown_code(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start(nb=self.foo_nb) - self.assertNotIn("execution_count", nb_man.nb.cells[-1]) - self.assertNotIn("outputs", nb_man.nb.cells[-1]) + self.assertNotIn('execution_count', nb_man.nb.cells[-1]) + self.assertNotIn('outputs', nb_man.nb.cells[-1]) def test_cell_start(self): nb_man = NotebookExecutionManager(self.nb) @@ -129,18 +124,16 @@ def test_cell_start(self): nb_man.save = Mock() nb_man.cell_start(cell) - self.assertEqual(cell.metadata.papermill["start_time"], fixed_now.isoformat()) - self.assertFalse(cell.metadata.papermill["exception"]) - self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.RUNNING - ) + self.assertEqual(cell.metadata.papermill['start_time'], fixed_now.isoformat()) + self.assertFalse(cell.metadata.papermill['exception']) + self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.RUNNING) nb_man.save.assert_called_once() def test_cell_start_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.cell_start(self.foo_nb.cells[0], nb=self.foo_nb) - self.assertEqual(nb_man.nb.metadata["foo"], "bar") + self.assertEqual(nb_man.nb.metadata['foo'], 'bar') def test_cell_exception(self): nb_man = NotebookExecutionManager(self.nb) @@ -148,16 +141,14 @@ def test_cell_exception(self): cell = nb_man.nb.cells[0] nb_man.cell_exception(cell) - self.assertEqual(nb_man.nb.metadata.papermill["exception"], True) - self.assertEqual(cell.metadata.papermill["exception"], True) - self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.FAILED - ) + self.assertEqual(nb_man.nb.metadata.papermill['exception'], True) + self.assertEqual(cell.metadata.papermill['exception'], True) + self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.FAILED) def test_cell_exception_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.cell_exception(self.foo_nb.cells[0], nb=self.foo_nb) - self.assertEqual(nb_man.nb.metadata["foo"], "bar") + self.assertEqual(nb_man.nb.metadata['foo'], 'bar') def test_cell_complete_after_cell_start(self): nb_man = NotebookExecutionManager(self.nb) @@ -173,18 +164,16 @@ def test_cell_complete_after_cell_start(self): nb_man.pbar = Mock() nb_man.cell_complete(cell) - self.assertIsNotNone(cell.metadata.papermill["start_time"]) - start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"]) + self.assertIsNotNone(cell.metadata.papermill['start_time']) + start_time = dateutil.parser.parse(cell.metadata.papermill['start_time']) - self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat()) + self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat()) self.assertEqual( - cell.metadata.papermill["duration"], + cell.metadata.papermill['duration'], (fixed_now - start_time).total_seconds(), ) - self.assertFalse(cell.metadata.papermill["exception"]) - self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED - ) + self.assertFalse(cell.metadata.papermill['exception']) + self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED) nb_man.save.assert_called_once() nb_man.pbar.update.assert_called_once() @@ -202,12 +191,10 @@ def test_cell_complete_without_cell_start(self): nb_man.pbar = Mock() nb_man.cell_complete(cell) - self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat()) - self.assertIsNone(cell.metadata.papermill["duration"]) - self.assertFalse(cell.metadata.papermill["exception"]) - self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED - ) + self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat()) + self.assertIsNone(cell.metadata.papermill['duration']) + self.assertFalse(cell.metadata.papermill['exception']) + self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED) nb_man.save.assert_called_once() nb_man.pbar.update.assert_called_once() @@ -227,18 +214,16 @@ def test_cell_complete_after_cell_exception(self): nb_man.pbar = Mock() nb_man.cell_complete(cell) - self.assertIsNotNone(cell.metadata.papermill["start_time"]) - start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"]) + self.assertIsNotNone(cell.metadata.papermill['start_time']) + start_time = dateutil.parser.parse(cell.metadata.papermill['start_time']) - self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat()) + self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat()) self.assertEqual( - cell.metadata.papermill["duration"], + cell.metadata.papermill['duration'], (fixed_now - start_time).total_seconds(), ) - self.assertTrue(cell.metadata.papermill["exception"]) - self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.FAILED - ) + self.assertTrue(cell.metadata.papermill['exception']) + self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.FAILED) nb_man.save.assert_called_once() nb_man.pbar.update.assert_called_once() @@ -247,9 +232,9 @@ def test_cell_complete_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start() baz_nb = copy.deepcopy(nb_man.nb) - baz_nb.metadata["baz"] = "buz" + baz_nb.metadata['baz'] = 'buz' nb_man.cell_complete(baz_nb.cells[0], nb=baz_nb) - self.assertEqual(nb_man.nb.metadata["baz"], "buz") + self.assertEqual(nb_man.nb.metadata['baz'], 'buz') def test_notebook_complete(self): nb_man = NotebookExecutionManager(self.nb) @@ -264,17 +249,15 @@ def test_notebook_complete(self): nb_man.notebook_complete() - self.assertIsNotNone(nb_man.nb.metadata.papermill["start_time"]) - start_time = dateutil.parser.parse(nb_man.nb.metadata.papermill["start_time"]) + self.assertIsNotNone(nb_man.nb.metadata.papermill['start_time']) + start_time = dateutil.parser.parse(nb_man.nb.metadata.papermill['start_time']) + self.assertEqual(nb_man.nb.metadata.papermill['end_time'], fixed_now.isoformat()) self.assertEqual( - nb_man.nb.metadata.papermill["end_time"], fixed_now.isoformat() - ) - self.assertEqual( - nb_man.nb.metadata.papermill["duration"], + nb_man.nb.metadata.papermill['duration'], (fixed_now - start_time).total_seconds(), ) - self.assertFalse(nb_man.nb.metadata.papermill["exception"]) + self.assertFalse(nb_man.nb.metadata.papermill['exception']) nb_man.save.assert_called_once() nb_man.cleanup_pbar.assert_called_once() @@ -283,18 +266,16 @@ def test_notebook_complete_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start() baz_nb = copy.deepcopy(nb_man.nb) - baz_nb.metadata["baz"] = "buz" + baz_nb.metadata['baz'] = 'buz' nb_man.notebook_complete(nb=baz_nb) - self.assertEqual(nb_man.nb.metadata["baz"], "buz") + self.assertEqual(nb_man.nb.metadata['baz'], 'buz') def test_notebook_complete_cell_status_completed(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start() nb_man.notebook_complete() for cell in nb_man.nb.cells: - self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED - ) + self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED) def test_notebook_complete_cell_status_with_failed(self): nb_man = NotebookExecutionManager(self.nb) @@ -302,22 +283,20 @@ def test_notebook_complete_cell_status_with_failed(self): nb_man.cell_exception(nb_man.nb.cells[1]) nb_man.notebook_complete() self.assertEqual( - nb_man.nb.cells[0].metadata.papermill["status"], + nb_man.nb.cells[0].metadata.papermill['status'], NotebookExecutionManager.COMPLETED, ) self.assertEqual( - nb_man.nb.cells[1].metadata.papermill["status"], + nb_man.nb.cells[1].metadata.papermill['status'], NotebookExecutionManager.FAILED, ) for cell in nb_man.nb.cells[2:]: - self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.PENDING - ) + self.assertEqual(cell.metadata.papermill['status'], NotebookExecutionManager.PENDING) class TestEngineBase(unittest.TestCase): def setUp(self): - self.notebook_name = "simple_execute.ipynb" + self.notebook_name = 'simple_execute.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb = load_notebook_node(self.notebook_path) @@ -326,28 +305,26 @@ def test_wrap_and_execute_notebook(self): Mocks each wrapped call and proves the correct inputs get applied to the correct underlying calls for execute_notebook. """ - with patch.object(Engine, "execute_managed_notebook") as exec_mock: - with patch.object(engines, "NotebookExecutionManager") as wrap_mock: + with patch.object(Engine, 'execute_managed_notebook') as exec_mock: + with patch.object(engines, 'NotebookExecutionManager') as wrap_mock: Engine.execute_notebook( self.nb, - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=True, - bar="baz", + bar='baz', ) wrap_mock.assert_called_once_with( self.nb, - output_path="foo.ipynb", + output_path='foo.ipynb', progress_bar=False, log_output=True, autosave_cell_every=30, ) wrap_mock.return_value.notebook_start.assert_called_once() - exec_mock.assert_called_once_with( - wrap_mock.return_value, "python", log_output=True, bar="baz" - ) + exec_mock.assert_called_once_with(wrap_mock.return_value, 'python', log_output=True, bar='baz') wrap_mock.return_value.notebook_complete.assert_called_once() wrap_mock.return_value.cleanup_pbar.assert_called_once() @@ -359,28 +336,26 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): nb_man.cell_start(cell) nb_man.cell_complete(cell) - with patch.object(NotebookExecutionManager, "save") as save_mock: - nb = CellCallbackEngine.execute_notebook( - copy.deepcopy(self.nb), "python", output_path="foo.ipynb" - ) + with patch.object(NotebookExecutionManager, 'save') as save_mock: + nb = CellCallbackEngine.execute_notebook(copy.deepcopy(self.nb), 'python', output_path='foo.ipynb') self.assertEqual(nb, AnyMock(NotebookNode)) self.assertNotEqual(self.nb, nb) self.assertEqual(save_mock.call_count, 8) - self.assertIsNotNone(nb.metadata.papermill["start_time"]) - self.assertIsNotNone(nb.metadata.papermill["end_time"]) - self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(nb.metadata.papermill["exception"]) + self.assertIsNotNone(nb.metadata.papermill['start_time']) + self.assertIsNotNone(nb.metadata.papermill['end_time']) + self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(nb.metadata.papermill['exception']) for cell in nb.cells: - self.assertIsNotNone(cell.metadata.papermill["start_time"]) - self.assertIsNotNone(cell.metadata.papermill["end_time"]) - self.assertEqual(cell.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(cell.metadata.papermill["exception"]) + self.assertIsNotNone(cell.metadata.papermill['start_time']) + self.assertIsNotNone(cell.metadata.papermill['end_time']) + self.assertEqual(cell.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED, ) @@ -390,13 +365,9 @@ class NoCellCallbackEngine(Engine): def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): pass - with patch.object(NotebookExecutionManager, "save") as save_mock: - with patch.object( - NotebookExecutionManager, "complete_pbar" - ) as pbar_comp_mock: - nb = NoCellCallbackEngine.execute_notebook( - copy.deepcopy(self.nb), "python", output_path="foo.ipynb" - ) + with patch.object(NotebookExecutionManager, 'save') as save_mock: + with patch.object(NotebookExecutionManager, 'complete_pbar') as pbar_comp_mock: + nb = NoCellCallbackEngine.execute_notebook(copy.deepcopy(self.nb), 'python', output_path='foo.ipynb') self.assertEqual(nb, AnyMock(NotebookNode)) self.assertNotEqual(self.nb, nb) @@ -404,38 +375,38 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): self.assertEqual(save_mock.call_count, 2) pbar_comp_mock.assert_called_once() - self.assertIsNotNone(nb.metadata.papermill["start_time"]) - self.assertIsNotNone(nb.metadata.papermill["end_time"]) - self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(nb.metadata.papermill["exception"]) + self.assertIsNotNone(nb.metadata.papermill['start_time']) + self.assertIsNotNone(nb.metadata.papermill['end_time']) + self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(nb.metadata.papermill['exception']) for cell in nb.cells: - self.assertIsNone(cell.metadata.papermill["start_time"]) - self.assertIsNone(cell.metadata.papermill["end_time"]) - self.assertIsNone(cell.metadata.papermill["duration"]) - self.assertIsNone(cell.metadata.papermill["exception"]) + self.assertIsNone(cell.metadata.papermill['start_time']) + self.assertIsNone(cell.metadata.papermill['end_time']) + self.assertIsNone(cell.metadata.papermill['duration']) + self.assertIsNone(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED, ) class TestNBClientEngine(unittest.TestCase): def setUp(self): - self.notebook_name = "simple_execute.ipynb" + self.notebook_name = 'simple_execute.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb = load_notebook_node(self.notebook_path) def test_nb_convert_engine(self): - with patch.object(engines, "PapermillNotebookClient") as client_mock: - with patch.object(NotebookExecutionManager, "save") as save_mock: + with patch.object(engines, 'PapermillNotebookClient') as client_mock: + with patch.object(NotebookExecutionManager, 'save') as save_mock: nb = NBClientEngine.execute_notebook( copy.deepcopy(self.nb), - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=True, - bar="baz", + bar='baz', start_timeout=30, execution_timeout=1000, ) @@ -447,16 +418,14 @@ def test_nb_convert_engine(self): args, kwargs = client_mock.call_args expected = [ - ("timeout", 1000), - ("startup_timeout", 30), - ("kernel_name", "python"), - ("log", logger), - ("log_output", True), + ('timeout', 1000), + ('startup_timeout', 30), + ('kernel_name', 'python'), + ('log', logger), + ('log_output', True), ] actual = {(key, kwargs[key]) for key in kwargs} - msg = ( - f"Expected arguments {expected} are not a subset of actual {actual}" - ) + msg = f'Expected arguments {expected} are not a subset of actual {actual}' self.assertTrue(set(expected).issubset(actual), msg=msg) client_mock.return_value.execute.assert_called_once_with() @@ -464,71 +433,63 @@ def test_nb_convert_engine(self): self.assertEqual(save_mock.call_count, 2) def test_nb_convert_engine_execute(self): - with patch.object(NotebookExecutionManager, "save") as save_mock: + with patch.object(NotebookExecutionManager, 'save') as save_mock: nb = NBClientEngine.execute_notebook( self.nb, - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=True, ) self.assertEqual(save_mock.call_count, 8) self.assertEqual(nb, AnyMock(NotebookNode)) - self.assertIsNotNone(nb.metadata.papermill["start_time"]) - self.assertIsNotNone(nb.metadata.papermill["end_time"]) - self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(nb.metadata.papermill["exception"]) + self.assertIsNotNone(nb.metadata.papermill['start_time']) + self.assertIsNotNone(nb.metadata.papermill['end_time']) + self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(nb.metadata.papermill['exception']) for cell in nb.cells: - self.assertIsNotNone(cell.metadata.papermill["start_time"]) - self.assertIsNotNone(cell.metadata.papermill["end_time"]) - self.assertEqual(cell.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(cell.metadata.papermill["exception"]) + self.assertIsNotNone(cell.metadata.papermill['start_time']) + self.assertIsNotNone(cell.metadata.papermill['end_time']) + self.assertEqual(cell.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED, ) def test_nb_convert_log_outputs(self): - with patch.object(logger, "info") as info_mock: - with patch.object(logger, "warning") as warning_mock: - with patch.object(NotebookExecutionManager, "save"): + with patch.object(logger, 'info') as info_mock: + with patch.object(logger, 'warning') as warning_mock: + with patch.object(NotebookExecutionManager, 'save'): NBClientEngine.execute_notebook( self.nb, - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=True, ) info_mock.assert_has_calls( [ - call("Executing notebook with kernel: python"), - call( - "Executing Cell 1---------------------------------------" - ), - call( - "Ending Cell 1------------------------------------------" - ), - call( - "Executing Cell 2---------------------------------------" - ), - call("None\n"), - call( - "Ending Cell 2------------------------------------------" - ), + call('Executing notebook with kernel: python'), + call('Executing Cell 1---------------------------------------'), + call('Ending Cell 1------------------------------------------'), + call('Executing Cell 2---------------------------------------'), + call('None\n'), + call('Ending Cell 2------------------------------------------'), ] ) warning_mock.is_not_called() def test_nb_convert_no_log_outputs(self): - with patch.object(logger, "info") as info_mock: - with patch.object(logger, "warning") as warning_mock: - with patch.object(NotebookExecutionManager, "save"): + with patch.object(logger, 'info') as info_mock: + with patch.object(logger, 'warning') as warning_mock: + with patch.object(NotebookExecutionManager, 'save'): NBClientEngine.execute_notebook( self.nb, - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=False, ) @@ -542,33 +503,31 @@ def setUp(self): def test_registration(self): mock_engine = Mock() - self.papermill_engines.register("mock_engine", mock_engine) - self.assertIn("mock_engine", self.papermill_engines._engines) - self.assertIs(mock_engine, self.papermill_engines._engines["mock_engine"]) + self.papermill_engines.register('mock_engine', mock_engine) + self.assertIn('mock_engine', self.papermill_engines._engines) + self.assertIs(mock_engine, self.papermill_engines._engines['mock_engine']) def test_getting(self): mock_engine = Mock() - self.papermill_engines.register("mock_engine", mock_engine) + self.papermill_engines.register('mock_engine', mock_engine) # test retrieving an engine works - retrieved_engine = self.papermill_engines.get_engine("mock_engine") + retrieved_engine = self.papermill_engines.get_engine('mock_engine') self.assertIs(mock_engine, retrieved_engine) # test you can't retrieve a non-registered engine self.assertRaises( exceptions.PapermillException, self.papermill_engines.get_engine, - "non-existent", + 'non-existent', ) def test_registering_entry_points(self): fake_entrypoint = Mock(load=Mock()) - fake_entrypoint.name = "fake-engine" + fake_entrypoint.name = 'fake-engine' - with patch( - "entrypoints.get_group_all", return_value=[fake_entrypoint] - ) as mock_get_group_all: + with patch('entrypoints.get_group_all', return_value=[fake_entrypoint]) as mock_get_group_all: self.papermill_engines.register_entry_points() - mock_get_group_all.assert_called_once_with("papermill.engine") + mock_get_group_all.assert_called_once_with('papermill.engine') self.assertEqual( - self.papermill_engines.get_engine("fake-engine"), + self.papermill_engines.get_engine('fake-engine'), fake_entrypoint.load.return_value, ) diff --git a/papermill/tests/test_exceptions.py b/papermill/tests/test_exceptions.py index 9c555942..191767fb 100644 --- a/papermill/tests/test_exceptions.py +++ b/papermill/tests/test_exceptions.py @@ -12,29 +12,29 @@ def temp_file(): """NamedTemporaryFile must be set in wb mode, closed without delete, opened with open(file, "rb"), then manually deleted. Otherwise, file fails to be read due to permission error on Windows. """ - with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: + with tempfile.NamedTemporaryFile(mode='wb', delete=False) as f: yield f os.unlink(f.name) @pytest.mark.parametrize( - "exc,args", + 'exc,args', [ ( exceptions.PapermillExecutionError, - (1, 2, "TestSource", "Exception", Exception(), ["Traceback", "Message"]), + (1, 2, 'TestSource', 'Exception', Exception(), ['Traceback', 'Message']), ), ( exceptions.PapermillMissingParameterException, - ("PapermillMissingParameterException",), + ('PapermillMissingParameterException',), ), - (exceptions.AwsError, ("AwsError",)), - (exceptions.FileExistsError, ("FileExistsError",)), - (exceptions.PapermillException, ("PapermillException",)), - (exceptions.PapermillRateLimitException, ("PapermillRateLimitException",)), + (exceptions.AwsError, ('AwsError',)), + (exceptions.FileExistsError, ('FileExistsError',)), + (exceptions.PapermillException, ('PapermillException',)), + (exceptions.PapermillRateLimitException, ('PapermillRateLimitException',)), ( exceptions.PapermillOptionalDependencyException, - ("PapermillOptionalDependencyException",), + ('PapermillOptionalDependencyException',), ), ], ) @@ -45,7 +45,7 @@ def test_exceptions_are_unpickleable(temp_file, exc, args): temp_file.close() # close to re-open for reading # Read the Pickled File - with open(temp_file.name, "rb") as read_file: + with open(temp_file.name, 'rb') as read_file: read_file.seek(0) data = read_file.read() pickled_err = pickle.loads(data) diff --git a/papermill/tests/test_execute.py b/papermill/tests/test_execute.py index 350d9b0f..6396de35 100644 --- a/papermill/tests/test_execute.py +++ b/papermill/tests/test_execute.py @@ -3,20 +3,19 @@ import tempfile import unittest from copy import deepcopy -from unittest.mock import patch, ANY - from functools import partial from pathlib import Path +from unittest.mock import ANY, patch import nbformat from nbformat import validate from .. import engines, translators -from ..log import logger +from ..exceptions import PapermillExecutionError +from ..execute import execute_notebook from ..iorw import load_notebook_node +from ..log import logger from ..utils import chdir -from ..execute import execute_notebook -from ..exceptions import PapermillExecutionError from . import get_notebook_path, kernel_name execute_notebook = partial(execute_notebook, kernel_name=kernel_name) @@ -25,132 +24,112 @@ class TestNotebookHelpers(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() - self.notebook_name = "simple_execute.ipynb" + self.notebook_name = 'simple_execute.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) - self.nb_test_executed_fname = os.path.join( - self.test_dir, f"output_{self.notebook_name}" - ) + self.nb_test_executed_fname = os.path.join(self.test_dir, f'output_{self.notebook_name}') def tearDown(self): shutil.rmtree(self.test_dir) - @patch(engines.__name__ + ".PapermillNotebookClient") + @patch(engines.__name__ + '.PapermillNotebookClient') def test_start_timeout(self, preproc_mock): - execute_notebook( - self.notebook_path, self.nb_test_executed_fname, start_timeout=123 - ) + execute_notebook(self.notebook_path, self.nb_test_executed_fname, start_timeout=123) args, kwargs = preproc_mock.call_args expected = [ - ("timeout", None), - ("startup_timeout", 123), - ("kernel_name", kernel_name), - ("log", logger), + ('timeout', None), + ('startup_timeout', 123), + ('kernel_name', kernel_name), + ('log', logger), ] actual = {(key, kwargs[key]) for key in kwargs} self.assertTrue( set(expected).issubset(actual), - msg=f"Expected arguments {expected} are not a subset of actual {actual}", + msg=f'Expected arguments {expected} are not a subset of actual {actual}', ) - @patch(engines.__name__ + ".PapermillNotebookClient") + @patch(engines.__name__ + '.PapermillNotebookClient') def test_default_start_timeout(self, preproc_mock): execute_notebook(self.notebook_path, self.nb_test_executed_fname) args, kwargs = preproc_mock.call_args expected = [ - ("timeout", None), - ("startup_timeout", 60), - ("kernel_name", kernel_name), - ("log", logger), + ('timeout', None), + ('startup_timeout', 60), + ('kernel_name', kernel_name), + ('log', logger), ] actual = {(key, kwargs[key]) for key in kwargs} self.assertTrue( set(expected).issubset(actual), - msg=f"Expected arguments {expected} are not a subset of actual {actual}", + msg=f'Expected arguments {expected} are not a subset of actual {actual}', ) def test_cell_insertion(self): - execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"msg": "Hello"} - ) + execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'msg': 'Hello'}) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", 'msg = "Hello"', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', 'msg = "Hello"', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'}) def test_no_tags(self): - notebook_name = "no_parameters.ipynb" - nb_test_executed_fname = os.path.join(self.test_dir, f"output_{notebook_name}") - execute_notebook( - get_notebook_path(notebook_name), nb_test_executed_fname, {"msg": "Hello"} - ) + notebook_name = 'no_parameters.ipynb' + nb_test_executed_fname = os.path.join(self.test_dir, f'output_{notebook_name}') + execute_notebook(get_notebook_path(notebook_name), nb_test_executed_fname, {'msg': 'Hello'}) test_nb = load_notebook_node(nb_test_executed_fname) self.assertListEqual( - test_nb.cells[0].get("source").split("\n"), - ["# Parameters", 'msg = "Hello"', ""], + test_nb.cells[0].get('source').split('\n'), + ['# Parameters', 'msg = "Hello"', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'}) def test_quoted_params(self): - execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"msg": '"Hello"'} - ) + execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'msg': '"Hello"'}) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", r'msg = "\"Hello\""', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', r'msg = "\"Hello\""', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": '"Hello"'}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': '"Hello"'}) def test_backslash_params(self): - execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"foo": r"do\ not\ crash"} - ) + execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'foo': r'do\ not\ crash'}) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", r'foo = "do\\ not\\ crash"', ""], - ) - self.assertEqual( - test_nb.metadata.papermill.parameters, {"foo": r"do\ not\ crash"} + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', r'foo = "do\\ not\\ crash"', ''], ) + self.assertEqual(test_nb.metadata.papermill.parameters, {'foo': r'do\ not\ crash'}) def test_backslash_quote_params(self): - execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"foo": r"bar=\"baz\""} - ) + execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'foo': r'bar=\"baz\"'}) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", r'foo = "bar=\\\"baz\\\""', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', r'foo = "bar=\\\"baz\\\""', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"foo": r"bar=\"baz\""}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'foo': r'bar=\"baz\"'}) def test_double_backslash_quote_params(self): - execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"foo": r'\\"bar\\"'} - ) + execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'foo': r'\\"bar\\"'}) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", r'foo = "\\\\\"bar\\\\\""', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', r'foo = "\\\\\"bar\\\\\""', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"foo": r'\\"bar\\"'}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'foo': r'\\"bar\\"'}) def test_prepare_only(self): - for example in ["broken1.ipynb", "keyboard_interrupt.ipynb"]: + for example in ['broken1.ipynb', 'keyboard_interrupt.ipynb']: path = get_notebook_path(example) result_path = os.path.join(self.test_dir, example) # Should not raise as we don't execute the notebook at all - execute_notebook( - path, result_path, {"foo": r"do\ not\ crash"}, prepare_only=True - ) + execute_notebook(path, result_path, {'foo': r'do\ not\ crash'}, prepare_only=True) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "code") + self.assertEqual(nb.cells[0].cell_type, 'code') self.assertEqual( - nb.cells[0].get("source").split("\n"), - ["# Parameters", r'foo = "do\\ not\\ crash"', ""], + nb.cells[0].get('source').split('\n'), + ['# Parameters', r'foo = "do\\ not\\ crash"', ''], ) @@ -162,52 +141,43 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test(self): - path = get_notebook_path("broken1.ipynb") + path = get_notebook_path('broken1.ipynb') # check that the notebook has two existing marker cells, so that this test is sure to be # validating the removal logic (the markers are simulatin an error in the first code cell # that has since been fixed) original_nb = load_notebook_node(path) - self.assertEqual( - original_nb.cells[0].metadata["tags"], ["papermill-error-cell-tag"] - ) - self.assertIn("In [1]", original_nb.cells[0].source) - self.assertEqual( - original_nb.cells[2].metadata["tags"], ["papermill-error-cell-tag"] - ) + self.assertEqual(original_nb.cells[0].metadata['tags'], ['papermill-error-cell-tag']) + self.assertIn('In [1]', original_nb.cells[0].source) + self.assertEqual(original_nb.cells[2].metadata['tags'], ['papermill-error-cell-tag']) - result_path = os.path.join(self.test_dir, "broken1.ipynb") + result_path = os.path.join(self.test_dir, 'broken1.ipynb') with self.assertRaises(PapermillExecutionError): execute_notebook(path, result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "markdown") + self.assertEqual(nb.cells[0].cell_type, 'markdown') self.assertRegex( nb.cells[0].source, r'^$', ) - self.assertEqual(nb.cells[0].metadata["tags"], ["papermill-error-cell-tag"]) + self.assertEqual(nb.cells[0].metadata['tags'], ['papermill-error-cell-tag']) - self.assertEqual(nb.cells[1].cell_type, "markdown") + self.assertEqual(nb.cells[1].cell_type, 'markdown') self.assertEqual(nb.cells[2].execution_count, 1) - self.assertEqual(nb.cells[3].cell_type, "markdown") - self.assertEqual(nb.cells[4].cell_type, "markdown") + self.assertEqual(nb.cells[3].cell_type, 'markdown') + self.assertEqual(nb.cells[4].cell_type, 'markdown') - self.assertEqual(nb.cells[5].cell_type, "markdown") - self.assertRegex( - nb.cells[5].source, '' - ) - self.assertEqual(nb.cells[5].metadata["tags"], ["papermill-error-cell-tag"]) + self.assertEqual(nb.cells[5].cell_type, 'markdown') + self.assertRegex(nb.cells[5].source, '') + self.assertEqual(nb.cells[5].metadata['tags'], ['papermill-error-cell-tag']) self.assertEqual(nb.cells[6].execution_count, 2) - self.assertEqual(nb.cells[6].outputs[0].output_type, "error") + self.assertEqual(nb.cells[6].outputs[0].output_type, 'error') self.assertEqual(nb.cells[7].execution_count, None) # double check the removal (the new cells above should be the only two tagged ones) self.assertEqual( - sum( - "papermill-error-cell-tag" in cell.metadata.get("tags", []) - for cell in nb.cells - ), + sum('papermill-error-cell-tag' in cell.metadata.get('tags', []) for cell in nb.cells), 2, ) @@ -220,25 +190,23 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test(self): - path = get_notebook_path("broken2.ipynb") - result_path = os.path.join(self.test_dir, "broken2.ipynb") + path = get_notebook_path('broken2.ipynb') + result_path = os.path.join(self.test_dir, 'broken2.ipynb') with self.assertRaises(PapermillExecutionError): execute_notebook(path, result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "markdown") + self.assertEqual(nb.cells[0].cell_type, 'markdown') self.assertRegex( nb.cells[0].source, r'^.*In \[2\].*$', ) self.assertEqual(nb.cells[1].execution_count, 1) - self.assertEqual(nb.cells[2].cell_type, "markdown") - self.assertRegex( - nb.cells[2].source, '' - ) + self.assertEqual(nb.cells[2].cell_type, 'markdown') + self.assertRegex(nb.cells[2].source, '') self.assertEqual(nb.cells[3].execution_count, 2) - self.assertEqual(nb.cells[3].outputs[0].output_type, "display_data") - self.assertEqual(nb.cells[3].outputs[1].output_type, "error") + self.assertEqual(nb.cells[3].outputs[0].output_type, 'display_data') + self.assertEqual(nb.cells[3].outputs[1].output_type, 'error') self.assertEqual(nb.cells[4].execution_count, None) @@ -246,33 +214,25 @@ def test(self): class TestReportMode(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() - self.notebook_name = "report_mode_test.ipynb" + self.notebook_name = 'report_mode_test.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) - self.nb_test_executed_fname = os.path.join( - self.test_dir, f"output_{self.notebook_name}" - ) + self.nb_test_executed_fname = os.path.join(self.test_dir, f'output_{self.notebook_name}') def tearDown(self): shutil.rmtree(self.test_dir) def test_report_mode(self): - nb = execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"a": 0}, report_mode=True - ) + nb = execute_notebook(self.notebook_path, self.nb_test_executed_fname, {'a': 0}, report_mode=True) for cell in nb.cells: - if cell.cell_type == "code": - self.assertEqual( - cell.metadata.get("jupyter", {}).get("source_hidden"), True - ) + if cell.cell_type == 'code': + self.assertEqual(cell.metadata.get('jupyter', {}).get('source_hidden'), True) class TestOutputPathNone(unittest.TestCase): def test_output_path_of_none(self): """Output path of None should return notebook node obj but not write an ipynb""" - nb = execute_notebook( - get_notebook_path("simple_execute.ipynb"), None, {"msg": "Hello"} - ) - self.assertEqual(nb.metadata.papermill.parameters, {"msg": "Hello"}) + nb = execute_notebook(get_notebook_path('simple_execute.ipynb'), None, {'msg': 'Hello'}) + self.assertEqual(nb.metadata.papermill.parameters, {'msg': 'Hello'}) class TestCWD(unittest.TestCase): @@ -280,26 +240,20 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() self.base_test_dir = tempfile.mkdtemp() - self.check_notebook_name = "read_check.ipynb" - self.check_notebook_path = os.path.join(self.base_test_dir, "read_check.ipynb") + self.check_notebook_name = 'read_check.ipynb' + self.check_notebook_path = os.path.join(self.base_test_dir, 'read_check.ipynb') # Setup read paths so base_test_dir has check_notebook_name - shutil.copyfile( - get_notebook_path(self.check_notebook_name), self.check_notebook_path - ) - with open(os.path.join(self.test_dir, "check.txt"), "w", encoding="utf-8") as f: + shutil.copyfile(get_notebook_path(self.check_notebook_name), self.check_notebook_path) + with open(os.path.join(self.test_dir, 'check.txt'), 'w', encoding='utf-8') as f: # Needed for read_check to pass - f.write("exists") + f.write('exists') - self.simple_notebook_name = "simple_execute.ipynb" - self.simple_notebook_path = os.path.join( - self.base_test_dir, "simple_execute.ipynb" - ) + self.simple_notebook_name = 'simple_execute.ipynb' + self.simple_notebook_path = os.path.join(self.base_test_dir, 'simple_execute.ipynb') # Setup read paths so base_test_dir has simple_notebook_name - shutil.copyfile( - get_notebook_path(self.simple_notebook_name), self.simple_notebook_path - ) + shutil.copyfile(get_notebook_path(self.simple_notebook_name), self.simple_notebook_path) - self.nb_test_executed_fname = "test_output.ipynb" + self.nb_test_executed_fname = 'test_output.ipynb' def tearDown(self): shutil.rmtree(self.test_dir) @@ -313,23 +267,13 @@ def test_local_save_ignores_cwd_assignment(self): self.nb_test_executed_fname, cwd=self.test_dir, ) - self.assertTrue( - os.path.isfile( - os.path.join(self.base_test_dir, self.nb_test_executed_fname) - ) - ) + self.assertTrue(os.path.isfile(os.path.join(self.base_test_dir, self.nb_test_executed_fname))) def test_execution_respects_cwd_assignment(self): with chdir(self.base_test_dir): # Both paths are relative - execute_notebook( - self.check_notebook_name, self.nb_test_executed_fname, cwd=self.test_dir - ) - self.assertTrue( - os.path.isfile( - os.path.join(self.base_test_dir, self.nb_test_executed_fname) - ) - ) + execute_notebook(self.check_notebook_name, self.nb_test_executed_fname, cwd=self.test_dir) + self.assertTrue(os.path.isfile(os.path.join(self.base_test_dir, self.nb_test_executed_fname))) def test_pathlib_paths(self): # Copy of test_execution_respects_cwd_assignment but with `Path`s @@ -339,9 +283,7 @@ def test_pathlib_paths(self): Path(self.nb_test_executed_fname), cwd=Path(self.test_dir), ) - self.assertTrue( - Path(self.base_test_dir).joinpath(self.nb_test_executed_fname).exists() - ) + self.assertTrue(Path(self.base_test_dir).joinpath(self.nb_test_executed_fname).exists()) class TestSysExit(unittest.TestCase): @@ -352,64 +294,62 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test_sys_exit(self): - notebook_name = "sysexit.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'sysexit.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook(get_notebook_path(notebook_name), result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "code") + self.assertEqual(nb.cells[0].cell_type, 'code') self.assertEqual(nb.cells[0].execution_count, 1) self.assertEqual(nb.cells[1].execution_count, 2) - self.assertEqual(nb.cells[1].outputs[0].output_type, "error") - self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit") - self.assertEqual(nb.cells[1].outputs[0].evalue, "") + self.assertEqual(nb.cells[1].outputs[0].output_type, 'error') + self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit') + self.assertEqual(nb.cells[1].outputs[0].evalue, '') self.assertEqual(nb.cells[2].execution_count, None) def test_sys_exit0(self): - notebook_name = "sysexit0.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'sysexit0.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook(get_notebook_path(notebook_name), result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "code") + self.assertEqual(nb.cells[0].cell_type, 'code') self.assertEqual(nb.cells[0].execution_count, 1) self.assertEqual(nb.cells[1].execution_count, 2) - self.assertEqual(nb.cells[1].outputs[0].output_type, "error") - self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit") - self.assertEqual(nb.cells[1].outputs[0].evalue, "0") + self.assertEqual(nb.cells[1].outputs[0].output_type, 'error') + self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit') + self.assertEqual(nb.cells[1].outputs[0].evalue, '0') self.assertEqual(nb.cells[2].execution_count, None) def test_sys_exit1(self): - notebook_name = "sysexit1.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'sysexit1.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') with self.assertRaises(PapermillExecutionError): execute_notebook(get_notebook_path(notebook_name), result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "markdown") + self.assertEqual(nb.cells[0].cell_type, 'markdown') self.assertRegex( nb.cells[0].source, r'^$', ) self.assertEqual(nb.cells[1].execution_count, 1) - self.assertEqual(nb.cells[2].cell_type, "markdown") - self.assertRegex( - nb.cells[2].source, '' - ) + self.assertEqual(nb.cells[2].cell_type, 'markdown') + self.assertRegex(nb.cells[2].source, '') self.assertEqual(nb.cells[3].execution_count, 2) - self.assertEqual(nb.cells[3].outputs[0].output_type, "error") + self.assertEqual(nb.cells[3].outputs[0].output_type, 'error') self.assertEqual(nb.cells[4].execution_count, None) def test_system_exit(self): - notebook_name = "systemexit.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'systemexit.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook(get_notebook_path(notebook_name), result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "code") + self.assertEqual(nb.cells[0].cell_type, 'code') self.assertEqual(nb.cells[0].execution_count, 1) self.assertEqual(nb.cells[1].execution_count, 2) - self.assertEqual(nb.cells[1].outputs[0].output_type, "error") - self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit") - self.assertEqual(nb.cells[1].outputs[0].evalue, "") + self.assertEqual(nb.cells[1].outputs[0].output_type, 'error') + self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit') + self.assertEqual(nb.cells[1].outputs[0].evalue, '') self.assertEqual(nb.cells[2].execution_count, None) @@ -421,11 +361,9 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test_from_version_4_4_upgrades(self): - notebook_name = "nb_version_4.4.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") - execute_notebook( - get_notebook_path(notebook_name), result_path, {"var": "It works"} - ) + notebook_name = 'nb_version_4.4.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') + execute_notebook(get_notebook_path(notebook_name), result_path, {'var': 'It works'}) nb = load_notebook_node(result_path) validate(nb) @@ -438,11 +376,9 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test_no_v3_language_backport(self): - notebook_name = "blank-vscode.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") - execute_notebook( - get_notebook_path(notebook_name), result_path, {"var": "It works"} - ) + notebook_name = 'blank-vscode.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') + execute_notebook(get_notebook_path(notebook_name), result_path, {'var': 'It works'}) nb = load_notebook_node(result_path) validate(nb) @@ -455,25 +391,21 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): @classmethod def nb_kernel_name(cls, nb, name=None): - return "my_custom_kernel" + return 'my_custom_kernel' @classmethod def nb_language(cls, nb, language=None): - return "my_custom_language" + return 'my_custom_language' def setUp(self): self.test_dir = tempfile.mkdtemp() - self.notebook_path = get_notebook_path("simple_execute.ipynb") - self.nb_test_executed_fname = os.path.join( - self.test_dir, "output_{}".format("simple_execute.ipynb") - ) + self.notebook_path = get_notebook_path('simple_execute.ipynb') + self.nb_test_executed_fname = os.path.join(self.test_dir, 'output_{}'.format('simple_execute.ipynb')) self._orig_papermill_engines = deepcopy(engines.papermill_engines) self._orig_translators = deepcopy(translators.papermill_translators) - engines.papermill_engines.register("custom_engine", self.CustomEngine) - translators.papermill_translators.register( - "my_custom_language", translators.PythonTranslator() - ) + engines.papermill_engines.register('custom_engine', self.CustomEngine) + translators.papermill_translators.register('my_custom_language', translators.PythonTranslator()) def tearDown(self): shutil.rmtree(self.test_dir) @@ -482,46 +414,40 @@ def tearDown(self): @patch.object( CustomEngine, - "execute_managed_notebook", + 'execute_managed_notebook', wraps=CustomEngine.execute_managed_notebook, ) @patch( - "papermill.parameterize.translate_parameters", + 'papermill.parameterize.translate_parameters', wraps=translators.translate_parameters, ) - def test_custom_kernel_name_and_language( - self, translate_parameters, execute_managed_notebook - ): + def test_custom_kernel_name_and_language(self, translate_parameters, execute_managed_notebook): """Tests execute against engine with custom implementations to fetch kernel name and language from the notebook object """ execute_notebook( self.notebook_path, self.nb_test_executed_fname, - engine_name="custom_engine", - parameters={"msg": "fake msg"}, - ) - self.assertEqual( - execute_managed_notebook.call_args[0], (ANY, "my_custom_kernel") + engine_name='custom_engine', + parameters={'msg': 'fake msg'}, ) + self.assertEqual(execute_managed_notebook.call_args[0], (ANY, 'my_custom_kernel')) self.assertEqual( translate_parameters.call_args[0], - (ANY, "my_custom_language", {"msg": "fake msg"}, ANY), + (ANY, 'my_custom_language', {'msg': 'fake msg'}, ANY), ) class TestNotebookNodeInput(unittest.TestCase): def setUp(self): self.test_dir = tempfile.TemporaryDirectory() - self.result_path = os.path.join(self.test_dir.name, "output.ipynb") + self.result_path = os.path.join(self.test_dir.name, 'output.ipynb') def tearDown(self): self.test_dir.cleanup() def test_notebook_node_input(self): - input_nb = nbformat.read( - get_notebook_path("simple_execute.ipynb"), as_version=4 - ) - execute_notebook(input_nb, self.result_path, {"msg": "Hello"}) + input_nb = nbformat.read(get_notebook_path('simple_execute.ipynb'), as_version=4) + execute_notebook(input_nb, self.result_path, {'msg': 'Hello'}) test_nb = nbformat.read(self.result_path, as_version=4) - self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'}) diff --git a/papermill/tests/test_gcs.py b/papermill/tests/test_gcs.py index 280deb8f..61de47b5 100644 --- a/papermill/tests/test_gcs.py +++ b/papermill/tests/test_gcs.py @@ -69,124 +69,100 @@ class GCSTest(unittest.TestCase): def setUp(self): self.gcs_handler = GCSHandler() - @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper()) + @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper()) def test_gcs_read(self, mock_gcs_filesystem): client = self.gcs_handler._get_client() - self.assertEqual(self.gcs_handler.read("gs://bucket/test.ipynb"), 1) + self.assertEqual(self.gcs_handler.read('gs://bucket/test.ipynb'), 1) # Check that client is only generated once self.assertIs(client, self.gcs_handler._get_client()) - @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper()) + @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper()) def test_gcs_write(self, mock_gcs_filesystem): client = self.gcs_handler._get_client() - self.assertEqual( - self.gcs_handler.write("new value", "gs://bucket/test.ipynb"), 1 - ) + self.assertEqual(self.gcs_handler.write('new value', 'gs://bucket/test.ipynb'), 1) # Check that client is only generated once self.assertIs(client, self.gcs_handler._get_client()) - @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper()) + @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper()) def test_gcs_listdir(self, mock_gcs_filesystem): client = self.gcs_handler._get_client() - self.gcs_handler.listdir("testdir") + self.gcs_handler.listdir('testdir') # Check that client is only generated once self.assertIs(client, self.gcs_handler._get_client()) @patch( - "papermill.iorw.GCSFileSystem", - side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": 429}), 10 - ), + 'papermill.iorw.GCSFileSystem', + side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': 429}), 10), ) def test_gcs_handle_exception(self, mock_gcs_filesystem): - with patch.object(GCSHandler, "RETRY_DELAY", 0): - with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0): - with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0): + with patch.object(GCSHandler, 'RETRY_DELAY', 0): + with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0): + with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0): with self.assertRaises(PapermillRateLimitException): - self.gcs_handler.write( - "raise_limit_exception", "gs://bucket/test.ipynb" - ) + self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb') @patch( - "papermill.iorw.GCSFileSystem", - side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": 429}), 1 - ), + 'papermill.iorw.GCSFileSystem', + side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': 429}), 1), ) def test_gcs_retry(self, mock_gcs_filesystem): - with patch.object(GCSHandler, "RETRY_DELAY", 0): - with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0): - with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0): + with patch.object(GCSHandler, 'RETRY_DELAY', 0): + with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0): + with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0): self.assertEqual( - self.gcs_handler.write( - "raise_limit_exception", "gs://bucket/test.ipynb" - ), + self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2, ) @patch( - "papermill.iorw.GCSFileSystem", - side_effect=mock_gcs_fs_wrapper( - GCSHttpError({"message": "test", "code": 429}), 1 - ), + 'papermill.iorw.GCSFileSystem', + side_effect=mock_gcs_fs_wrapper(GCSHttpError({'message': 'test', 'code': 429}), 1), ) def test_gcs_retry_older_exception(self, mock_gcs_filesystem): - with patch.object(GCSHandler, "RETRY_DELAY", 0): - with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0): - with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0): + with patch.object(GCSHandler, 'RETRY_DELAY', 0): + with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0): + with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0): self.assertEqual( - self.gcs_handler.write( - "raise_limit_exception", "gs://bucket/test.ipynb" - ), + self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2, ) - @patch("papermill.iorw.gs_is_retriable", side_effect=fallback_gs_is_retriable) + @patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable) @patch( - "papermill.iorw.GCSFileSystem", - side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": None}), 1 - ), + 'papermill.iorw.GCSFileSystem', + side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': None}), 1), ) - def test_gcs_fallback_retry_unknown_failure_code( - self, mock_gcs_filesystem, mock_gcs_retriable - ): - with patch.object(GCSHandler, "RETRY_DELAY", 0): - with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0): - with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0): + def test_gcs_fallback_retry_unknown_failure_code(self, mock_gcs_filesystem, mock_gcs_retriable): + with patch.object(GCSHandler, 'RETRY_DELAY', 0): + with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0): + with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0): self.assertEqual( - self.gcs_handler.write( - "raise_limit_exception", "gs://bucket/test.ipynb" - ), + self.gcs_handler.write('raise_limit_exception', 'gs://bucket/test.ipynb'), 2, ) - @patch("papermill.iorw.gs_is_retriable", return_value=False) + @patch('papermill.iorw.gs_is_retriable', return_value=False) @patch( - "papermill.iorw.GCSFileSystem", - side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": 500}), 1 - ), + 'papermill.iorw.GCSFileSystem', + side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': 500}), 1), ) def test_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable): with self.assertRaises(GCSRateLimitException): - self.gcs_handler.write("fatal_exception", "gs://bucket/test.ipynb") + self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb') - @patch("papermill.iorw.gs_is_retriable", side_effect=fallback_gs_is_retriable) + @patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable) @patch( - "papermill.iorw.GCSFileSystem", - side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": 500}), 1 - ), + 'papermill.iorw.GCSFileSystem', + side_effect=mock_gcs_fs_wrapper(GCSRateLimitException({'message': 'test', 'code': 500}), 1), ) def test_fallback_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable): with self.assertRaises(GCSRateLimitException): - self.gcs_handler.write("fatal_exception", "gs://bucket/test.ipynb") + self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb') @patch( - "papermill.iorw.GCSFileSystem", - side_effect=mock_gcs_fs_wrapper(ValueError("not-a-retry"), 1), + 'papermill.iorw.GCSFileSystem', + side_effect=mock_gcs_fs_wrapper(ValueError('not-a-retry'), 1), ) def test_gcs_unretryable(self, mock_gcs_filesystem): with self.assertRaises(ValueError): - self.gcs_handler.write("no_a_rate_limit", "gs://bucket/test.ipynb") + self.gcs_handler.write('no_a_rate_limit', 'gs://bucket/test.ipynb') diff --git a/papermill/tests/test_hdfs.py b/papermill/tests/test_hdfs.py index 0577e1f5..e8c49dd2 100644 --- a/papermill/tests/test_hdfs.py +++ b/papermill/tests/test_hdfs.py @@ -8,7 +8,7 @@ class MockHadoopFileSystem(MagicMock): def get_file_info(self, path): - return [MockFileInfo("test1.ipynb"), MockFileInfo("test2.ipynb")] + return [MockFileInfo('test1.ipynb'), MockFileInfo('test2.ipynb')] def open_input_stream(self, path): return MockHadoopFile() @@ -19,7 +19,7 @@ def open_output_stream(self, path): class MockHadoopFile: def __init__(self): - self._content = b"Content of notebook" + self._content = b'Content of notebook' def __enter__(self, *args): return self @@ -40,8 +40,8 @@ def __init__(self, path): self.path = path -@pytest.mark.skip(reason="No valid dep package for python 3.12 yet") -@patch("papermill.iorw.HadoopFileSystem", side_effect=MockHadoopFileSystem()) +@pytest.mark.skip(reason='No valid dep package for python 3.12 yet') +@patch('papermill.iorw.HadoopFileSystem', side_effect=MockHadoopFileSystem()) class HDFSTest(unittest.TestCase): def setUp(self): self.hdfs_handler = HDFSHandler() @@ -49,8 +49,8 @@ def setUp(self): def test_hdfs_listdir(self, mock_hdfs_filesystem): client = self.hdfs_handler._get_client() self.assertEqual( - self.hdfs_handler.listdir("hdfs:///Projects/"), - ["test1.ipynb", "test2.ipynb"], + self.hdfs_handler.listdir('hdfs:///Projects/'), + ['test1.ipynb', 'test2.ipynb'], ) # Check if client is the same after calling self.assertIs(client, self.hdfs_handler._get_client()) @@ -58,14 +58,12 @@ def test_hdfs_listdir(self, mock_hdfs_filesystem): def test_hdfs_read(self, mock_hdfs_filesystem): client = self.hdfs_handler._get_client() self.assertEqual( - self.hdfs_handler.read("hdfs:///Projects/test1.ipynb"), - b"Content of notebook", + self.hdfs_handler.read('hdfs:///Projects/test1.ipynb'), + b'Content of notebook', ) self.assertIs(client, self.hdfs_handler._get_client()) def test_hdfs_write(self, mock_hdfs_filesystem): client = self.hdfs_handler._get_client() - self.assertEqual( - self.hdfs_handler.write("hdfs:///Projects/test1.ipynb", b"New content"), 1 - ) + self.assertEqual(self.hdfs_handler.write('hdfs:///Projects/test1.ipynb', b'New content'), 1) self.assertIs(client, self.hdfs_handler._get_client()) diff --git a/papermill/tests/test_inspect.py b/papermill/tests/test_inspect.py index bab1df65..6d787e2d 100644 --- a/papermill/tests/test_inspect.py +++ b/papermill/tests/test_inspect.py @@ -3,11 +3,9 @@ import pytest from click import Context - from papermill.inspection import display_notebook_help, inspect_notebook - -NOTEBOOKS_PATH = Path(__file__).parent / "notebooks" +NOTEBOOKS_PATH = Path(__file__).parent / 'notebooks' def _get_fullpath(name): @@ -17,55 +15,55 @@ def _get_fullpath(name): @pytest.fixture def click_context(): mock = MagicMock(spec=Context, command=MagicMock()) - mock.command.get_usage.return_value = "Dummy usage" + mock.command.get_usage.return_value = 'Dummy usage' return mock @pytest.mark.parametrize( - "name, expected", + 'name, expected', [ - (_get_fullpath("no_parameters.ipynb"), {}), + (_get_fullpath('no_parameters.ipynb'), {}), ( - _get_fullpath("simple_execute.ipynb"), + _get_fullpath('simple_execute.ipynb'), { - "msg": { - "name": "msg", - "inferred_type_name": "None", - "default": "None", - "help": "", + 'msg': { + 'name': 'msg', + 'inferred_type_name': 'None', + 'default': 'None', + 'help': '', } }, ), ( - _get_fullpath("complex_parameters.ipynb"), + _get_fullpath('complex_parameters.ipynb'), { - "msg": { - "name": "msg", - "inferred_type_name": "None", - "default": "None", - "help": "", + 'msg': { + 'name': 'msg', + 'inferred_type_name': 'None', + 'default': 'None', + 'help': '', }, - "a": { - "name": "a", - "inferred_type_name": "float", - "default": "2.25", - "help": "Variable a", + 'a': { + 'name': 'a', + 'inferred_type_name': 'float', + 'default': '2.25', + 'help': 'Variable a', }, - "b": { - "name": "b", - "inferred_type_name": "List[str]", - "default": "['Hello','World']", - "help": "Nice list", + 'b': { + 'name': 'b', + 'inferred_type_name': 'List[str]', + 'default': "['Hello','World']", + 'help': 'Nice list', }, - "c": { - "name": "c", - "inferred_type_name": "NoneType", - "default": "None", - "help": "", + 'c': { + 'name': 'c', + 'inferred_type_name': 'NoneType', + 'default': 'None', + 'help': '', }, }, ), - (_get_fullpath("notimplemented_translator.ipynb"), {}), + (_get_fullpath('notimplemented_translator.ipynb'), {}), ], ) def test_inspect_notebook(name, expected): @@ -74,50 +72,50 @@ def test_inspect_notebook(name, expected): def test_str_path(): expected = { - "msg": { - "name": "msg", - "inferred_type_name": "None", - "default": "None", - "help": "", + 'msg': { + 'name': 'msg', + 'inferred_type_name': 'None', + 'default': 'None', + 'help': '', } } - assert inspect_notebook(str(_get_fullpath("simple_execute.ipynb"))) == expected + assert inspect_notebook(str(_get_fullpath('simple_execute.ipynb'))) == expected @pytest.mark.parametrize( - "name, expected", + 'name, expected', [ ( - _get_fullpath("no_parameters.ipynb"), + _get_fullpath('no_parameters.ipynb'), [ - "Dummy usage", + 'Dummy usage', "\nParameters inferred for notebook '{name}':", "\n No cell tagged 'parameters'", ], ), ( - _get_fullpath("simple_execute.ipynb"), + _get_fullpath('simple_execute.ipynb'), [ - "Dummy usage", + 'Dummy usage', "\nParameters inferred for notebook '{name}':", - " msg: Unknown type (default None)", + ' msg: Unknown type (default None)', ], ), ( - _get_fullpath("complex_parameters.ipynb"), + _get_fullpath('complex_parameters.ipynb'), [ - "Dummy usage", + 'Dummy usage', "\nParameters inferred for notebook '{name}':", - " msg: Unknown type (default None)", - " a: float (default 2.25) Variable a", + ' msg: Unknown type (default None)', + ' a: float (default 2.25) Variable a', " b: List[str] (default ['Hello','World'])\n Nice list", - " c: NoneType (default None) ", + ' c: NoneType (default None) ', ], ), ( - _get_fullpath("notimplemented_translator.ipynb"), + _get_fullpath('notimplemented_translator.ipynb'), [ - "Dummy usage", + 'Dummy usage', "\nParameters inferred for notebook '{name}':", "\n Can't infer anything about this notebook's parameters. It may not have any parameter defined.", # noqa ], @@ -125,7 +123,7 @@ def test_str_path(): ], ) def test_display_notebook_help(click_context, name, expected): - with patch("papermill.inspection.click.echo") as echo: + with patch('papermill.inspection.click.echo') as echo: display_notebook_help(click_context, str(name), None) assert echo.call_count == len(expected) diff --git a/papermill/tests/test_iorw.py b/papermill/tests/test_iorw.py index 39ad12b0..cb1eab75 100644 --- a/papermill/tests/test_iorw.py +++ b/papermill/tests/test_iorw.py @@ -1,31 +1,31 @@ +import io import json -import unittest import os -import io +import unittest +from tempfile import TemporaryDirectory +from unittest.mock import Mock, patch + import nbformat import pytest - from requests.exceptions import ConnectionError -from tempfile import TemporaryDirectory -from unittest.mock import Mock, patch from .. import iorw +from ..exceptions import PapermillException from ..iorw import ( + ADLHandler, HttpHandler, LocalHandler, NoIOHandler, - ADLHandler, NotebookNodeHandler, - StreamHandler, PapermillIO, - read_yaml_file, - papermill_io, + StreamHandler, local_file_io_cwd, + papermill_io, + read_yaml_file, ) -from ..exceptions import PapermillException from . import get_notebook_path -FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures") +FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures') class TestPapermillIO(unittest.TestCase): @@ -38,16 +38,16 @@ def __init__(self, ver): self.ver = ver def read(self, path): - return f"contents from {path} for version {self.ver}" + return f'contents from {path} for version {self.ver}' def listdir(self, path): - return ["fake", "contents"] + return ['fake', 'contents'] def write(self, buf, path): - return f"wrote {buf}" + return f'wrote {buf}' def pretty_path(self, path): - return f"{path}/pretty/{self.ver}" + return f'{path}/pretty/{self.ver}' class FakeByteHandler: def __init__(self, ver): @@ -59,13 +59,13 @@ def read(self, path): return f.read() def listdir(self, path): - return ["fake", "contents"] + return ['fake', 'contents'] def write(self, buf, path): - return f"wrote {buf}" + return f'wrote {buf}' def pretty_path(self, path): - return f"{path}/pretty/{self.ver}" + return f'{path}/pretty/{self.ver}' def setUp(self): self.papermill_io = PapermillIO() @@ -73,8 +73,8 @@ def setUp(self): self.fake1 = self.FakeHandler(1) self.fake2 = self.FakeHandler(2) self.fake_byte1 = self.FakeByteHandler(1) - self.papermill_io.register("fake", self.fake1) - self.papermill_io_bytes.register("notebooks", self.fake_byte1) + self.papermill_io.register('fake', self.fake1) + self.papermill_io_bytes.register('notebooks', self.fake_byte1) self.old_papermill_io = iorw.papermill_io iorw.papermill_io = self.papermill_io @@ -83,117 +83,103 @@ def tearDown(self): iorw.papermill_io = self.old_papermill_io def test_get_handler(self): - self.assertEqual(self.papermill_io.get_handler("fake"), self.fake1) + self.assertEqual(self.papermill_io.get_handler('fake'), self.fake1) def test_get_local_handler(self): with self.assertRaises(PapermillException): - self.papermill_io.get_handler("dne") + self.papermill_io.get_handler('dne') - self.papermill_io.register("local", self.fake2) - self.assertEqual(self.papermill_io.get_handler("dne"), self.fake2) + self.papermill_io.register('local', self.fake2) + self.assertEqual(self.papermill_io.get_handler('dne'), self.fake2) def test_get_no_io_handler(self): self.assertIsInstance(self.papermill_io.get_handler(None), NoIOHandler) def test_get_notebook_node_handler(self): - test_nb = nbformat.read( - get_notebook_path("test_notebooknode_io.ipynb"), as_version=4 - ) - self.assertIsInstance( - self.papermill_io.get_handler(test_nb), NotebookNodeHandler - ) + test_nb = nbformat.read(get_notebook_path('test_notebooknode_io.ipynb'), as_version=4) + self.assertIsInstance(self.papermill_io.get_handler(test_nb), NotebookNodeHandler) def test_entrypoint_register(self): fake_entrypoint = Mock(load=Mock()) - fake_entrypoint.name = "fake-from-entry-point://" + fake_entrypoint.name = 'fake-from-entry-point://' - with patch( - "entrypoints.get_group_all", return_value=[fake_entrypoint] - ) as mock_get_group_all: + with patch('entrypoints.get_group_all', return_value=[fake_entrypoint]) as mock_get_group_all: self.papermill_io.register_entry_points() - mock_get_group_all.assert_called_once_with("papermill.io") - fake_ = self.papermill_io.get_handler("fake-from-entry-point://") + mock_get_group_all.assert_called_once_with('papermill.io') + fake_ = self.papermill_io.get_handler('fake-from-entry-point://') assert fake_ == fake_entrypoint.load.return_value def test_register_ordering(self): # Should match fake1 with fake2 path - self.assertEqual(self.papermill_io.get_handler("fake2/path"), self.fake1) + self.assertEqual(self.papermill_io.get_handler('fake2/path'), self.fake1) self.papermill_io.reset() - self.papermill_io.register("fake", self.fake1) - self.papermill_io.register("fake2", self.fake2) + self.papermill_io.register('fake', self.fake1) + self.papermill_io.register('fake2', self.fake2) # Should match fake1 with fake1 path, and NOT fake2 path/match - self.assertEqual(self.papermill_io.get_handler("fake/path"), self.fake1) + self.assertEqual(self.papermill_io.get_handler('fake/path'), self.fake1) # Should match fake2 with fake2 path - self.assertEqual(self.papermill_io.get_handler("fake2/path"), self.fake2) + self.assertEqual(self.papermill_io.get_handler('fake2/path'), self.fake2) def test_read(self): - self.assertEqual( - self.papermill_io.read("fake/path"), "contents from fake/path for version 1" - ) + self.assertEqual(self.papermill_io.read('fake/path'), 'contents from fake/path for version 1') def test_read_bytes(self): - self.assertIsNotNone( - self.papermill_io_bytes.read( - "notebooks/gcs/gcs_in/gcs-simple_notebook.ipynb" - ) - ) + self.assertIsNotNone(self.papermill_io_bytes.read('notebooks/gcs/gcs_in/gcs-simple_notebook.ipynb')) def test_read_with_no_file_extension(self): with pytest.warns(UserWarning): - self.papermill_io.read("fake/path") + self.papermill_io.read('fake/path') def test_read_with_invalid_file_extension(self): with pytest.warns(UserWarning): - self.papermill_io.read("fake/path/fakeinputpath.ipynb1") + self.papermill_io.read('fake/path/fakeinputpath.ipynb1') def test_read_with_valid_file_extension(self): with pytest.warns(None) as warns: - self.papermill_io.read("fake/path/fakeinputpath.ipynb") + self.papermill_io.read('fake/path/fakeinputpath.ipynb') self.assertEqual(len(warns), 0) def test_read_yaml_with_no_file_extension(self): with pytest.warns(UserWarning): - read_yaml_file("fake/path") + read_yaml_file('fake/path') def test_read_yaml_with_invalid_file_extension(self): with pytest.warns(UserWarning): - read_yaml_file("fake/path/fakeinputpath.ipynb") + read_yaml_file('fake/path/fakeinputpath.ipynb') def test_read_stdin(self): - file_content = "Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ" - with patch("sys.stdin", io.StringIO(file_content)): - self.assertEqual(self.old_papermill_io.read("-"), file_content) + file_content = 'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ' + with patch('sys.stdin', io.StringIO(file_content)): + self.assertEqual(self.old_papermill_io.read('-'), file_content) def test_listdir(self): - self.assertEqual(self.papermill_io.listdir("fake/path"), ["fake", "contents"]) + self.assertEqual(self.papermill_io.listdir('fake/path'), ['fake', 'contents']) def test_write(self): - self.assertEqual(self.papermill_io.write("buffer", "fake/path"), "wrote buffer") + self.assertEqual(self.papermill_io.write('buffer', 'fake/path'), 'wrote buffer') def test_write_with_no_file_extension(self): with pytest.warns(UserWarning): - self.papermill_io.write("buffer", "fake/path") + self.papermill_io.write('buffer', 'fake/path') def test_write_with_path_of_none(self): - self.assertIsNone(self.papermill_io.write("buffer", None)) + self.assertIsNone(self.papermill_io.write('buffer', None)) def test_write_with_invalid_file_extension(self): with pytest.warns(UserWarning): - self.papermill_io.write("buffer", "fake/path/fakeoutputpath.ipynb1") + self.papermill_io.write('buffer', 'fake/path/fakeoutputpath.ipynb1') def test_write_stdout(self): - file_content = "Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ" + file_content = 'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ' out = io.BytesIO() - with patch("sys.stdout", out): - self.old_papermill_io.write(file_content, "-") - self.assertEqual(out.getvalue(), file_content.encode("utf-8")) + with patch('sys.stdout', out): + self.old_papermill_io.write(file_content, '-') + self.assertEqual(out.getvalue(), file_content.encode('utf-8')) def test_pretty_path(self): - self.assertEqual( - self.papermill_io.pretty_path("fake/path"), "fake/path/pretty/1" - ) + self.assertEqual(self.papermill_io.pretty_path('fake/path'), 'fake/path/pretty/1') class TestLocalHandler(unittest.TestCase): @@ -202,36 +188,34 @@ class TestLocalHandler(unittest.TestCase): """ def test_read_utf8(self): - self.assertEqual( - LocalHandler().read(os.path.join(FIXTURE_PATH, "rock.txt")).strip(), "✄" - ) + self.assertEqual(LocalHandler().read(os.path.join(FIXTURE_PATH, 'rock.txt')).strip(), '✄') def test_write_utf8(self): with TemporaryDirectory() as temp_dir: - path = os.path.join(temp_dir, "paper.txt") - LocalHandler().write("✄", path) - with open(path, encoding="utf-8") as f: - self.assertEqual(f.read().strip(), "✄") + path = os.path.join(temp_dir, 'paper.txt') + LocalHandler().write('✄', path) + with open(path, encoding='utf-8') as f: + self.assertEqual(f.read().strip(), '✄') def test_write_no_directory_exists(self): with self.assertRaises(FileNotFoundError): - LocalHandler().write("buffer", "fake/path/fakenb.ipynb") + LocalHandler().write('buffer', 'fake/path/fakenb.ipynb') def test_write_local_directory(self): - with patch.object(io, "open"): + with patch.object(io, 'open'): # Shouldn't raise with missing directory - LocalHandler().write("buffer", "local.ipynb") + LocalHandler().write('buffer', 'local.ipynb') def test_write_passed_cwd(self): with TemporaryDirectory() as temp_dir: handler = LocalHandler() handler.cwd(temp_dir) - handler.write("✄", "paper.txt") + handler.write('✄', 'paper.txt') - path = os.path.join(temp_dir, "paper.txt") - with open(path, encoding="utf-8") as f: - self.assertEqual(f.read().strip(), "✄") + path = os.path.join(temp_dir, 'paper.txt') + with open(path, encoding='utf-8') as f: + self.assertEqual(f.read().strip(), '✄') def test_local_file_io_cwd(self): with TemporaryDirectory() as temp_dir: @@ -241,16 +225,16 @@ def test_local_file_io_cwd(self): try: local_handler = LocalHandler() papermill_io.reset() - papermill_io.register("local", local_handler) + papermill_io.register('local', local_handler) with local_file_io_cwd(temp_dir): - local_handler.write("✄", "paper.txt") - self.assertEqual(local_handler.read("paper.txt"), "✄") + local_handler.write('✄', 'paper.txt') + self.assertEqual(local_handler.read('paper.txt'), '✄') # Double check it used the tmpdir - path = os.path.join(temp_dir, "paper.txt") - with open(path, encoding="utf-8") as f: - self.assertEqual(f.read().strip(), "✄") + path = os.path.join(temp_dir, 'paper.txt') + with open(path, encoding='utf-8') as f: + self.assertEqual(f.read().strip(), '✄') finally: papermill_io.handlers = handlers @@ -263,7 +247,7 @@ def test_invalid_string(self): # a string from which we can't extract a notebook is assumed to # be a file and an IOError will be raised with self.assertRaises(IOError): - LocalHandler().read("a random string") + LocalHandler().read('a random string') class TestNoIOHandler(unittest.TestCase): @@ -276,10 +260,10 @@ def test_raises_on_listdir(self): NoIOHandler().listdir(None) def test_write_returns_none(self): - self.assertIsNone(NoIOHandler().write("buf", None)) + self.assertIsNone(NoIOHandler().write('buf', None)) def test_pretty_path(self): - expect = "Notebook will not be saved" + expect = 'Notebook will not be saved' self.assertEqual(NoIOHandler().pretty_path(None), expect) @@ -291,20 +275,20 @@ class TestADLHandler(unittest.TestCase): def setUp(self): self.handler = ADLHandler() self.handler._client = Mock( - read=Mock(return_value=["foo", "bar", "baz"]), - listdir=Mock(return_value=["foo", "bar", "baz"]), + read=Mock(return_value=['foo', 'bar', 'baz']), + listdir=Mock(return_value=['foo', 'bar', 'baz']), write=Mock(), ) def test_read(self): - self.assertEqual(self.handler.read("some_path"), "foo\nbar\nbaz") + self.assertEqual(self.handler.read('some_path'), 'foo\nbar\nbaz') def test_listdir(self): - self.assertEqual(self.handler.listdir("some_path"), ["foo", "bar", "baz"]) + self.assertEqual(self.handler.listdir('some_path'), ['foo', 'bar', 'baz']) def test_write(self): - self.handler.write("foo", "bar") - self.handler._client.write.assert_called_once_with("foo", "bar") + self.handler.write('foo', 'bar') + self.handler._client.write.assert_called_once_with('foo', 'bar') class TestHttpHandler(unittest.TestCase): @@ -318,34 +302,32 @@ def test_listdir(self): `listdir` function is not supported. """ with self.assertRaises(PapermillException) as e: - HttpHandler.listdir("http://example.com") + HttpHandler.listdir('http://example.com') - self.assertEqual(f"{e.exception}", "listdir is not supported by HttpHandler") + self.assertEqual(f'{e.exception}', 'listdir is not supported by HttpHandler') def test_read(self): """ Tests that the `read` function performs a request to the giving path and returns the response. """ - path = "http://example.com" - text = "request test response" + path = 'http://example.com' + text = 'request test response' - with patch("papermill.iorw.requests.get") as mock_get: + with patch('papermill.iorw.requests.get') as mock_get: mock_get.return_value = Mock(text=text) self.assertEqual(HttpHandler.read(path), text) - mock_get.assert_called_once_with( - path, headers={"Accept": "application/json"} - ) + mock_get.assert_called_once_with(path, headers={'Accept': 'application/json'}) def test_write(self): """ Tests that the `write` function performs a put request to the given path. """ - path = "http://example.com" + path = 'http://example.com' buf = '{"papermill": true}' - with patch("papermill.iorw.requests.put") as mock_put: + with patch('papermill.iorw.requests.put') as mock_put: HttpHandler.write(buf, path) mock_put.assert_called_once_with(path, json=json.loads(buf)) @@ -353,7 +335,7 @@ def test_write_failure(self): """ Tests that the `write` function raises on failure to put the buffer. """ - path = "http://localhost:9999" + path = 'http://localhost:9999' buf = '{"papermill": true}' with self.assertRaises(ConnectionError): @@ -361,36 +343,34 @@ def test_write_failure(self): class TestStreamHandler(unittest.TestCase): - @patch("sys.stdin", io.StringIO("mock stream")) + @patch('sys.stdin', io.StringIO('mock stream')) def test_read_from_stdin(self): - result = StreamHandler().read("foo") - self.assertEqual(result, "mock stream") + result = StreamHandler().read('foo') + self.assertEqual(result, 'mock stream') def test_raises_on_listdir(self): with self.assertRaises(PapermillException): StreamHandler().listdir(None) - @patch("sys.stdout") + @patch('sys.stdout') def test_write_to_stdout_buffer(self, mock_stdout): mock_stdout.buffer = io.BytesIO() - StreamHandler().write("mock stream", "foo") - self.assertEqual(mock_stdout.buffer.getbuffer(), b"mock stream") + StreamHandler().write('mock stream', 'foo') + self.assertEqual(mock_stdout.buffer.getbuffer(), b'mock stream') - @patch("sys.stdout", new_callable=io.BytesIO) + @patch('sys.stdout', new_callable=io.BytesIO) def test_write_to_stdout(self, mock_stdout): - StreamHandler().write("mock stream", "foo") - self.assertEqual(mock_stdout.getbuffer(), b"mock stream") + StreamHandler().write('mock stream', 'foo') + self.assertEqual(mock_stdout.getbuffer(), b'mock stream') def test_pretty_path_returns_input_path(self): '''Should return the input str, which often is the default registered schema "-"''' - self.assertEqual(StreamHandler().pretty_path("foo"), "foo") + self.assertEqual(StreamHandler().pretty_path('foo'), 'foo') class TestNotebookNodeHandler(unittest.TestCase): def test_read_notebook_node(self): - input_nb = nbformat.read( - get_notebook_path("test_notebooknode_io.ipynb"), as_version=4 - ) + input_nb = nbformat.read(get_notebook_path('test_notebooknode_io.ipynb'), as_version=4) result = NotebookNodeHandler().read(input_nb) expect = ( '{\n "cells": [\n {\n "cell_type": "code",\n "execution_count": null,' @@ -403,12 +383,12 @@ def test_read_notebook_node(self): def test_raises_on_listdir(self): with self.assertRaises(PapermillException): - NotebookNodeHandler().listdir("foo") + NotebookNodeHandler().listdir('foo') def test_raises_on_write(self): with self.assertRaises(PapermillException): - NotebookNodeHandler().write("foo", "bar") + NotebookNodeHandler().write('foo', 'bar') def test_pretty_path(self): - expect = "NotebookNode object" - self.assertEqual(NotebookNodeHandler().pretty_path("foo"), expect) + expect = 'NotebookNode object' + self.assertEqual(NotebookNodeHandler().pretty_path('foo'), expect) diff --git a/papermill/tests/test_parameterize.py b/papermill/tests/test_parameterize.py index 4e2df4f4..fbd12ff0 100644 --- a/papermill/tests/test_parameterize.py +++ b/papermill/tests/test_parameterize.py @@ -1,205 +1,173 @@ import unittest +from datetime import datetime -from ..iorw import load_notebook_node from ..exceptions import PapermillMissingParameterException +from ..iorw import load_notebook_node from ..parameterize import ( + add_builtin_parameters, parameterize_notebook, parameterize_path, - add_builtin_parameters, ) from . import get_notebook_path -from datetime import datetime class TestNotebookParametrizing(unittest.TestCase): def count_nb_injected_parameter_cells(self, nb): - return len( - [ - c - for c in nb.cells - if "injected-parameters" in c.get("metadata", {}).get("tags", []) - ] - ) + return len([c for c in nb.cells if 'injected-parameters' in c.get('metadata', {}).get('tags', [])]) def test_no_tag_copying(self): # Test that injected cell does not copy other tags - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) - test_nb.cells[0]["metadata"]["tags"].append("some tag") + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) + test_nb.cells[0]['metadata']['tags'].append('some tag') - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) cell_zero = test_nb.cells[0] - self.assertTrue("some tag" in cell_zero.get("metadata").get("tags")) - self.assertTrue("parameters" in cell_zero.get("metadata").get("tags")) + self.assertTrue('some tag' in cell_zero.get('metadata').get('tags')) + self.assertTrue('parameters' in cell_zero.get('metadata').get('tags')) cell_one = test_nb.cells[1] - self.assertTrue("some tag" not in cell_one.get("metadata").get("tags")) - self.assertTrue("injected-parameters" in cell_one.get("metadata").get("tags")) + self.assertTrue('some tag' not in cell_one.get('metadata').get('tags')) + self.assertTrue('injected-parameters' in cell_one.get('metadata').get('tags')) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_injected_parameters_tag(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) cell_zero = test_nb.cells[0] - self.assertTrue("parameters" in cell_zero.get("metadata").get("tags")) - self.assertTrue( - "injected-parameters" not in cell_zero.get("metadata").get("tags") - ) + self.assertTrue('parameters' in cell_zero.get('metadata').get('tags')) + self.assertTrue('injected-parameters' not in cell_zero.get('metadata').get('tags')) cell_one = test_nb.cells[1] - self.assertTrue("injected-parameters" in cell_one.get("metadata").get("tags")) + self.assertTrue('injected-parameters' in cell_one.get('metadata').get('tags')) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_repeated_run_injected_parameters_tag(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 0) - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) - parameterize_notebook(test_nb, {"msg": "Hello"}) + parameterize_notebook(test_nb, {'msg': 'Hello'}) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_no_parameter_tag(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) - test_nb.cells[0]["metadata"]["tags"] = [] + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) + test_nb.cells[0]['metadata']['tags'] = [] - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) cell_zero = test_nb.cells[0] - self.assertTrue("injected-parameters" in cell_zero.get("metadata").get("tags")) - self.assertTrue("parameters" not in cell_zero.get("metadata").get("tags")) + self.assertTrue('injected-parameters' in cell_zero.get('metadata').get('tags')) + self.assertTrue('parameters' not in cell_zero.get('metadata').get('tags')) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_repeated_run_no_parameters_tag(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) - test_nb.cells[0]["metadata"]["tags"] = [] + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) + test_nb.cells[0]['metadata']['tags'] = [] self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 0) - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_custom_comment(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) - test_nb = parameterize_notebook( - test_nb, {"msg": "Hello"}, comment="This is a custom comment" - ) + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}, comment='This is a custom comment') cell_one = test_nb.cells[1] - first_line = cell_one["source"].split("\n")[0] - self.assertEqual(first_line, "# This is a custom comment") + first_line = cell_one['source'].split('\n')[0] + self.assertEqual(first_line, '# This is a custom comment') class TestBuiltinParameters(unittest.TestCase): def test_add_builtin_parameters_keeps_provided_parameters(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertEqual(with_builtin_parameters["foo"], "bar") + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertEqual(with_builtin_parameters['foo'], 'bar') def test_add_builtin_parameters_adds_dict_of_builtins(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertIn("pm", with_builtin_parameters) - self.assertIsInstance(with_builtin_parameters["pm"], type({})) + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertIn('pm', with_builtin_parameters) + self.assertIsInstance(with_builtin_parameters['pm'], type({})) def test_add_builtin_parameters_allows_to_override_builtin(self): - with_builtin_parameters = add_builtin_parameters({"pm": "foo"}) - self.assertEqual(with_builtin_parameters["pm"], "foo") + with_builtin_parameters = add_builtin_parameters({'pm': 'foo'}) + self.assertEqual(with_builtin_parameters['pm'], 'foo') def test_builtin_parameters_include_run_uuid(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertIn("run_uuid", with_builtin_parameters["pm"]) + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertIn('run_uuid', with_builtin_parameters['pm']) def test_builtin_parameters_include_current_datetime_local(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertIn("current_datetime_local", with_builtin_parameters["pm"]) - self.assertIsInstance( - with_builtin_parameters["pm"]["current_datetime_local"], datetime - ) + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertIn('current_datetime_local', with_builtin_parameters['pm']) + self.assertIsInstance(with_builtin_parameters['pm']['current_datetime_local'], datetime) def test_builtin_parameters_include_current_datetime_utc(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertIn("current_datetime_utc", with_builtin_parameters["pm"]) - self.assertIsInstance( - with_builtin_parameters["pm"]["current_datetime_utc"], datetime - ) + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertIn('current_datetime_utc', with_builtin_parameters['pm']) + self.assertIsInstance(with_builtin_parameters['pm']['current_datetime_utc'], datetime) class TestPathParameterizing(unittest.TestCase): def test_plain_text_path_with_empty_parameters_object(self): - self.assertEqual(parameterize_path("foo/bar", {}), "foo/bar") + self.assertEqual(parameterize_path('foo/bar', {}), 'foo/bar') def test_plain_text_path_with_none_parameters(self): - self.assertEqual(parameterize_path("foo/bar", None), "foo/bar") + self.assertEqual(parameterize_path('foo/bar', None), 'foo/bar') def test_plain_text_path_with_unused_parameters(self): - self.assertEqual(parameterize_path("foo/bar", {"baz": "quux"}), "foo/bar") + self.assertEqual(parameterize_path('foo/bar', {'baz': 'quux'}), 'foo/bar') def test_path_with_single_parameter(self): - self.assertEqual( - parameterize_path("foo/bar/{baz}", {"baz": "quux"}), "foo/bar/quux" - ) + self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': 'quux'}), 'foo/bar/quux') def test_path_with_boolean_parameter(self): - self.assertEqual( - parameterize_path("foo/bar/{baz}", {"baz": False}), "foo/bar/False" - ) + self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': False}), 'foo/bar/False') def test_path_with_dict_parameter(self): - self.assertEqual( - parameterize_path("foo/{bar[baz]}/", {"bar": {"baz": "quux"}}), "foo/quux/" - ) + self.assertEqual(parameterize_path('foo/{bar[baz]}/', {'bar': {'baz': 'quux'}}), 'foo/quux/') def test_path_with_list_parameter(self): - self.assertEqual( - parameterize_path("foo/{bar[0]}/", {"bar": [1, 2, 3]}), "foo/1/" - ) - self.assertEqual( - parameterize_path("foo/{bar[2]}/", {"bar": [1, 2, 3]}), "foo/3/" - ) + self.assertEqual(parameterize_path('foo/{bar[0]}/', {'bar': [1, 2, 3]}), 'foo/1/') + self.assertEqual(parameterize_path('foo/{bar[2]}/', {'bar': [1, 2, 3]}), 'foo/3/') def test_path_with_none_parameter(self): - self.assertEqual( - parameterize_path("foo/bar/{baz}", {"baz": None}), "foo/bar/None" - ) + self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': None}), 'foo/bar/None') def test_path_with_numeric_parameter(self): - self.assertEqual(parameterize_path("foo/bar/{baz}", {"baz": 42}), "foo/bar/42") + self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': 42}), 'foo/bar/42') def test_path_with_numeric_format_string(self): - self.assertEqual( - parameterize_path("foo/bar/{baz:03d}", {"baz": 42}), "foo/bar/042" - ) + self.assertEqual(parameterize_path('foo/bar/{baz:03d}', {'baz': 42}), 'foo/bar/042') def test_path_with_float_format_string(self): - self.assertEqual( - parameterize_path("foo/bar/{baz:.03f}", {"baz": 0.3}), "foo/bar/0.300" - ) + self.assertEqual(parameterize_path('foo/bar/{baz:.03f}', {'baz': 0.3}), 'foo/bar/0.300') def test_path_with_multiple_parameter(self): - self.assertEqual( - parameterize_path("{foo}/{baz}", {"foo": "bar", "baz": "quux"}), "bar/quux" - ) + self.assertEqual(parameterize_path('{foo}/{baz}', {'foo': 'bar', 'baz': 'quux'}), 'bar/quux') def test_parameterized_path_with_undefined_parameter(self): with self.assertRaises(PapermillMissingParameterException) as context: - parameterize_path("{foo}", {}) + parameterize_path('{foo}', {}) self.assertEqual(str(context.exception), "Missing parameter 'foo'") def test_parameterized_path_with_none_parameters(self): with self.assertRaises(PapermillMissingParameterException) as context: - parameterize_path("{foo}", None) + parameterize_path('{foo}', None) self.assertEqual(str(context.exception), "Missing parameter 'foo'") def test_path_of_none_returns_none(self): - self.assertIsNone(parameterize_path(path=None, parameters={"foo": "bar"})) + self.assertIsNone(parameterize_path(path=None, parameters={'foo': 'bar'})) self.assertIsNone(parameterize_path(path=None, parameters=None)) def test_path_of_notebook_node_returns_input(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) result_nb = parameterize_path(test_nb, parameters=None) self.assertIs(result_nb, test_nb) diff --git a/papermill/tests/test_s3.py b/papermill/tests/test_s3.py index 156b4a7a..de86f5b6 100644 --- a/papermill/tests/test_s3.py +++ b/papermill/tests/test_s3.py @@ -1,52 +1,52 @@ # The following tests are purposely limited to the exposed interface by iorw.py import os.path -import pytest + import boto3 import moto - +import pytest from moto import mock_s3 -from ..s3 import Bucket, Prefix, Key, S3 +from ..s3 import S3, Bucket, Key, Prefix @pytest.fixture def bucket_no_service(): """Returns a bucket instance with no services""" - return Bucket("my_test_bucket") + return Bucket('my_test_bucket') @pytest.fixture def bucket_with_service(): """Returns a bucket instance with a service""" - return Bucket("my_sqs_bucket", ["sqs"]) + return Bucket('my_sqs_bucket', ['sqs']) @pytest.fixture def bucket_sqs(): """Returns a bucket instance with a sqs service""" - return Bucket("my_sqs_bucket", ["sqs"]) + return Bucket('my_sqs_bucket', ['sqs']) @pytest.fixture def bucket_ec2(): """Returns a bucket instance with a ec2 service""" - return Bucket("my_sqs_bucket", ["ec2"]) + return Bucket('my_sqs_bucket', ['ec2']) @pytest.fixture def bucket_multiservice(): """Returns a bucket instance with a ec2 service""" - return Bucket("my_sqs_bucket", ["ec2", "sqs"]) + return Bucket('my_sqs_bucket', ['ec2', 'sqs']) def test_bucket_init(): - assert Bucket("my_test_bucket") - assert Bucket("my_sqs_bucket", "sqs") + assert Bucket('my_test_bucket') + assert Bucket('my_sqs_bucket', 'sqs') def test_bucket_defaults(): - name = "a bucket" + name = 'a bucket' b1 = Bucket(name) b2 = Bucket(name, None) @@ -86,19 +86,19 @@ def test_prefix_init(): Prefix(service=None) with pytest.raises(TypeError): - Prefix("my_test_prefix") + Prefix('my_test_prefix') - b1 = Bucket("my_test_bucket") - p1 = Prefix(b1, "sqs_test", service="sqs") - assert Prefix(b1, "test_bucket") - assert Prefix(b1, "test_bucket", service=None) - assert Prefix(b1, "test_bucket", None) + b1 = Bucket('my_test_bucket') + p1 = Prefix(b1, 'sqs_test', service='sqs') + assert Prefix(b1, 'test_bucket') + assert Prefix(b1, 'test_bucket', service=None) + assert Prefix(b1, 'test_bucket', None) assert p1.bucket.service == p1.service def test_prefix_defaults(): - bucket = Bucket("my data pool") - name = "bigdata bucket" + bucket = Bucket('my data pool') + name = 'bigdata bucket' p1 = Prefix(bucket, name) p2 = Prefix(bucket, name, None) @@ -107,13 +107,13 @@ def test_prefix_defaults(): def test_prefix_str(bucket_sqs): - p1 = Prefix(bucket_sqs, "sqs_prefix_test", "sqs") - assert str(p1) == "s3://" + str(bucket_sqs) + "/sqs_prefix_test" + p1 = Prefix(bucket_sqs, 'sqs_prefix_test', 'sqs') + assert str(p1) == 's3://' + str(bucket_sqs) + '/sqs_prefix_test' def test_prefix_repr(bucket_sqs): - p1 = Prefix(bucket_sqs, "sqs_prefix_test", "sqs") - assert repr(p1) == "s3://" + str(bucket_sqs) + "/sqs_prefix_test" + p1 = Prefix(bucket_sqs, 'sqs_prefix_test', 'sqs') + assert repr(p1) == 's3://' + str(bucket_sqs) + '/sqs_prefix_test' def test_key_init(): @@ -121,13 +121,13 @@ def test_key_init(): def test_key_repr(): - k = Key("foo", "bar") - assert repr(k) == "s3://foo/bar" + k = Key('foo', 'bar') + assert repr(k) == 's3://foo/bar' def test_key_defaults(): - bucket = Bucket("my data pool") - name = "bigdata bucket" + bucket = Bucket('my data pool') + name = 'bigdata bucket' k1 = Key(bucket, name) k2 = Key(bucket, name, None, None, None, None, None) @@ -148,36 +148,36 @@ def test_s3_defaults(): local_dir = os.path.dirname(os.path.abspath(__file__)) -test_bucket_name = "test-pm-bucket" -test_string = "Hello" -test_file_path = "notebooks/s3/s3_in/s3-simple_notebook.ipynb" -test_empty_file_path = "notebooks/s3/s3_in/s3-empty.ipynb" +test_bucket_name = 'test-pm-bucket' +test_string = 'Hello' +test_file_path = 'notebooks/s3/s3_in/s3-simple_notebook.ipynb' +test_empty_file_path = 'notebooks/s3/s3_in/s3-empty.ipynb' with open(os.path.join(local_dir, test_file_path)) as f: test_nb_content = f.read() -no_empty_lines = lambda s: "\n".join([l for l in s.split("\n") if len(l) > 0]) +no_empty_lines = lambda s: '\n'.join([l for l in s.split('\n') if len(l) > 0]) test_clean_nb_content = no_empty_lines(test_nb_content) -read_from_gen = lambda g: "\n".join(g) +read_from_gen = lambda g: '\n'.join(g) -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def s3_client(): mock_s3 = moto.mock_s3() mock_s3.start() - client = boto3.client("s3") + client = boto3.client('s3') client.create_bucket( Bucket=test_bucket_name, - CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}, ) client.put_object(Bucket=test_bucket_name, Key=test_file_path, Body=test_nb_content) - client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body="") + client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body='') yield S3() try: client.delete_object(Bucket=test_bucket_name, Key=test_file_path) - client.delete_object(Bucket=test_bucket_name, Key=test_file_path + ".txt") + client.delete_object(Bucket=test_bucket_name, Key=test_file_path + '.txt') client.delete_object(Bucket=test_bucket_name, Key=test_empty_file_path) except Exception: pass @@ -185,19 +185,19 @@ def s3_client(): def test_s3_read(s3_client): - s3_path = f"s3://{test_bucket_name}/{test_file_path}" + s3_path = f's3://{test_bucket_name}/{test_file_path}' data = read_from_gen(s3_client.read(s3_path)) assert data == test_clean_nb_content def test_s3_read_empty(s3_client): - s3_path = f"s3://{test_bucket_name}/{test_empty_file_path}" + s3_path = f's3://{test_bucket_name}/{test_empty_file_path}' data = read_from_gen(s3_client.read(s3_path)) - assert data == "" + assert data == '' def test_s3_write(s3_client): - s3_path = f"s3://{test_bucket_name}/{test_file_path}.txt" + s3_path = f's3://{test_bucket_name}/{test_file_path}.txt' s3_client.cp_string(test_string, s3_path) data = read_from_gen(s3_client.read(s3_path)) @@ -205,7 +205,7 @@ def test_s3_write(s3_client): def test_s3_overwrite(s3_client): - s3_path = f"s3://{test_bucket_name}/{test_file_path}" + s3_path = f's3://{test_bucket_name}/{test_file_path}' s3_client.cp_string(test_string, s3_path) data = read_from_gen(s3_client.read(s3_path)) @@ -214,8 +214,8 @@ def test_s3_overwrite(s3_client): def test_s3_listdir(s3_client): dir_name = os.path.dirname(test_file_path) - s3_dir = f"s3://{test_bucket_name}/{dir_name}" - s3_path = f"s3://{test_bucket_name}/{test_file_path}" + s3_dir = f's3://{test_bucket_name}/{dir_name}' + s3_path = f's3://{test_bucket_name}/{test_file_path}' dir_listings = s3_client.listdir(s3_dir) assert len(dir_listings) == 2 assert s3_path in dir_listings diff --git a/papermill/tests/test_translators.py b/papermill/tests/test_translators.py index 906784f6..ab49475d 100644 --- a/papermill/tests/test_translators.py +++ b/papermill/tests/test_translators.py @@ -1,8 +1,7 @@ -import pytest - -from unittest.mock import Mock from collections import OrderedDict +from unittest.mock import Mock +import pytest from nbformat.v4 import new_code_cell from .. import translators @@ -11,29 +10,29 @@ @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, '{"foo": "bar"}'), - ({"foo": '"bar"'}, '{"foo": "\\"bar\\""}'), - ({"foo": ["bar"]}, '{"foo": ["bar"]}'), - ({"foo": {"bar": "baz"}}, '{"foo": {"bar": "baz"}}'), - ({"foo": {"bar": '"baz"'}}, '{"foo": {"bar": "\\"baz\\""}}'), - (["foo"], '["foo"]'), - (["foo", '"bar"'], '["foo", "\\"bar\\""]'), - ([{"foo": "bar"}], '[{"foo": "bar"}]'), - ([{"foo": '"bar"'}], '[{"foo": "\\"bar\\""}]'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (float("nan"), "float('nan')"), - (float("-inf"), "float('-inf')"), - (float("inf"), "float('inf')"), - (True, "True"), - (False, "False"), - (None, "None"), + ({'foo': 'bar'}, '{"foo": "bar"}'), + ({'foo': '"bar"'}, '{"foo": "\\"bar\\""}'), + ({'foo': ['bar']}, '{"foo": ["bar"]}'), + ({'foo': {'bar': 'baz'}}, '{"foo": {"bar": "baz"}}'), + ({'foo': {'bar': '"baz"'}}, '{"foo": {"bar": "\\"baz\\""}}'), + (['foo'], '["foo"]'), + (['foo', '"bar"'], '["foo", "\\"bar\\""]'), + ([{'foo': 'bar'}], '[{"foo": "bar"}]'), + ([{'foo': '"bar"'}], '[{"foo": "\\"bar\\""}]'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (float('nan'), "float('nan')"), + (float('-inf'), "float('-inf')"), + (float('inf'), "float('inf')"), + (True, 'True'), + (False, 'False'), + (None, 'None'), ], ) def test_translate_type_python(test_input, expected): @@ -41,16 +40,16 @@ def test_translate_type_python(test_input, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'), - ({"foo": True}, "# Parameters\nfoo = True\n"), - ({"foo": 5}, "# Parameters\nfoo = 5\n"), - ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = ["bar", "baz"]\n'), - ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = {"bar": "baz"}\n'), + ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'), + ({'foo': True}, '# Parameters\nfoo = True\n'), + ({'foo': 5}, '# Parameters\nfoo = 5\n'), + ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = ["bar", "baz"]\n'), + ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = {"bar": "baz"}\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '# Parameters\nfoo = "bar"\nbaz = ["buz"]\n', ), ], @@ -60,39 +59,39 @@ def test_translate_codify_python(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")], ) def test_translate_comment_python(test_input, expected): assert translators.PythonTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("a = 2", [Parameter("a", "None", "2", "")]), - ("a: int = 2", [Parameter("a", "int", "2", "")]), - ("a = 2 # type:int", [Parameter("a", "int", "2", "")]), + ('a = 2', [Parameter('a', 'None', '2', '')]), + ('a: int = 2', [Parameter('a', 'int', '2', '')]), + ('a = 2 # type:int', [Parameter('a', 'int', '2', '')]), ( - "a = False # Nice variable a", - [Parameter("a", "None", "False", "Nice variable a")], + 'a = False # Nice variable a', + [Parameter('a', 'None', 'False', 'Nice variable a')], ), ( - "a: float = 2.258 # type: int Nice variable a", - [Parameter("a", "float", "2.258", "Nice variable a")], + 'a: float = 2.258 # type: int Nice variable a', + [Parameter('a', 'float', '2.258', 'Nice variable a')], ), ( "a = 'this is a string' # type: int Nice variable a", - [Parameter("a", "int", "'this is a string'", "Nice variable a")], + [Parameter('a', 'int', "'this is a string'", 'Nice variable a')], ), ( "a: List[str] = ['this', 'is', 'a', 'string', 'list'] # Nice variable a", [ Parameter( - "a", - "List[str]", + 'a', + 'List[str]', "['this', 'is', 'a', 'string', 'list']", - "Nice variable a", + 'Nice variable a', ) ], ), @@ -100,10 +99,10 @@ def test_translate_comment_python(test_input, expected): "a: List[str] = [\n 'this', # First\n 'is',\n 'a',\n 'string',\n 'list' # Last\n] # Nice variable a", # noqa [ Parameter( - "a", - "List[str]", + 'a', + 'List[str]', "['this','is','a','string','list']", - "Nice variable a", + 'Nice variable a', ) ], ), @@ -111,10 +110,10 @@ def test_translate_comment_python(test_input, expected): "a: List[str] = [\n 'this',\n 'is',\n 'a',\n 'string',\n 'list'\n] # Nice variable a", # noqa [ Parameter( - "a", - "List[str]", + 'a', + 'List[str]', "['this','is','a','string','list']", - "Nice variable a", + 'Nice variable a', ) ], ), @@ -132,12 +131,12 @@ def test_translate_comment_python(test_input, expected): """, [ Parameter( - "a", - "List[str]", + 'a', + 'List[str]', "['this','is','a','string','list']", - "Nice variable a", + 'Nice variable a', ), - Parameter("b", "float", "-2.3432", "My b variable"), + Parameter('b', 'float', '-2.3432', 'My b variable'), ], ), ], @@ -148,26 +147,26 @@ def test_inspect_python(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, 'list("foo" = "bar")'), - ({"foo": '"bar"'}, 'list("foo" = "\\"bar\\"")'), - ({"foo": ["bar"]}, 'list("foo" = list("bar"))'), - ({"foo": {"bar": "baz"}}, 'list("foo" = list("bar" = "baz"))'), - ({"foo": {"bar": '"baz"'}}, 'list("foo" = list("bar" = "\\"baz\\""))'), - (["foo"], 'list("foo")'), - (["foo", '"bar"'], 'list("foo", "\\"bar\\"")'), - ([{"foo": "bar"}], 'list(list("foo" = "bar"))'), - ([{"foo": '"bar"'}], 'list(list("foo" = "\\"bar\\""))'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (True, "TRUE"), - (False, "FALSE"), - (None, "NULL"), + ({'foo': 'bar'}, 'list("foo" = "bar")'), + ({'foo': '"bar"'}, 'list("foo" = "\\"bar\\"")'), + ({'foo': ['bar']}, 'list("foo" = list("bar"))'), + ({'foo': {'bar': 'baz'}}, 'list("foo" = list("bar" = "baz"))'), + ({'foo': {'bar': '"baz"'}}, 'list("foo" = list("bar" = "\\"baz\\""))'), + (['foo'], 'list("foo")'), + (['foo', '"bar"'], 'list("foo", "\\"bar\\"")'), + ([{'foo': 'bar'}], 'list(list("foo" = "bar"))'), + ([{'foo': '"bar"'}], 'list(list("foo" = "\\"bar\\""))'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, 'TRUE'), + (False, 'FALSE'), + (None, 'NULL'), ], ) def test_translate_type_r(test_input, expected): @@ -175,28 +174,28 @@ def test_translate_type_r(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")], ) def test_translate_comment_r(test_input, expected): assert translators.RTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'), - ({"foo": True}, "# Parameters\nfoo = TRUE\n"), - ({"foo": 5}, "# Parameters\nfoo = 5\n"), - ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = list("bar", "baz")\n'), - ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = list("bar" = "baz")\n'), + ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'), + ({'foo': True}, '# Parameters\nfoo = TRUE\n'), + ({'foo': 5}, '# Parameters\nfoo = 5\n'), + ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = list("bar", "baz")\n'), + ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = list("bar" = "baz")\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '# Parameters\nfoo = "bar"\nbaz = list("buz")\n', ), # Underscores remove - ({"___foo": 5}, "# Parameters\nfoo = 5\n"), + ({'___foo': 5}, '# Parameters\nfoo = 5\n'), ], ) def test_translate_codify_r(parameters, expected): @@ -204,28 +203,28 @@ def test_translate_codify_r(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, 'Map("foo" -> "bar")'), - ({"foo": '"bar"'}, 'Map("foo" -> "\\"bar\\"")'), - ({"foo": ["bar"]}, 'Map("foo" -> Seq("bar"))'), - ({"foo": {"bar": "baz"}}, 'Map("foo" -> Map("bar" -> "baz"))'), - ({"foo": {"bar": '"baz"'}}, 'Map("foo" -> Map("bar" -> "\\"baz\\""))'), - (["foo"], 'Seq("foo")'), - (["foo", '"bar"'], 'Seq("foo", "\\"bar\\"")'), - ([{"foo": "bar"}], 'Seq(Map("foo" -> "bar"))'), - ([{"foo": '"bar"'}], 'Seq(Map("foo" -> "\\"bar\\""))'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (2147483648, "2147483648L"), - (-2147483649, "-2147483649L"), - (True, "true"), - (False, "false"), - (None, "None"), + ({'foo': 'bar'}, 'Map("foo" -> "bar")'), + ({'foo': '"bar"'}, 'Map("foo" -> "\\"bar\\"")'), + ({'foo': ['bar']}, 'Map("foo" -> Seq("bar"))'), + ({'foo': {'bar': 'baz'}}, 'Map("foo" -> Map("bar" -> "baz"))'), + ({'foo': {'bar': '"baz"'}}, 'Map("foo" -> Map("bar" -> "\\"baz\\""))'), + (['foo'], 'Seq("foo")'), + (['foo', '"bar"'], 'Seq("foo", "\\"bar\\"")'), + ([{'foo': 'bar'}], 'Seq(Map("foo" -> "bar"))'), + ([{'foo': '"bar"'}], 'Seq(Map("foo" -> "\\"bar\\""))'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (2147483648, '2147483648L'), + (-2147483649, '-2147483649L'), + (True, 'true'), + (False, 'false'), + (None, 'None'), ], ) def test_translate_type_scala(test_input, expected): @@ -233,19 +232,19 @@ def test_translate_type_scala(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "//"), ("foo", "// foo"), ("['best effort']", "// ['best effort']")], + 'test_input,expected', + [('', '//'), ('foo', '// foo'), ("['best effort']", "// ['best effort']")], ) def test_translate_comment_scala(test_input, expected): assert translators.ScalaTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "input_name,input_value,expected", + 'input_name,input_value,expected', [ - ("foo", '""', 'val foo = ""'), - ("foo", '"bar"', 'val foo = "bar"'), - ("foo", 'Map("foo" -> "bar")', 'val foo = Map("foo" -> "bar")'), + ('foo', '""', 'val foo = ""'), + ('foo', '"bar"', 'val foo = "bar"'), + ('foo', 'Map("foo" -> "bar")', 'val foo = Map("foo" -> "bar")'), ], ) def test_translate_assign_scala(input_name, input_value, expected): @@ -253,16 +252,16 @@ def test_translate_assign_scala(input_name, input_value, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '// Parameters\nval foo = "bar"\n'), - ({"foo": True}, "// Parameters\nval foo = true\n"), - ({"foo": 5}, "// Parameters\nval foo = 5\n"), - ({"foo": 1.1}, "// Parameters\nval foo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '// Parameters\nval foo = Seq("bar", "baz")\n'), - ({"foo": {"bar": "baz"}}, '// Parameters\nval foo = Map("bar" -> "baz")\n'), + ({'foo': 'bar'}, '// Parameters\nval foo = "bar"\n'), + ({'foo': True}, '// Parameters\nval foo = true\n'), + ({'foo': 5}, '// Parameters\nval foo = 5\n'), + ({'foo': 1.1}, '// Parameters\nval foo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '// Parameters\nval foo = Seq("bar", "baz")\n'), + ({'foo': {'bar': 'baz'}}, '// Parameters\nval foo = Map("bar" -> "baz")\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '// Parameters\nval foo = "bar"\nval baz = Seq("buz")\n', ), ], @@ -273,26 +272,26 @@ def test_translate_codify_scala(parameters, expected): # C# section @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, 'new Dictionary{ { "foo" , "bar" } }'), - ({"foo": '"bar"'}, 'new Dictionary{ { "foo" , "\\"bar\\"" } }'), - (["foo"], 'new [] { "foo" }'), - (["foo", '"bar"'], 'new [] { "foo", "\\"bar\\"" }'), + ({'foo': 'bar'}, 'new Dictionary{ { "foo" , "bar" } }'), + ({'foo': '"bar"'}, 'new Dictionary{ { "foo" , "\\"bar\\"" } }'), + (['foo'], 'new [] { "foo" }'), + (['foo', '"bar"'], 'new [] { "foo", "\\"bar\\"" }'), ( - [{"foo": "bar"}], + [{'foo': 'bar'}], 'new [] { new Dictionary{ { "foo" , "bar" } } }', ), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (2147483648, "2147483648L"), - (-2147483649, "-2147483649L"), - (True, "true"), - (False, "false"), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (2147483648, '2147483648L'), + (-2147483649, '-2147483649L'), + (True, 'true'), + (False, 'false'), ], ) def test_translate_type_csharp(test_input, expected): @@ -300,34 +299,34 @@ def test_translate_type_csharp(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "//"), ("foo", "// foo"), ("['best effort']", "// ['best effort']")], + 'test_input,expected', + [('', '//'), ('foo', '// foo'), ("['best effort']", "// ['best effort']")], ) def test_translate_comment_csharp(test_input, expected): assert translators.CSharpTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "input_name,input_value,expected", - [("foo", '""', 'var foo = "";'), ("foo", '"bar"', 'var foo = "bar";')], + 'input_name,input_value,expected', + [('foo', '""', 'var foo = "";'), ('foo', '"bar"', 'var foo = "bar";')], ) def test_translate_assign_csharp(input_name, input_value, expected): assert translators.CSharpTranslator.assign(input_name, input_value) == expected @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '// Parameters\nvar foo = "bar";\n'), - ({"foo": True}, "// Parameters\nvar foo = true;\n"), - ({"foo": 5}, "// Parameters\nvar foo = 5;\n"), - ({"foo": 1.1}, "// Parameters\nvar foo = 1.1;\n"), + ({'foo': 'bar'}, '// Parameters\nvar foo = "bar";\n'), + ({'foo': True}, '// Parameters\nvar foo = true;\n'), + ({'foo': 5}, '// Parameters\nvar foo = 5;\n'), + ({'foo': 1.1}, '// Parameters\nvar foo = 1.1;\n'), ( - {"foo": ["bar", "baz"]}, + {'foo': ['bar', 'baz']}, '// Parameters\nvar foo = new [] { "bar", "baz" };\n', ), ( - {"foo": {"bar": "baz"}}, + {'foo': {'bar': 'baz'}}, '// Parameters\nvar foo = new Dictionary{ { "bar" , "baz" } };\n', ), ], @@ -338,29 +337,29 @@ def test_translate_codify_csharp(parameters, expected): # Powershell section @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{`"foo`": `"bar`"}"'), - ({"foo": "bar"}, '@{"foo" = "bar"}'), - ({"foo": '"bar"'}, '@{"foo" = "`"bar`""}'), - ({"foo": ["bar"]}, '@{"foo" = @("bar")}'), - ({"foo": {"bar": "baz"}}, '@{"foo" = @{"bar" = "baz"}}'), - ({"foo": {"bar": '"baz"'}}, '@{"foo" = @{"bar" = "`"baz`""}}'), - (["foo"], '@("foo")'), - (["foo", '"bar"'], '@("foo", "`"bar`"")'), - ([{"foo": "bar"}], '@(@{"foo" = "bar"})'), - ([{"foo": '"bar"'}], '@(@{"foo" = "`"bar`""})'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (float("nan"), "[double]::NaN"), - (float("-inf"), "[double]::NegativeInfinity"), - (float("inf"), "[double]::PositiveInfinity"), - (True, "$True"), - (False, "$False"), - (None, "$Null"), + ({'foo': 'bar'}, '@{"foo" = "bar"}'), + ({'foo': '"bar"'}, '@{"foo" = "`"bar`""}'), + ({'foo': ['bar']}, '@{"foo" = @("bar")}'), + ({'foo': {'bar': 'baz'}}, '@{"foo" = @{"bar" = "baz"}}'), + ({'foo': {'bar': '"baz"'}}, '@{"foo" = @{"bar" = "`"baz`""}}'), + (['foo'], '@("foo")'), + (['foo', '"bar"'], '@("foo", "`"bar`"")'), + ([{'foo': 'bar'}], '@(@{"foo" = "bar"})'), + ([{'foo': '"bar"'}], '@(@{"foo" = "`"bar`""})'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (float('nan'), '[double]::NaN'), + (float('-inf'), '[double]::NegativeInfinity'), + (float('inf'), '[double]::PositiveInfinity'), + (True, '$True'), + (False, '$False'), + (None, '$Null'), ], ) def test_translate_type_powershell(test_input, expected): @@ -368,16 +367,16 @@ def test_translate_type_powershell(test_input, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '# Parameters\n$foo = "bar"\n'), - ({"foo": True}, "# Parameters\n$foo = $True\n"), - ({"foo": 5}, "# Parameters\n$foo = 5\n"), - ({"foo": 1.1}, "# Parameters\n$foo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '# Parameters\n$foo = @("bar", "baz")\n'), - ({"foo": {"bar": "baz"}}, '# Parameters\n$foo = @{"bar" = "baz"}\n'), + ({'foo': 'bar'}, '# Parameters\n$foo = "bar"\n'), + ({'foo': True}, '# Parameters\n$foo = $True\n'), + ({'foo': 5}, '# Parameters\n$foo = 5\n'), + ({'foo': 1.1}, '# Parameters\n$foo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '# Parameters\n$foo = @("bar", "baz")\n'), + ({'foo': {'bar': 'baz'}}, '# Parameters\n$foo = @{"bar" = "baz"}\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '# Parameters\n$foo = "bar"\n$baz = @("buz")\n', ), ], @@ -387,16 +386,16 @@ def test_translate_codify_powershell(parameters, expected): @pytest.mark.parametrize( - "input_name,input_value,expected", - [("foo", '""', '$foo = ""'), ("foo", '"bar"', '$foo = "bar"')], + 'input_name,input_value,expected', + [('foo', '""', '$foo = ""'), ('foo', '"bar"', '$foo = "bar"')], ) def test_translate_assign_powershell(input_name, input_value, expected): assert translators.PowershellTranslator.assign(input_name, input_value) == expected @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")], ) def test_translate_comment_powershell(test_input, expected): assert translators.PowershellTranslator.comment(test_input) == expected @@ -404,23 +403,23 @@ def test_translate_comment_powershell(test_input, expected): # F# section @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, '[ ("foo", "bar" :> IComparable) ] |> Map.ofList'), - ({"foo": '"bar"'}, '[ ("foo", "\\"bar\\"" :> IComparable) ] |> Map.ofList'), - (["foo"], '[ "foo" ]'), - (["foo", '"bar"'], '[ "foo"; "\\"bar\\"" ]'), - ([{"foo": "bar"}], '[ [ ("foo", "bar" :> IComparable) ] |> Map.ofList ]'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (2147483648, "2147483648L"), - (-2147483649, "-2147483649L"), - (True, "true"), - (False, "false"), + ({'foo': 'bar'}, '[ ("foo", "bar" :> IComparable) ] |> Map.ofList'), + ({'foo': '"bar"'}, '[ ("foo", "\\"bar\\"" :> IComparable) ] |> Map.ofList'), + (['foo'], '[ "foo" ]'), + (['foo', '"bar"'], '[ "foo"; "\\"bar\\"" ]'), + ([{'foo': 'bar'}], '[ [ ("foo", "bar" :> IComparable) ] |> Map.ofList ]'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (2147483648, '2147483648L'), + (-2147483649, '-2147483649L'), + (True, 'true'), + (False, 'false'), ], ) def test_translate_type_fsharp(test_input, expected): @@ -428,10 +427,10 @@ def test_translate_type_fsharp(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("", "(* *)"), - ("foo", "(* foo *)"), + ('', '(* *)'), + ('foo', '(* foo *)'), ("['best effort']", "(* ['best effort'] *)"), ], ) @@ -440,23 +439,23 @@ def test_translate_comment_fsharp(test_input, expected): @pytest.mark.parametrize( - "input_name,input_value,expected", - [("foo", '""', 'let foo = ""'), ("foo", '"bar"', 'let foo = "bar"')], + 'input_name,input_value,expected', + [('foo', '""', 'let foo = ""'), ('foo', '"bar"', 'let foo = "bar"')], ) def test_translate_assign_fsharp(input_name, input_value, expected): assert translators.FSharpTranslator.assign(input_name, input_value) == expected @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '(* Parameters *)\nlet foo = "bar"\n'), - ({"foo": True}, "(* Parameters *)\nlet foo = true\n"), - ({"foo": 5}, "(* Parameters *)\nlet foo = 5\n"), - ({"foo": 1.1}, "(* Parameters *)\nlet foo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '(* Parameters *)\nlet foo = [ "bar"; "baz" ]\n'), + ({'foo': 'bar'}, '(* Parameters *)\nlet foo = "bar"\n'), + ({'foo': True}, '(* Parameters *)\nlet foo = true\n'), + ({'foo': 5}, '(* Parameters *)\nlet foo = 5\n'), + ({'foo': 1.1}, '(* Parameters *)\nlet foo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '(* Parameters *)\nlet foo = [ "bar"; "baz" ]\n'), ( - {"foo": {"bar": "baz"}}, + {'foo': {'bar': 'baz'}}, '(* Parameters *)\nlet foo = [ ("bar", "baz" :> IComparable) ] |> Map.ofList\n', ), ], @@ -466,26 +465,26 @@ def test_translate_codify_fsharp(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, 'Dict("foo" => "bar")'), - ({"foo": '"bar"'}, 'Dict("foo" => "\\"bar\\"")'), - ({"foo": ["bar"]}, 'Dict("foo" => ["bar"])'), - ({"foo": {"bar": "baz"}}, 'Dict("foo" => Dict("bar" => "baz"))'), - ({"foo": {"bar": '"baz"'}}, 'Dict("foo" => Dict("bar" => "\\"baz\\""))'), - (["foo"], '["foo"]'), - (["foo", '"bar"'], '["foo", "\\"bar\\""]'), - ([{"foo": "bar"}], '[Dict("foo" => "bar")]'), - ([{"foo": '"bar"'}], '[Dict("foo" => "\\"bar\\"")]'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (True, "true"), - (False, "false"), - (None, "nothing"), + ({'foo': 'bar'}, 'Dict("foo" => "bar")'), + ({'foo': '"bar"'}, 'Dict("foo" => "\\"bar\\"")'), + ({'foo': ['bar']}, 'Dict("foo" => ["bar"])'), + ({'foo': {'bar': 'baz'}}, 'Dict("foo" => Dict("bar" => "baz"))'), + ({'foo': {'bar': '"baz"'}}, 'Dict("foo" => Dict("bar" => "\\"baz\\""))'), + (['foo'], '["foo"]'), + (['foo', '"bar"'], '["foo", "\\"bar\\""]'), + ([{'foo': 'bar'}], '[Dict("foo" => "bar")]'), + ([{'foo': '"bar"'}], '[Dict("foo" => "\\"bar\\"")]'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, 'true'), + (False, 'false'), + (None, 'nothing'), ], ) def test_translate_type_julia(test_input, expected): @@ -493,16 +492,16 @@ def test_translate_type_julia(test_input, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'), - ({"foo": True}, "# Parameters\nfoo = true\n"), - ({"foo": 5}, "# Parameters\nfoo = 5\n"), - ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = ["bar", "baz"]\n'), - ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = Dict("bar" => "baz")\n'), + ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'), + ({'foo': True}, '# Parameters\nfoo = true\n'), + ({'foo': 5}, '# Parameters\nfoo = 5\n'), + ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = ["bar", "baz"]\n'), + ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = Dict("bar" => "baz")\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '# Parameters\nfoo = "bar"\nbaz = ["buz"]\n', ), ], @@ -512,44 +511,44 @@ def test_translate_codify_julia(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ('["best effort"]', '# ["best effort"]')], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ('["best effort"]', '# ["best effort"]')], ) def test_translate_comment_julia(test_input, expected): assert translators.JuliaTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{""foo"": ""bar""}"'), - ({1: "foo"}, "containers.Map({'1'}, {\"foo\"})"), - ({1.0: "foo"}, "containers.Map({'1.0'}, {\"foo\"})"), - ({None: "foo"}, "containers.Map({'None'}, {\"foo\"})"), - ({True: "foo"}, "containers.Map({'True'}, {\"foo\"})"), - ({"foo": "bar"}, "containers.Map({'foo'}, {\"bar\"})"), - ({"foo": '"bar"'}, 'containers.Map({\'foo\'}, {"""bar"""})'), - ({"foo": ["bar"]}, "containers.Map({'foo'}, {{\"bar\"}})"), + ({1: 'foo'}, 'containers.Map({\'1\'}, {"foo"})'), + ({1.0: 'foo'}, 'containers.Map({\'1.0\'}, {"foo"})'), + ({None: 'foo'}, 'containers.Map({\'None\'}, {"foo"})'), + ({True: 'foo'}, 'containers.Map({\'True\'}, {"foo"})'), + ({'foo': 'bar'}, 'containers.Map({\'foo\'}, {"bar"})'), + ({'foo': '"bar"'}, 'containers.Map({\'foo\'}, {"""bar"""})'), + ({'foo': ['bar']}, 'containers.Map({\'foo\'}, {{"bar"}})'), ( - {"foo": {"bar": "baz"}}, + {'foo': {'bar': 'baz'}}, "containers.Map({'foo'}, {containers.Map({'bar'}, {\"baz\"})})", ), ( - {"foo": {"bar": '"baz"'}}, + {'foo': {'bar': '"baz"'}}, 'containers.Map({\'foo\'}, {containers.Map({\'bar\'}, {"""baz"""})})', ), - (["foo"], '{"foo"}'), - (["foo", '"bar"'], '{"foo", """bar"""}'), - ([{"foo": "bar"}], "{containers.Map({'foo'}, {\"bar\"})}"), - ([{"foo": '"bar"'}], '{containers.Map({\'foo\'}, {"""bar"""})}'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (True, "true"), - (False, "false"), - (None, "NaN"), + (['foo'], '{"foo"}'), + (['foo', '"bar"'], '{"foo", """bar"""}'), + ([{'foo': 'bar'}], '{containers.Map({\'foo\'}, {"bar"})}'), + ([{'foo': '"bar"'}], '{containers.Map({\'foo\'}, {"""bar"""})}'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, 'true'), + (False, 'false'), + (None, 'NaN'), ], ) def test_translate_type_matlab(test_input, expected): @@ -557,19 +556,19 @@ def test_translate_type_matlab(test_input, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '% Parameters\nfoo = "bar";\n'), - ({"foo": True}, "% Parameters\nfoo = true;\n"), - ({"foo": 5}, "% Parameters\nfoo = 5;\n"), - ({"foo": 1.1}, "% Parameters\nfoo = 1.1;\n"), - ({"foo": ["bar", "baz"]}, '% Parameters\nfoo = {"bar", "baz"};\n'), + ({'foo': 'bar'}, '% Parameters\nfoo = "bar";\n'), + ({'foo': True}, '% Parameters\nfoo = true;\n'), + ({'foo': 5}, '% Parameters\nfoo = 5;\n'), + ({'foo': 1.1}, '% Parameters\nfoo = 1.1;\n'), + ({'foo': ['bar', 'baz']}, '% Parameters\nfoo = {"bar", "baz"};\n'), ( - {"foo": {"bar": "baz"}}, - "% Parameters\nfoo = containers.Map({'bar'}, {\"baz\"});\n", + {'foo': {'bar': 'baz'}}, + '% Parameters\nfoo = containers.Map({\'bar\'}, {"baz"});\n', ), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '% Parameters\nfoo = "bar";\nbaz = {"buz"};\n', ), ], @@ -579,8 +578,8 @@ def test_translate_codify_matlab(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "%"), ("foo", "% foo"), ("['best effort']", "% ['best effort']")], + 'test_input,expected', + [('', '%'), ('foo', '% foo'), ("['best effort']", "% ['best effort']")], ) def test_translate_comment_matlab(test_input, expected): assert translators.MatlabTranslator.comment(test_input) == expected @@ -589,44 +588,32 @@ def test_translate_comment_matlab(test_input, expected): def test_find_translator_with_exact_kernel_name(): my_new_kernel_translator = Mock() my_new_language_translator = Mock() - translators.papermill_translators.register( - "my_new_kernel", my_new_kernel_translator - ) - translators.papermill_translators.register( - "my_new_language", my_new_language_translator - ) + translators.papermill_translators.register('my_new_kernel', my_new_kernel_translator) + translators.papermill_translators.register('my_new_language', my_new_language_translator) assert ( - translators.papermill_translators.find_translator( - "my_new_kernel", "my_new_language" - ) + translators.papermill_translators.find_translator('my_new_kernel', 'my_new_language') is my_new_kernel_translator ) def test_find_translator_with_exact_language(): my_new_language_translator = Mock() - translators.papermill_translators.register( - "my_new_language", my_new_language_translator - ) + translators.papermill_translators.register('my_new_language', my_new_language_translator) assert ( - translators.papermill_translators.find_translator( - "unregistered_kernel", "my_new_language" - ) + translators.papermill_translators.find_translator('unregistered_kernel', 'my_new_language') is my_new_language_translator ) def test_find_translator_with_no_such_kernel_or_language(): with pytest.raises(PapermillException): - translators.papermill_translators.find_translator( - "unregistered_kernel", "unregistered_language" - ) + translators.papermill_translators.find_translator('unregistered_kernel', 'unregistered_language') def test_translate_uses_str_representation_of_unknown_types(): class FooClass: def __str__(self): - return "foo" + return 'foo' obj = FooClass() assert translators.Translator.translate(obj) == '"foo"' @@ -637,7 +624,7 @@ class MyNewTranslator(translators.Translator): pass with pytest.raises(NotImplementedError): - MyNewTranslator.translate_dict({"foo": "bar"}) + MyNewTranslator.translate_dict({'foo': 'bar'}) def test_translator_must_implement_translate_list(): @@ -645,7 +632,7 @@ class MyNewTranslator(translators.Translator): pass with pytest.raises(NotImplementedError): - MyNewTranslator.translate_list(["foo", "bar"]) + MyNewTranslator.translate_list(['foo', 'bar']) def test_translator_must_implement_comment(): @@ -653,24 +640,24 @@ class MyNewTranslator(translators.Translator): pass with pytest.raises(NotImplementedError): - MyNewTranslator.comment("foo") + MyNewTranslator.comment('foo') # Bash/sh section @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", "foo"), - ("foo space", "'foo space'"), + ('foo', 'foo'), + ('foo space', "'foo space'"), ("foo's apostrophe", "'foo'\"'\"'s apostrophe'"), - ("shell ( is ) ", "'shell ( is ) '"), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (True, "true"), - (False, "false"), - (None, ""), + ('shell ( is ) ', "'shell ( is ) '"), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, 'true'), + (False, 'false'), + (None, ''), ], ) def test_translate_type_sh(test_input, expected): @@ -678,23 +665,23 @@ def test_translate_type_sh(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")], ) def test_translate_comment_sh(test_input, expected): assert translators.BashTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, "# Parameters\nfoo=bar\n"), - ({"foo": "shell ( is ) "}, "# Parameters\nfoo='shell ( is ) '\n"), - ({"foo": True}, "# Parameters\nfoo=true\n"), - ({"foo": 5}, "# Parameters\nfoo=5\n"), - ({"foo": 1.1}, "# Parameters\nfoo=1.1\n"), + ({'foo': 'bar'}, '# Parameters\nfoo=bar\n'), + ({'foo': 'shell ( is ) '}, "# Parameters\nfoo='shell ( is ) '\n"), + ({'foo': True}, '# Parameters\nfoo=true\n'), + ({'foo': 5}, '# Parameters\nfoo=5\n'), + ({'foo': 1.1}, '# Parameters\nfoo=1.1\n'), ( - OrderedDict([["foo", "bar"], ["baz", "$dumb(shell)"]]), + OrderedDict([['foo', 'bar'], ['baz', '$dumb(shell)']]), "# Parameters\nfoo=bar\nbaz='$dumb(shell)'\n", ), ], diff --git a/papermill/tests/test_utils.py b/papermill/tests/test_utils.py index 519fa383..4d058fb2 100644 --- a/papermill/tests/test_utils.py +++ b/papermill/tests/test_utils.py @@ -1,59 +1,53 @@ -import pytest import warnings - -from unittest.mock import Mock, call -from tempfile import TemporaryDirectory from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import Mock, call -from nbformat.v4 import new_notebook, new_code_cell +import pytest +from nbformat.v4 import new_code_cell, new_notebook +from ..exceptions import PapermillParameterOverwriteWarning from ..utils import ( any_tagged_cell, - retry, chdir, merge_kwargs, remove_args, + retry, ) -from ..exceptions import PapermillParameterOverwriteWarning def test_no_tagged_cell(): nb = new_notebook( - cells=[new_code_cell("a = 2", metadata={"tags": []})], + cells=[new_code_cell('a = 2', metadata={'tags': []})], ) - assert not any_tagged_cell(nb, "parameters") + assert not any_tagged_cell(nb, 'parameters') def test_tagged_cell(): nb = new_notebook( - cells=[new_code_cell("a = 2", metadata={"tags": ["parameters"]})], + cells=[new_code_cell('a = 2', metadata={'tags': ['parameters']})], ) - assert any_tagged_cell(nb, "parameters") + assert any_tagged_cell(nb, 'parameters') def test_merge_kwargs(): with warnings.catch_warnings(record=True) as wrn: - assert merge_kwargs({"a": 1, "b": 2}, a=3) == {"a": 3, "b": 2} + assert merge_kwargs({'a': 1, 'b': 2}, a=3) == {'a': 3, 'b': 2} assert len(wrn) == 1 assert issubclass(wrn[0].category, PapermillParameterOverwriteWarning) - assert ( - wrn[0].message.__str__() - == "Callee will overwrite caller's argument(s): a=3" - ) + assert wrn[0].message.__str__() == "Callee will overwrite caller's argument(s): a=3" def test_remove_args(): - assert remove_args(["a"], a=1, b=2, c=3) == {"c": 3, "b": 2} + assert remove_args(['a'], a=1, b=2, c=3) == {'c': 3, 'b': 2} def test_retry(): - m = Mock( - side_effect=RuntimeError(), __name__="m", __module__="test_s3", __doc__="m" - ) + m = Mock(side_effect=RuntimeError(), __name__='m', __module__='test_s3', __doc__='m') wrapped_m = retry(3)(m) with pytest.raises(RuntimeError): - wrapped_m("foo") - m.assert_has_calls([call("foo"), call("foo"), call("foo")]) + wrapped_m('foo') + m.assert_has_calls([call('foo'), call('foo'), call('foo')]) def test_chdir(): diff --git a/papermill/translators.py b/papermill/translators.py index ace316bf..0086f84f 100644 --- a/papermill/translators.py +++ b/papermill/translators.py @@ -6,7 +6,6 @@ from .exceptions import PapermillException from .models import Parameter - logger = logging.getLogger(__name__) @@ -29,9 +28,7 @@ def find_translator(self, kernel_name, language): elif language in self._translators: return self._translators[language] raise PapermillException( - "No parameter translator functions specified for kernel '{}' or language '{}'".format( - kernel_name, language - ) + f"No parameter translator functions specified for kernel '{kernel_name}' or language '{language}'" ) @@ -39,15 +36,15 @@ class Translator: @classmethod def translate_raw_str(cls, val): """Reusable by most interpreters""" - return f"{val}" + return f'{val}' @classmethod def translate_escaped_str(cls, str_val): """Reusable by most interpreters""" if isinstance(str_val, str): - str_val = str_val.encode("unicode_escape") - str_val = str_val.decode("utf-8") - str_val = str_val.replace('"', r"\"") + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') + str_val = str_val.replace('"', r'\"') return f'"{str_val}"' @classmethod @@ -73,15 +70,15 @@ def translate_float(cls, val): @classmethod def translate_bool(cls, val): """Default behavior for translation""" - return "true" if val else "false" + return 'true' if val else 'false' @classmethod def translate_dict(cls, val): - raise NotImplementedError(f"dict type translation not implemented for {cls}") + raise NotImplementedError(f'dict type translation not implemented for {cls}') @classmethod def translate_list(cls, val): - raise NotImplementedError(f"list type translation not implemented for {cls}") + raise NotImplementedError(f'list type translation not implemented for {cls}') @classmethod def translate(cls, val): @@ -106,17 +103,17 @@ def translate(cls, val): @classmethod def comment(cls, cmt_str): - raise NotImplementedError(f"comment translation not implemented for {cls}") + raise NotImplementedError(f'comment translation not implemented for {cls}') @classmethod def assign(cls, name, str_val): - return f"{name} = {str_val}" + return f'{name} = {str_val}' @classmethod - def codify(cls, parameters, comment="Parameters"): - content = f"{cls.comment(comment)}\n" + def codify(cls, parameters, comment='Parameters'): + content = f'{cls.comment(comment)}\n' for name, val in parameters.items(): - content += f"{cls.assign(name, cls.translate(val))}\n" + content += f'{cls.assign(name, cls.translate(val))}\n' return content @classmethod @@ -140,7 +137,7 @@ def inspect(cls, parameters_cell): List[Parameter] A list of all parameters """ - raise NotImplementedError(f"parameters introspection not implemented for {cls}") + raise NotImplementedError(f'parameters introspection not implemented for {cls}') class PythonTranslator(Translator): @@ -166,22 +163,20 @@ def translate_bool(cls, val): @classmethod def translate_dict(cls, val): - escaped = ", ".join( - [f"{cls.translate_str(k)}: {cls.translate(v)}" for k, v in val.items()] - ) - return f"{{{escaped}}}" + escaped = ', '.join([f'{cls.translate_str(k)}: {cls.translate(v)}' for k, v in val.items()]) + return f'{{{escaped}}}' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"[{escaped}]" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'[{escaped}]' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() @classmethod - def codify(cls, parameters, comment="Parameters"): + def codify(cls, parameters, comment='Parameters'): content = super().codify(parameters, comment) try: # Put content through the Black Python code formatter @@ -192,7 +187,7 @@ def codify(cls, parameters, comment="Parameters"): except ImportError: logger.debug("Black is not installed, parameters won't be formatted") except AttributeError as aerr: - logger.warning(f"Black encountered an error, skipping formatting ({aerr})") + logger.warning(f'Black encountered an error, skipping formatting ({aerr})') return content @classmethod @@ -213,7 +208,7 @@ def inspect(cls, parameters_cell): A list of all parameters """ params = [] - src = parameters_cell["source"] + src = parameters_cell['source'] def flatten_accumulator(accumulator): """Flatten a multilines variable definition. @@ -225,10 +220,10 @@ def flatten_accumulator(accumulator): Returns: Flatten definition """ - flat_string = "" + flat_string = '' for line in accumulator[:-1]: - if "#" in line: - comment_pos = line.index("#") + if '#' in line: + comment_pos = line.index('#') flat_string += line[:comment_pos].strip() else: flat_string += line.strip() @@ -244,10 +239,10 @@ def flatten_accumulator(accumulator): grouped_variable = [] accumulator = [] for iline, line in enumerate(src.splitlines()): - if len(line.strip()) == 0 or line.strip().startswith("#"): + if len(line.strip()) == 0 or line.strip().startswith('#'): continue # Skip blank and comment - nequal = line.count("=") + nequal = line.count('=') if nequal > 0: grouped_variable.append(flatten_accumulator(accumulator)) accumulator = [] @@ -265,16 +260,16 @@ def flatten_accumulator(accumulator): match = re.match(cls.PARAMETER_PATTERN, definition) if match is not None: attr = match.groupdict() - if attr["target"] is None: # Fail to get variable name + if attr['target'] is None: # Fail to get variable name continue - type_name = str(attr["annotation"] or attr["type_comment"] or None) + type_name = str(attr['annotation'] or attr['type_comment'] or None) params.append( Parameter( - name=attr["target"].strip(), + name=attr['target'].strip(), inferred_type_name=type_name.strip(), - default=str(attr["value"]).strip(), - help=str(attr["help"] or "").strip(), + default=str(attr['value']).strip(), + help=str(attr['help'] or '').strip(), ) ) @@ -284,85 +279,79 @@ def flatten_accumulator(accumulator): class RTranslator(Translator): @classmethod def translate_none(cls, val): - return "NULL" + return 'NULL' @classmethod def translate_bool(cls, val): - return "TRUE" if val else "FALSE" + return 'TRUE' if val else 'FALSE' @classmethod def translate_dict(cls, val): - escaped = ", ".join( - [f"{cls.translate_str(k)} = {cls.translate(v)}" for k, v in val.items()] - ) - return f"list({escaped})" + escaped = ', '.join([f'{cls.translate_str(k)} = {cls.translate(v)}' for k, v in val.items()]) + return f'list({escaped})' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"list({escaped})" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'list({escaped})' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): # Leading '_' aren't legal R variable names -- so we drop them when injecting - while name.startswith("_"): + while name.startswith('_'): name = name[1:] - return f"{name} = {str_val}" + return f'{name} = {str_val}' class ScalaTranslator(Translator): @classmethod def translate_int(cls, val): strval = cls.translate_raw_str(val) - return strval + "L" if (val > 2147483647 or val < -2147483648) else strval + return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval @classmethod def translate_dict(cls, val): """Translate dicts to scala Maps""" - escaped = ", ".join( - [f"{cls.translate_str(k)} -> {cls.translate(v)}" for k, v in val.items()] - ) - return f"Map({escaped})" + escaped = ', '.join([f'{cls.translate_str(k)} -> {cls.translate(v)}' for k, v in val.items()]) + return f'Map({escaped})' @classmethod def translate_list(cls, val): """Translate list to scala Seq""" - escaped = ", ".join([cls.translate(v) for v in val]) - return f"Seq({escaped})" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'Seq({escaped})' @classmethod def comment(cls, cmt_str): - return f"// {cmt_str}".strip() + return f'// {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): - return f"val {name} = {str_val}" + return f'val {name} = {str_val}' class JuliaTranslator(Translator): @classmethod def translate_none(cls, val): - return "nothing" + return 'nothing' @classmethod def translate_dict(cls, val): - escaped = ", ".join( - [f"{cls.translate_str(k)} => {cls.translate(v)}" for k, v in val.items()] - ) - return f"Dict({escaped})" + escaped = ', '.join([f'{cls.translate_str(k)} => {cls.translate(v)}' for k, v in val.items()]) + return f'Dict({escaped})' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"[{escaped}]" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'[{escaped}]' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() class MatlabTranslator(Translator): @@ -370,8 +359,8 @@ class MatlabTranslator(Translator): def translate_escaped_str(cls, str_val): """Translate a string to an escaped Matlab string""" if isinstance(str_val, str): - str_val = str_val.encode("unicode_escape") - str_val = str_val.decode("utf-8") + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') str_val = str_val.replace('"', '""') return f'"{str_val}"' @@ -379,35 +368,35 @@ def translate_escaped_str(cls, str_val): def __translate_char_array(str_val): """Translates a string to a Matlab char array""" if isinstance(str_val, str): - str_val = str_val.encode("unicode_escape") - str_val = str_val.decode("utf-8") + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') str_val = str_val.replace("'", "''") return f"'{str_val}'" @classmethod def translate_none(cls, val): - return "NaN" + return 'NaN' @classmethod def translate_dict(cls, val): - keys = ", ".join([f"{cls.__translate_char_array(k)}" for k, v in val.items()]) - vals = ", ".join([f"{cls.translate(v)}" for k, v in val.items()]) - return f"containers.Map({{{keys}}}, {{{vals}}})" + keys = ', '.join([f'{cls.__translate_char_array(k)}' for k, v in val.items()]) + vals = ', '.join([f'{cls.translate(v)}' for k, v in val.items()]) + return f'containers.Map({{{keys}}}, {{{vals}}})' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"{{{escaped}}}" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'{{{escaped}}}' @classmethod def comment(cls, cmt_str): - return f"% {cmt_str}".strip() + return f'% {cmt_str}'.strip() @classmethod - def codify(cls, parameters, comment="Parameters"): - content = f"{cls.comment(comment)}\n" + def codify(cls, parameters, comment='Parameters'): + content = f'{cls.comment(comment)}\n' for name, val in parameters.items(): - content += f"{cls.assign(name, cls.translate(val))};\n" + content += f'{cls.assign(name, cls.translate(val))};\n' return content @@ -415,80 +404,70 @@ class CSharpTranslator(Translator): @classmethod def translate_none(cls, val): # Can't figure out how to do this as nullable - raise NotImplementedError("Option type not implemented for C#.") + raise NotImplementedError('Option type not implemented for C#.') @classmethod def translate_bool(cls, val): - return "true" if val else "false" + return 'true' if val else 'false' @classmethod def translate_int(cls, val): strval = cls.translate_raw_str(val) - return strval + "L" if (val > 2147483647 or val < -2147483648) else strval + return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval @classmethod def translate_dict(cls, val): """Translate dicts to nontyped dictionary""" - kvps = ", ".join( - [ - f"{{ {cls.translate_str(k)} , {cls.translate(v)} }}" - for k, v in val.items() - ] - ) - return f"new Dictionary{{ {kvps} }}" + kvps = ', '.join([f'{{ {cls.translate_str(k)} , {cls.translate(v)} }}' for k, v in val.items()]) + return f'new Dictionary{{ {kvps} }}' @classmethod def translate_list(cls, val): """Translate list to array""" - escaped = ", ".join([cls.translate(v) for v in val]) - return f"new [] {{ {escaped} }}" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'new [] {{ {escaped} }}' @classmethod def comment(cls, cmt_str): - return f"// {cmt_str}".strip() + return f'// {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): - return f"var {name} = {str_val};" + return f'var {name} = {str_val};' class FSharpTranslator(Translator): @classmethod def translate_none(cls, val): - return "None" + return 'None' @classmethod def translate_bool(cls, val): - return "true" if val else "false" + return 'true' if val else 'false' @classmethod def translate_int(cls, val): strval = cls.translate_raw_str(val) - return strval + "L" if (val > 2147483647 or val < -2147483648) else strval + return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval @classmethod def translate_dict(cls, val): - tuples = "; ".join( - [ - f"({cls.translate_str(k)}, {cls.translate(v)} :> IComparable)" - for k, v in val.items() - ] - ) - return f"[ {tuples} ] |> Map.ofList" + tuples = '; '.join([f'({cls.translate_str(k)}, {cls.translate(v)} :> IComparable)' for k, v in val.items()]) + return f'[ {tuples} ] |> Map.ofList' @classmethod def translate_list(cls, val): - escaped = "; ".join([cls.translate(v) for v in val]) - return f"[ {escaped} ]" + escaped = '; '.join([cls.translate(v) for v in val]) + return f'[ {escaped} ]' @classmethod def comment(cls, cmt_str): - return f"(* {cmt_str} *)".strip() + return f'(* {cmt_str} *)'.strip() @classmethod def assign(cls, name, str_val): - return f"let {name} = {str_val}" + return f'let {name} = {str_val}' class PowershellTranslator(Translator): @@ -496,8 +475,8 @@ class PowershellTranslator(Translator): def translate_escaped_str(cls, str_val): """Translate a string to an escaped Matlab string""" if isinstance(str_val, str): - str_val = str_val.encode("unicode_escape") - str_val = str_val.decode("utf-8") + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') str_val = str_val.replace('"', '`"') return f'"{str_val}"' @@ -506,49 +485,47 @@ def translate_float(cls, val): if math.isfinite(val): return cls.translate_raw_str(val) elif math.isnan(val): - return "[double]::NaN" + return '[double]::NaN' elif val < 0: - return "[double]::NegativeInfinity" + return '[double]::NegativeInfinity' else: - return "[double]::PositiveInfinity" + return '[double]::PositiveInfinity' @classmethod def translate_none(cls, val): - return "$Null" + return '$Null' @classmethod def translate_bool(cls, val): - return "$True" if val else "$False" + return '$True' if val else '$False' @classmethod def translate_dict(cls, val): - kvps = "\n ".join( - [f"{cls.translate_str(k)} = {cls.translate(v)}" for k, v in val.items()] - ) - return f"@{{{kvps}}}" + kvps = '\n '.join([f'{cls.translate_str(k)} = {cls.translate(v)}' for k, v in val.items()]) + return f'@{{{kvps}}}' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"@({escaped})" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'@({escaped})' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): - return f"${name} = {str_val}" + return f'${name} = {str_val}' class BashTranslator(Translator): @classmethod def translate_none(cls, val): - return "" + return '' @classmethod def translate_bool(cls, val): - return "true" if val else "false" + return 'true' if val else 'false' @classmethod def translate_escaped_str(cls, str_val): @@ -556,35 +533,33 @@ def translate_escaped_str(cls, str_val): @classmethod def translate_list(cls, val): - escaped = " ".join([cls.translate(v) for v in val]) - return f"({escaped})" + escaped = ' '.join([cls.translate(v) for v in val]) + return f'({escaped})' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): - return f"{name}={str_val}" + return f'{name}={str_val}' # Instantiate a PapermillIO instance and register Handlers. papermill_translators = PapermillTranslators() -papermill_translators.register("python", PythonTranslator) -papermill_translators.register("R", RTranslator) -papermill_translators.register("scala", ScalaTranslator) -papermill_translators.register("julia", JuliaTranslator) -papermill_translators.register("matlab", MatlabTranslator) -papermill_translators.register(".net-csharp", CSharpTranslator) -papermill_translators.register(".net-fsharp", FSharpTranslator) -papermill_translators.register(".net-powershell", PowershellTranslator) -papermill_translators.register("pysparkkernel", PythonTranslator) -papermill_translators.register("sparkkernel", ScalaTranslator) -papermill_translators.register("sparkrkernel", RTranslator) -papermill_translators.register("bash", BashTranslator) - - -def translate_parameters(kernel_name, language, parameters, comment="Parameters"): - return papermill_translators.find_translator(kernel_name, language).codify( - parameters, comment - ) +papermill_translators.register('python', PythonTranslator) +papermill_translators.register('R', RTranslator) +papermill_translators.register('scala', ScalaTranslator) +papermill_translators.register('julia', JuliaTranslator) +papermill_translators.register('matlab', MatlabTranslator) +papermill_translators.register('.net-csharp', CSharpTranslator) +papermill_translators.register('.net-fsharp', FSharpTranslator) +papermill_translators.register('.net-powershell', PowershellTranslator) +papermill_translators.register('pysparkkernel', PythonTranslator) +papermill_translators.register('sparkkernel', ScalaTranslator) +papermill_translators.register('sparkrkernel', RTranslator) +papermill_translators.register('bash', BashTranslator) + + +def translate_parameters(kernel_name, language, parameters, comment='Parameters'): + return papermill_translators.find_translator(kernel_name, language).codify(parameters, comment) diff --git a/papermill/utils.py b/papermill/utils.py index 532a5a43..e69b710a 100644 --- a/papermill/utils.py +++ b/papermill/utils.py @@ -1,13 +1,12 @@ -import os import logging +import os import warnings - from contextlib import contextmanager from functools import wraps from .exceptions import PapermillParameterOverwriteWarning -logger = logging.getLogger("papermill.utils") +logger = logging.getLogger('papermill.utils') def any_tagged_cell(nb, tag): @@ -48,9 +47,9 @@ def nb_kernel_name(nb, name=None): ValueError If no kernel name is found or provided """ - name = name or nb.metadata.get("kernelspec", {}).get("name") + name = name or nb.metadata.get('kernelspec', {}).get('name') if not name: - raise ValueError("No kernel name found in notebook and no override provided.") + raise ValueError('No kernel name found in notebook and no override provided.') return name @@ -74,12 +73,12 @@ def nb_language(nb, language=None): ValueError If no notebook language is found or provided """ - language = language or nb.metadata.get("language_info", {}).get("name") + language = language or nb.metadata.get('language_info', {}).get('name') if not language: # v3 language path for old notebooks that didn't convert cleanly - language = language or nb.metadata.get("kernelspec", {}).get("language") + language = language or nb.metadata.get('kernelspec', {}).get('language') if not language: - raise ValueError("No language found in notebook and no override provided.") + raise ValueError('No language found in notebook and no override provided.') return language @@ -128,9 +127,7 @@ def merge_kwargs(caller_args, **callee_args): """ conflicts = set(caller_args) & set(callee_args) if conflicts: - args = format( - "; ".join([f"{key}={value}" for key, value in callee_args.items()]) - ) + args = format('; '.join([f'{key}={value}' for key, value in callee_args.items()])) msg = f"Callee will overwrite caller's argument(s): {args}" warnings.warn(msg, PapermillParameterOverwriteWarning) return dict(caller_args, **callee_args) @@ -167,7 +164,7 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - logger.debug(f"Retrying after: {e}") + logger.debug(f'Retrying after: {e}') exception = e else: raise exception diff --git a/papermill/version.py b/papermill/version.py index 824cbf24..3d98bc1d 100644 --- a/papermill/version.py +++ b/papermill/version.py @@ -1 +1 @@ -version = "2.5.0" +version = '2.5.0'