Skip to content

Commit

Permalink
ENH helper function to list models (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
Elizabeth Sander committed Aug 2, 2019
1 parent f881fb1 commit 84ecf67
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased
### Added
- Add helper function to list CivisML models. (#314)
- Allow the base URL of the CLI to be configured through the
`CIVIS_API_ENDPOINT` environment variable, like the civis Python module. (#312)
- Allow the CLI log level to be configured with the `CIVIS_LOG_LEVEL`
Expand Down
1 change: 1 addition & 0 deletions civis/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Machine learning in Civis
"""
from civis.ml._model import * # NOQA
from civis.ml._helper import * # NOQA
60 changes: 60 additions & 0 deletions civis/ml/_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from collections import namedtuple

from civis import APIClient
from civis.ml._model import _PRED_TEMPLATES

__all__ = ['list_models']

# sentinel value for default author value
SENTINEL = namedtuple('Sentinel', [])()


def list_models(job_type="train", author=SENTINEL, client=None, **kwargs):
"""List a user's CivisML models.
Parameters
----------
job_type : {"train", "predict", None}
The type of model job to list. If "train", list training jobs
only (including registered models trained outside of CivisML).
If "predict", list prediction jobs only. If None, list both.
author : int, optional
User id of the user whose models you want to list. Defaults to
the current user. Use ``None`` to list models from all users.
client : :class:`civis.APIClient`, optional
If not provided, an :class:`civis.APIClient` object will be
created from the :envvar:`CIVIS_API_KEY`.
**kwargs : kwargs
Extra keyword arguments passed to `client.scripts.list_custom()`
See Also
--------
APIClient.scripts.list_custom
"""
if job_type == "train":
template_id_list = list(_PRED_TEMPLATES.keys())
elif job_type == "predict":
# get a unique list of prediction ids
template_id_list = list(set(_PRED_TEMPLATES.values()))
elif job_type is None:
# use sets to make sure there's no duplicate ids
template_id_list = list(set(_PRED_TEMPLATES.keys()).union(
set(_PRED_TEMPLATES.values())))
else:
raise ValueError("Parameter 'job_type' must be None, 'train', "
"or 'predict'.")
template_id_str = ', '.join([str(tmp) for tmp in template_id_list])

if client is None:
client = APIClient()

if author is SENTINEL:
author = client.users.list_me().id

# default to showing most recent models first
kwargs.setdefault('order_dir', 'desc')

models = client.scripts.list_custom(from_template_id=template_id_str,
author=author,
**kwargs)
return models
24 changes: 24 additions & 0 deletions civis/ml/tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from civis.response import Response
from civis.ml import list_models
from civis.tests.mocks import create_client_mock


def test_list_models_bad_job_type():
with pytest.raises(ValueError):
list_models(job_type="fake")


def test_list_models():
resp = [Response({'id': 2834, 'name': 'RFC model'})]
m_client = create_client_mock()
m_client.scripts.list_custom.return_value = resp
out = list_models(job_type='train', client=m_client)
assert out == resp

out = list_models(job_type='predict', client=m_client)
assert out == resp

out = list_models(job_type=None, client=m_client)
assert out == resp

0 comments on commit 84ecf67

Please sign in to comment.