diff --git a/CHANGELOG.md b/CHANGELOG.md index 23ddce49..1fbfc4f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased ### Added +- Add helper function to list CivisML models. (#314) - Allow the base URL of the CLI to be configured through the `CIVIS_API_ENDPOINT` environment variable, like the civis Python module. (#312) - Allow the CLI log level to be configured with the `CIVIS_LOG_LEVEL` diff --git a/civis/ml/__init__.py b/civis/ml/__init__.py index 1a0d1fb5..ab7895fe 100644 --- a/civis/ml/__init__.py +++ b/civis/ml/__init__.py @@ -1,3 +1,4 @@ """Machine learning in Civis """ from civis.ml._model import * # NOQA +from civis.ml._helper import * # NOQA diff --git a/civis/ml/_helper.py b/civis/ml/_helper.py new file mode 100644 index 00000000..dcfdc9dc --- /dev/null +++ b/civis/ml/_helper.py @@ -0,0 +1,60 @@ +from collections import namedtuple + +from civis import APIClient +from civis.ml._model import _PRED_TEMPLATES + +__all__ = ['list_models'] + +# sentinel value for default author value +SENTINEL = namedtuple('Sentinel', [])() + + +def list_models(job_type="train", author=SENTINEL, client=None, **kwargs): + """List a user's CivisML models. + + Parameters + ---------- + job_type : {"train", "predict", None} + The type of model job to list. If "train", list training jobs + only (including registered models trained outside of CivisML). + If "predict", list prediction jobs only. If None, list both. + author : int, optional + User id of the user whose models you want to list. Defaults to + the current user. Use ``None`` to list models from all users. + client : :class:`civis.APIClient`, optional + If not provided, an :class:`civis.APIClient` object will be + created from the :envvar:`CIVIS_API_KEY`. + **kwargs : kwargs + Extra keyword arguments passed to `client.scripts.list_custom()` + + See Also + -------- + APIClient.scripts.list_custom + """ + if job_type == "train": + template_id_list = list(_PRED_TEMPLATES.keys()) + elif job_type == "predict": + # get a unique list of prediction ids + template_id_list = list(set(_PRED_TEMPLATES.values())) + elif job_type is None: + # use sets to make sure there's no duplicate ids + template_id_list = list(set(_PRED_TEMPLATES.keys()).union( + set(_PRED_TEMPLATES.values()))) + else: + raise ValueError("Parameter 'job_type' must be None, 'train', " + "or 'predict'.") + template_id_str = ', '.join([str(tmp) for tmp in template_id_list]) + + if client is None: + client = APIClient() + + if author is SENTINEL: + author = client.users.list_me().id + + # default to showing most recent models first + kwargs.setdefault('order_dir', 'desc') + + models = client.scripts.list_custom(from_template_id=template_id_str, + author=author, + **kwargs) + return models diff --git a/civis/ml/tests/test_helper.py b/civis/ml/tests/test_helper.py new file mode 100644 index 00000000..7bc040ad --- /dev/null +++ b/civis/ml/tests/test_helper.py @@ -0,0 +1,24 @@ +import pytest + +from civis.response import Response +from civis.ml import list_models +from civis.tests.mocks import create_client_mock + + +def test_list_models_bad_job_type(): + with pytest.raises(ValueError): + list_models(job_type="fake") + + +def test_list_models(): + resp = [Response({'id': 2834, 'name': 'RFC model'})] + m_client = create_client_mock() + m_client.scripts.list_custom.return_value = resp + out = list_models(job_type='train', client=m_client) + assert out == resp + + out = list_models(job_type='predict', client=m_client) + assert out == resp + + out = list_models(job_type=None, client=m_client) + assert out == resp