Skip to content

Commit

Permalink
adds first implementation of NotebookIntrospector.get_params()
Browse files Browse the repository at this point in the history
  • Loading branch information
edublancas committed Apr 29, 2021
1 parent 77ff2da commit d9525d2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/sklearn_evaluation/nb/NotebookIntrospector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,26 @@ def _parse_output(output, literal_eval, to_df, text_only):
to_df=to_df)


def find_cell_with_tag(cells):
for cell in cells:
if ('metadata' in cell and 'tags' in cell['metadata']
and 'parameters' in cell['metadata']['tags']):
return cell


def parse_parameters_cell(cells):
# this is a very simple implementation, for a more robust solution
# re-implement with ast or parso
cell = find_cell_with_tag(cells)

if not cell:
return dict()

tuples = [line.split('=') for line in cell['source'].splitlines()]

return {t[0].strip(): ast.literal_eval(t[1].strip()) for t in tuples}


class NotebookIntrospector(Mapping):
"""Retrieve output from a notebook file with tagged cells.
Expand Down Expand Up @@ -145,3 +165,6 @@ def to_json_serializable(self):
text_only=True)
for k, v in self.tag2output_raw.items()
}

def get_params(self):
return parse_parameters_cell(self.nb.cells)
13 changes: 13 additions & 0 deletions tests/nb/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,16 @@ def nb_invalid_output():
np.array([1, 2, 3])
"""
return save_and_execute_notebook(content, 'nb_invalid_output.ipynb')


@pytest.fixture
def nb_parameters():
content = """
import numpy as np
# + tags=["parameters"]
x = 1
y = [1, 2]
z = {'a': 1, 'b': 2}
"""
return save_and_execute_notebook(content, 'nb_parameters.ipynb')
18 changes: 18 additions & 0 deletions tests/nb/test_NotebookIntrospector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ast

from sklearn_evaluation import NotebookIntrospector
from IPython.display import HTML, Image

Expand Down Expand Up @@ -50,3 +52,19 @@ def test_json_serializable(tmp_directory, nb_plot):
# "<Figure size YxZ with 1 Axes>")
assert 'Figure size' in d['a']
assert d['b'] == 42


def test_get_params(tmp_directory, nb_parameters):
d = NotebookIntrospector('nb_parameters.ipynb')
assert d.get_params() == {
'x': 1,
'y': [1, 2],
'z': {
'a': 1,
'b': 2
},
'z': {
'a': 1,
'b': 2
}
}

0 comments on commit d9525d2

Please sign in to comment.