Skip to content

Commit

Permalink
parsing injected-parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
edublancas committed Apr 29, 2021
1 parent d9525d2 commit b91ba62
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 25 deletions.
17 changes: 10 additions & 7 deletions src/sklearn_evaluation/nb/NotebookIntrospector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,25 @@ def _parse_output(output, literal_eval, to_df, text_only):
to_df=to_df)


def find_cell_with_tag(cells):
def find_cell_with_tag(cells, tag):
for cell in cells:
if ('metadata' in cell and 'tags' in cell['metadata']
and 'parameters' in cell['metadata']['tags']):
and tag in cell['metadata']['tags']):
return cell


def parse_parameters_cell(cells):
def parse_injected_parameters_cell(cells):
# this is a very simple implementation, for a more robust solution
# re-implement with ast or parso
cell = find_cell_with_tag(cells)
cell = find_cell_with_tag(cells, tag='injected-parameters')

if not cell:
return dict()

tuples = [line.split('=') for line in cell['source'].splitlines()]
tuples = [
line.split('=') for line in cell['source'].splitlines()
if not line.startswith('#')
]

return {t[0].strip(): ast.literal_eval(t[1].strip()) for t in tuples}

Expand Down Expand Up @@ -166,5 +169,5 @@ def to_json_serializable(self):
for k, v in self.tag2output_raw.items()
}

def get_params(self):
return parse_parameters_cell(self.nb.cells)
def get_injected_parameters(self):
return parse_injected_parameters_cell(self.nb.cells)
28 changes: 17 additions & 11 deletions tests/nb/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest


def save_and_execute_notebook(nb_str, path):
def save_notebook(nb_str, path, execute=True):
nb = jupytext.reads(nb_str, fmt='py:light')
nb.metadata['kernelspec'] = {
'name': 'python3',
Expand All @@ -13,7 +13,10 @@ def save_and_execute_notebook(nb_str, path):
}

nbformat.write(nb, path)
pm.execute_notebook(str(path), str(path))

if execute:
pm.execute_notebook(str(path), str(path))

return str(path)


Expand All @@ -32,7 +35,7 @@ def nb_literals():
dict_ = {'x': 1, 'y': 2}
dict_
"""
save_and_execute_notebook(content, 'nb_literals.ipynb')
save_notebook(content, 'nb_literals.ipynb')


@pytest.fixture
Expand All @@ -50,7 +53,7 @@ def nb_other_literals():
dict_ = {'x': 2, 'y': 3}
dict_
"""
save_and_execute_notebook(content, 'nb_other_literals.ipynb')
save_notebook(content, 'nb_other_literals.ipynb')


@pytest.fixture
Expand All @@ -64,7 +67,7 @@ def nb_plot():
# + tags=["b"]
42
"""
save_and_execute_notebook(content, 'nb_plot.ipynb')
save_notebook(content, 'nb_plot.ipynb')


@pytest.fixture
Expand All @@ -78,7 +81,7 @@ def nb_table():
# + tags=["b"]
42
"""
save_and_execute_notebook(content, 'nb_table.ipynb')
save_notebook(content, 'nb_table.ipynb')


@pytest.fixture
Expand All @@ -89,7 +92,7 @@ def nb_no_output():
# + tags=["a"]
x = 1
"""
save_and_execute_notebook(content, 'nb_no_output.ipynb')
save_notebook(content, 'nb_no_output.ipynb')


@pytest.fixture
Expand All @@ -100,17 +103,20 @@ def nb_invalid_output():
# + tags=["numpy_array"]
np.array([1, 2, 3])
"""
return save_and_execute_notebook(content, 'nb_invalid_output.ipynb')
return save_notebook(content, 'nb_invalid_output.ipynb')


@pytest.fixture
def nb_parameters():
def nb_injected_parameters():
content = """
import numpy as np
# + tags=["parameters"]
# + tags=["injected-parameters"]
# Parameters
x = 1
y = [1, 2]
z = {'a': 1, 'b': 2}
"""
return save_and_execute_notebook(content, 'nb_parameters.ipynb')
return save_notebook(content,
'nb_injected_parameters.ipynb',
execute=False)
11 changes: 4 additions & 7 deletions tests/nb/test_NotebookIntrospector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,12 @@ def test_json_serializable(tmp_directory, nb_plot):
assert d['b'] == 42


def test_get_params(tmp_directory, nb_parameters):
d = NotebookIntrospector('nb_parameters.ipynb')
assert d.get_params() == {
def test_get_injected_parameters(tmp_directory, nb_injected_parameters):
d = NotebookIntrospector('nb_injected_parameters.ipynb')

assert d.get_injected_parameters() == {
'x': 1,
'y': [1, 2],
'z': {
'a': 1,
'b': 2
},
'z': {
'a': 1,
'b': 2
Expand Down

0 comments on commit b91ba62

Please sign in to comment.