Skip to content

Commit

Permalink
fixes doc building
Browse files Browse the repository at this point in the history
  • Loading branch information
edublancas committed Dec 23, 2020
1 parent 78af03b commit bfc41e9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
61 changes: 52 additions & 9 deletions docs/source/hooks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,50 @@
from os import environ
from pathlib import Path

import jupytext
import nbformat
from ploomber import DAG
from ploomber.tasks import NotebookRunner
from ploomber.products import File
from ploomber.constants import TaskStatus


def make_task(dag, path_in, path_out):
nb = jupytext.read(path_in)

fmt = nbformat.versions[nbformat.current_nbformat]
nb.cells.append(fmt.new_code_cell(metadata=dict(tags=['parameters'])))

name = Path(path_in).name.split('.')[0]
path_preprocessed = Path(path_in).parent / (name + '-preprocessed.ipynb')
nbformat.write(nb, path_preprocessed)

NotebookRunner(Path(path_preprocessed),
File(path_out),
dag,
kernelspec_name='python3',
name=name,
local_execution=True)


def remove_with_tag(nb, tag):
idx = None

for i, cell in enumerate(nb.cells):
if tag in cell.metadata.tags:
idx = i
break

nb.cells.pop(idx)


def post_process_nb(path):
nb = jupytext.read(path)

remove_with_tag(nb, 'injected-parameters')
remove_with_tag(nb, 'parameters')

jupytext.write(nb, path)


def config_init(app, config):
Expand All @@ -15,15 +57,16 @@ def config_init(app, config):

dag = DAG()

NotebookRunner(base_path / 'nbs/SQLiteTracker.md',
File(base_path / 'user_guide/SQLiteTracker.ipynb'),
dag=dag,
kernelspec_name='python3')
make_task(dag, base_path / 'nbs/SQLiteTracker.md',
base_path / 'user_guide/SQLiteTracker.ipynb')

NotebookRunner(base_path / 'nbs/NotebookCollection.py',
File(base_path / 'user_guide/NotebookCollection.ipynb'),
dag=dag,
kernelspec_name='python3',
local_execution=True)
make_task(dag, base_path / 'nbs/NotebookCollection.py',
base_path / 'user_guide/NotebookCollection.ipynb')

dag.build()

for task_name in dag:
task = dag[task_name]

if task.exec_status == TaskStatus.Executed:
post_process_nb(str(task.product))
6 changes: 4 additions & 2 deletions docs/source/nbs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# + tags=["parameters"]
model = 'sklearn.ensemble.RandomForestRegressor'
params = {'min_samples_leaf': 1, 'n_estimators': 50}
params = {'min_samples_leaf': 1, 'n_estimators': 50}

# + tags=["model_name"]
model
Expand All @@ -40,7 +40,7 @@
class_ = getattr(importlib.import_module(module), name)

# + tags=["feature_names"]
d['feature_names']
list(d['feature_names'])
# -

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
Expand Down Expand Up @@ -85,9 +85,11 @@
error_river = df.groupby('CHAS')[['error_abs', 'error_sq']].mean()
error_river.columns = ['mae', 'mse']


def r2_score(df):
return metrics.r2_score(df.y_true, df.y_pred)


r2 = pd.DataFrame(df.groupby('CHAS').apply(r2_score))
r2.columns = ['r2']

Expand Down
2 changes: 1 addition & 1 deletion src/sklearn_evaluation/nb/NotebookIntrospector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _safe_literal_eval(source, to_df=False, none_if_error=False):

return result

except SyntaxError:
except (SyntaxError, ValueError):
return None if none_if_error else source


Expand Down

0 comments on commit bfc41e9

Please sign in to comment.