Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH helper function to list models #314

Merged
merged 6 commits into from
Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased
### Added
- Add helper function to list CivisML models. (#314)
- Allow the base URL of the CLI to be configured through the
`CIVIS_API_ENDPOINT` environment variable, like the civis Python module. (#312)
- Allow the CLI log level to be configured with the `CIVIS_LOG_LEVEL`
Expand Down
1 change: 1 addition & 0 deletions civis/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Machine learning in Civis
"""
from civis.ml._model import * # NOQA
from civis.ml._helper import * # NOQA
60 changes: 60 additions & 0 deletions civis/ml/_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from collections import namedtuple

from civis import APIClient
from civis.ml._model import _PRED_TEMPLATES

__all__ = ['list_models']

# sentinel value for default author value
SENTINEL = namedtuple('Sentinel', [])()


stephen-hoover marked this conversation as resolved.
Show resolved Hide resolved
def list_models(job_type="train", author=SENTINEL, client=None, **kwargs):
"""List a user's CivisML models.

Parameters
----------
job_type : {"train", "predict", None}
The type of model job to list. If "train", list training jobs
only (including registered models trained outside of CivisML).
If "predict", list prediction jobs only. If None, list both.
author : int, optional
User id of the user whose models you want to list. Defaults to
the current user. Use ``None`` to list models from all users.
client : :class:`civis.APIClient`, optional
If not provided, an :class:`civis.APIClient` object will be
created from the :envvar:`CIVIS_API_KEY`.
**kwargs : kwargs
Extra keyword arguments passed to `client.scripts.list_custom()`

See Also
--------
APIClient.scripts.list_custom
"""
stephen-hoover marked this conversation as resolved.
Show resolved Hide resolved
if job_type == "train":
template_id_list = list(_PRED_TEMPLATES.keys())
elif job_type == "predict":
# get a unique list of prediction ids
template_id_list = list(set(_PRED_TEMPLATES.values()))
elif job_type is None:
# use sets to make sure there's no duplicate ids
template_id_list = list(set(_PRED_TEMPLATES.keys()).union(
set(_PRED_TEMPLATES.values())))
else:
raise ValueError("Parameter 'job_type' must be None, 'train', "
"or 'predict'.")
template_id_str = ', '.join([str(tmp) for tmp in template_id_list])

if client is None:
client = APIClient()

if author is SENTINEL:
author = client.users.list_me().id

# default to showing most recent models first
kwargs.setdefault('order_dir', 'desc')

models = client.scripts.list_custom(from_template_id=template_id_str,
author=author,
**kwargs)
return models
24 changes: 24 additions & 0 deletions civis/ml/tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from civis.response import Response
from civis.ml import list_models
from civis.tests.mocks import create_client_mock


def test_list_models_bad_job_type():
with pytest.raises(ValueError):
list_models(job_type="fake")


def test_list_models():
resp = [Response({'id': 2834, 'name': 'RFC model'})]
m_client = create_client_mock()
m_client.scripts.list_custom.return_value = resp
out = list_models(job_type='train', client=m_client)
assert out == resp

out = list_models(job_type='predict', client=m_client)
assert out == resp

out = list_models(job_type=None, client=m_client)
assert out == resp