From 666641981be0337c99af7022a079a02033caf73b Mon Sep 17 00:00:00 2001 From: Liz Sander Date: Fri, 2 Aug 2019 13:45:11 -0500 Subject: [PATCH 1/6] ENH helper function to list models --- CHANGELOG.md | 1 + civis/ml/__init__.py | 1 + civis/ml/_helper.py | 40 +++++++++++++++++++++++++++++++++++ civis/ml/tests/test_helper.py | 28 ++++++++++++++++++++++++ 4 files changed, 70 insertions(+) create mode 100644 civis/ml/_helper.py create mode 100644 civis/ml/tests/test_helper.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 23ddce49..07b9d150 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased ### Added +- Add helper function to list CivisML models. (#313) - Allow the base URL of the CLI to be configured through the `CIVIS_API_ENDPOINT` environment variable, like the civis Python module. (#312) - Allow the CLI log level to be configured with the `CIVIS_LOG_LEVEL` diff --git a/civis/ml/__init__.py b/civis/ml/__init__.py index 1a0d1fb5..ab7895fe 100644 --- a/civis/ml/__init__.py +++ b/civis/ml/__init__.py @@ -1,3 +1,4 @@ """Machine learning in Civis """ from civis.ml._model import * # NOQA +from civis.ml._helper import * # NOQA diff --git a/civis/ml/_helper.py b/civis/ml/_helper.py new file mode 100644 index 00000000..21c514c7 --- /dev/null +++ b/civis/ml/_helper.py @@ -0,0 +1,40 @@ +from civis import APIClient +from civis.ml._model import _PRED_TEMPLATES + + +def list_models(job_type=None, client=None, **kwargs): + """List the current user's CivisML models. + + Parameters + ---------- + job_type : {None, "train", "predict"} + The type of model job to list. If "train", list training jobs + only (including registered models trained outside of CivisML). + If "predict", list prediction jobs only. If None, list both. + client : :class:`civis.APIClient`, optional + If not provided, an :class:`civis.APIClient` object will be + created from the :envvar:`CIVIS_API_KEY`. + **kwargs : kwargs + Extra keyword arguments passed to `client.scripts.list_custom()` + """ + if job_type == "train": + template_id_list = list(_PRED_TEMPLATES.keys()) + elif job_type == "predict": + # get a unique list of prediction ids + template_id_list = list(set(_PRED_TEMPLATES.values())) + elif job_type is None: + # use sets to make sure there's no duplicate ids + template_id_list = list(set(_PRED_TEMPLATES.keys()).union( + set(_PRED_TEMPLATES.values()))) + else: + raise ValueError("Parameter 'job_type' must be None, 'train', " + "or 'predict'.") + template_id_str = ', '.join([str(tmp) for tmp in template_id_list]) + + if client is None: + client = APIClient() + + models = client.scripts.list_custom(from_template_id=template_id_str, + author=client.users.list_me().id, + **kwargs) + return models diff --git a/civis/ml/tests/test_helper.py b/civis/ml/tests/test_helper.py new file mode 100644 index 00000000..a8434362 --- /dev/null +++ b/civis/ml/tests/test_helper.py @@ -0,0 +1,28 @@ +import mock +import pytest + +from civis.response import Response +from civis.ml import list_models + +TRAIN_TEMPLATES = '11219, 11221, 10582, 9968, 9112, 8387, 7020' +PRED_TEMPLATES = '11220, 10583, 9969, 9113, 8388, 7021' + + +def test_list_models_bad_job_type(): + with pytest.raises(ValueError): + list_models(job_type="fake") + + +def test_list_models(): + resp = [Response({'id': 2834, 'name': 'RFC model'})] + m_client = mock.Mock() + m_client.scripts.list_custom.return_value = resp + m_client.users.list_me.return_value = Response({'id': 2834}) + out = list_models(job_type='train', client=m_client) + assert out == resp + + out = list_models(job_type='predict', client=m_client) + assert out == resp + + out = list_models(job_type=None, client=m_client) + assert out == resp From 0f6c2a1a73d1affd17f4ba3143df95717e36709c Mon Sep 17 00:00:00 2001 From: Liz Sander Date: Fri, 2 Aug 2019 13:51:57 -0500 Subject: [PATCH 2/6] STY remove stray constants --- civis/ml/tests/test_helper.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/civis/ml/tests/test_helper.py b/civis/ml/tests/test_helper.py index a8434362..b3f31566 100644 --- a/civis/ml/tests/test_helper.py +++ b/civis/ml/tests/test_helper.py @@ -4,9 +4,6 @@ from civis.response import Response from civis.ml import list_models -TRAIN_TEMPLATES = '11219, 11221, 10582, 9968, 9112, 8387, 7020' -PRED_TEMPLATES = '11220, 10583, 9969, 9113, 8388, 7021' - def test_list_models_bad_job_type(): with pytest.raises(ValueError): From 32775176224ddb64c30ff0c653ca30cb792de6b9 Mon Sep 17 00:00:00 2001 From: Liz Sander Date: Fri, 2 Aug 2019 16:31:52 -0500 Subject: [PATCH 3/6] ENH code review changes --- civis/ml/_helper.py | 23 ++++++++++++++++++++--- civis/ml/tests/test_helper.py | 5 ++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/civis/ml/_helper.py b/civis/ml/_helper.py index 21c514c7..ff3269aa 100644 --- a/civis/ml/_helper.py +++ b/civis/ml/_helper.py @@ -1,9 +1,16 @@ +from collections import namedtuple + from civis import APIClient from civis.ml._model import _PRED_TEMPLATES +__all__ = ['list_models'] + +# sentinel value for default author value +SENTINEL = namedtuple('Sentinel', [])() + -def list_models(job_type=None, client=None, **kwargs): - """List the current user's CivisML models. +def list_models(job_type="train", author=SENTINEL, client=None, **kwargs): + """List a user's CivisML models. Parameters ---------- @@ -11,11 +18,18 @@ def list_models(job_type=None, client=None, **kwargs): The type of model job to list. If "train", list training jobs only (including registered models trained outside of CivisML). If "predict", list prediction jobs only. If None, list both. + author : int, optional + User id of the user whose models you want to list. Defaults to + the current user. Use ``None`` to list models from all users. client : :class:`civis.APIClient`, optional If not provided, an :class:`civis.APIClient` object will be created from the :envvar:`CIVIS_API_KEY`. **kwargs : kwargs Extra keyword arguments passed to `client.scripts.list_custom()` + + See Also + -------- + APIClient.scripts.list_custom """ if job_type == "train": template_id_list = list(_PRED_TEMPLATES.keys()) @@ -34,7 +48,10 @@ def list_models(job_type=None, client=None, **kwargs): if client is None: client = APIClient() + if author is SENTINEL: + author = client.users.list_me().id + models = client.scripts.list_custom(from_template_id=template_id_str, - author=client.users.list_me().id, + author=author, **kwargs) return models diff --git a/civis/ml/tests/test_helper.py b/civis/ml/tests/test_helper.py index b3f31566..7bc040ad 100644 --- a/civis/ml/tests/test_helper.py +++ b/civis/ml/tests/test_helper.py @@ -1,8 +1,8 @@ -import mock import pytest from civis.response import Response from civis.ml import list_models +from civis.tests.mocks import create_client_mock def test_list_models_bad_job_type(): @@ -12,9 +12,8 @@ def test_list_models_bad_job_type(): def test_list_models(): resp = [Response({'id': 2834, 'name': 'RFC model'})] - m_client = mock.Mock() + m_client = create_client_mock() m_client.scripts.list_custom.return_value = resp - m_client.users.list_me.return_value = Response({'id': 2834}) out = list_models(job_type='train', client=m_client) assert out == resp From a9c2bdd51e1d07aee9dc7a34f634c1e85f5e4875 Mon Sep 17 00:00:00 2001 From: Liz Sander Date: Fri, 2 Aug 2019 16:47:47 -0500 Subject: [PATCH 4/6] ENH order_dir desc by default --- civis/ml/_helper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/civis/ml/_helper.py b/civis/ml/_helper.py index ff3269aa..a2b02f9d 100644 --- a/civis/ml/_helper.py +++ b/civis/ml/_helper.py @@ -51,6 +51,9 @@ def list_models(job_type="train", author=SENTINEL, client=None, **kwargs): if author is SENTINEL: author = client.users.list_me().id + # default to showing most recent models first + kwargs.setdefault('order_dir', 'desc') + models = client.scripts.list_custom(from_template_id=template_id_str, author=author, **kwargs) From 55fbebc05a8cabedc150990ec782196cac732fb4 Mon Sep 17 00:00:00 2001 From: Liz Sander Date: Fri, 2 Aug 2019 16:53:07 -0500 Subject: [PATCH 5/6] DOC update docstring --- civis/ml/_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/civis/ml/_helper.py b/civis/ml/_helper.py index a2b02f9d..dcfdc9dc 100644 --- a/civis/ml/_helper.py +++ b/civis/ml/_helper.py @@ -14,7 +14,7 @@ def list_models(job_type="train", author=SENTINEL, client=None, **kwargs): Parameters ---------- - job_type : {None, "train", "predict"} + job_type : {"train", "predict", None} The type of model job to list. If "train", list training jobs only (including registered models trained outside of CivisML). If "predict", list prediction jobs only. If None, list both. From 7179ffef1abb7668708072057e42a4bc1446e14d Mon Sep 17 00:00:00 2001 From: Liz Sander Date: Fri, 2 Aug 2019 16:54:38 -0500 Subject: [PATCH 6/6] DOC fix PR number in changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 07b9d150..1fbfc4f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased ### Added -- Add helper function to list CivisML models. (#313) +- Add helper function to list CivisML models. (#314) - Allow the base URL of the CLI to be configured through the `CIVIS_API_ENDPOINT` environment variable, like the civis Python module. (#312) - Allow the CLI log level to be configured with the `CIVIS_LOG_LEVEL`