From cbef9f0e8c3cbb6b0ed46e3186b17d7c73890965 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Sat, 13 Jun 2020 22:41:25 +0530 Subject: [PATCH] Fix faulty prerun logic in notebook_loader (#28) * removed faulty prerun logic * renamed to notebook_loader for clarity --- testbook/client.py | 2 +- testbook/{execute.py => notebook_loader.py} | 12 ++++----- testbook/tests/test_execute.py | 27 +++++++++++++-------- 3 files changed, 23 insertions(+), 18 deletions(-) rename testbook/{execute.py => notebook_loader.py} (83%) diff --git a/testbook/client.py b/testbook/client.py index f5b7468..e51cb60 100644 --- a/testbook/client.py +++ b/testbook/client.py @@ -73,7 +73,7 @@ def cell_output_text(self, cell): if 'text' in output: text += output['text'] - return text + return text.strip() def inject(self, code, args=None, prerun=None): """Injects given function and executes with arguments passed diff --git a/testbook/execute.py b/testbook/notebook_loader.py similarity index 83% rename from testbook/execute.py rename to testbook/notebook_loader.py index a9d18c3..75b4d7e 100644 --- a/testbook/execute.py +++ b/testbook/notebook_loader.py @@ -13,13 +13,7 @@ def __init__(self, nb_path, prerun=None): with open(self.nb_path) as f: nb = nbformat.read(f, as_version=4) - client = TestbookNotebookClient(nb) - - if self.prerun is not None: - with client.setup_kernel(): - client.execute_cell(self.prerun) - - self.client = client + self.client = TestbookNotebookClient(nb) def _start_kernel(self): if self.client.km is None: @@ -30,6 +24,8 @@ def _start_kernel(self): def __enter__(self): self._start_kernel() + if self.prerun is not None: + self.client.execute_cell(self.prerun) return self.client def __exit__(self, *args): @@ -38,6 +34,8 @@ def __exit__(self, *args): def __call__(self, func): def wrapper(*args, **kwargs): with self.client.setup_kernel(): + if self.prerun is not None: + self.client.execute_cell(self.prerun) func(self.client, *args, **kwargs) wrapper.__name__ = func.__name__ diff --git a/testbook/tests/test_execute.py b/testbook/tests/test_execute.py index 703802c..6ef31d7 100644 --- a/testbook/tests/test_execute.py +++ b/testbook/tests/test_execute.py @@ -4,30 +4,37 @@ def test_execute_cell(): with notebook_loader('testbook/tests/resources/foo.ipynb') as notebook: notebook.execute_cell(1) - assert notebook.cell_output_text(1) == 'hello world\n[1, 2, 3]\n' + assert notebook.cell_output_text(1) == 'hello world\n[1, 2, 3]' notebook.execute_cell([2, 3]) - assert notebook.cell_output_text(3) == 'foo\n' + assert notebook.cell_output_text(3) == 'foo' def test_execute_cell_tags(): with notebook_loader('testbook/tests/resources/foo.ipynb') as notebook: notebook.execute_cell('test1') - assert notebook.cell_output_text('test1') == 'hello world\n[1, 2, 3]\n' + assert notebook.cell_output_text('test1') == 'hello world\n[1, 2, 3]' notebook.execute_cell(['prepare_foo', 'execute_foo']) - assert notebook.cell_output_text('execute_foo') == 'foo\n' + assert notebook.cell_output_text('execute_foo') == 'foo' @notebook_loader("testbook/tests/resources/foo.ipynb") -def test_notebook(notebook): +def test_notebook_loader(notebook): notebook.execute_cell('test1') - assert notebook.cell_output_text('test1') == 'hello world\n[1, 2, 3]\n' + assert notebook.cell_output_text('test1') == 'hello world\n[1, 2, 3]' notebook.execute_cell(['prepare_foo', 'execute_foo']) - assert notebook.cell_output_text('execute_foo') == 'foo\n' + assert notebook.cell_output_text('execute_foo') == 'foo' -@notebook_loader("testbook/tests/resources/foo.ipynb", prerun='test1') -def test_notebook_with_prerun(notebook): - assert notebook.cell_output_text(1) == 'hello world\n[1, 2, 3]\n' +@notebook_loader("testbook/tests/resources/foo.ipynb", prerun='prepare_foo') +def test_notebook_loader_with_prerun(notebook): + notebook.execute_cell('execute_foo') + assert notebook.cell_output_text('execute_foo') == 'foo' + + +def test_notebook_loader_with_prerun_context_manager(): + with notebook_loader("testbook/tests/resources/foo.ipynb", prerun='prepare_foo') as notebook: + notebook.execute_cell('execute_foo') + assert notebook.cell_output_text('execute_foo') == 'foo'