Skip to content

Commit

Permalink
bring back plot_importance for compatibility (#274)
Browse files Browse the repository at this point in the history
* bring back plot_importance for compatibility

* allow overwriting image in project.set_function()
  • Loading branch information
yaronha committed May 23, 2020
1 parent d1a9d36 commit 0d69e4e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mlrun/mlutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
roc_multi,
roc_bin,
precision_recall_bin,
plot_roc)
plot_roc,
plot_importance)

from .data import (get_sample,
get_splits,
Expand Down
41 changes: 41 additions & 0 deletions mlrun/mlutils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,47 @@ def feature_importances(model, header):
TableArtifact("feature-importances-tbl", df=feature_imp))


def plot_importance(
context,
model,
key: str = "feature-importances",
plots_dest: str = "plots"
):
"""Display estimated feature importances
Only works for models with attribute 'feature_importances_`
**legacy version please deprecate in functions and demos**
:param context: function context
:param model: fitted model
:param key: key of feature importances plot and table in artifact
store
:param plots_dest: subfolder in artifact store
"""
if not hasattr(model, "feature_importances_"):
raise Exception(
"feature importaces are only available for some models")

# create a feature importance table with desired labels
zipped = zip(model.feature_importances_, context.header)
feature_imp = pd.DataFrame(sorted(zipped), columns=["freq", "feature"]).sort_values(
by="freq", ascending=False
)

gcf_clear(plt)
plt.figure(figsize=(20, 10))
sns.barplot(x="freq", y="feature", data=feature_imp)
plt.title("features")
plt.tight_layout()

fname = f"{plots_dest}/{key}.html"
context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)

# feature importances are also saved as a csv table (generally small):
fname = key + "-tbl.csv"
return context.log_artifact(TableArtifact(key + "-tbl", df=feature_imp), local_path=fname)


def learning_curves(model):
"""model class dependent
Expand Down
8 changes: 7 additions & 1 deletion mlrun/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,11 @@ def set_function(self, func, name='', kind='', image=None, with_repo=None):
name, f = _init_function_from_dict(func, self)
elif hasattr(func, 'to_dict'):
name, f = _init_function_from_obj(func, self, name=name)
if image:
f.spec.image = image
if with_repo:
f.spec.build.source = './'

if not name:
raise ValueError('function name must be specified')
else:
Expand Down Expand Up @@ -690,6 +695,8 @@ def _init_function_from_dict(f, project):
elif url.endswith('.yaml') or url.startswith('db://') \
or url.startswith('hub://'):
func = import_function(url)
if image:
func.spec.image = image
elif url.endswith('.ipynb'):
func = code_to_function(name, filename=url, image=image, kind=kind)
elif url.endswith('.py'):
Expand Down Expand Up @@ -723,7 +730,6 @@ def _init_function_from_obj(func, project, name=None):
func.metadata.project = project.name
if project.tag:
func.metadata.tag = project.tag

return name or func.metadata.name, func


Expand Down

0 comments on commit 0d69e4e

Please sign in to comment.