From 03a464fb2b384d69c67d95dce3acbb0e912c273c Mon Sep 17 00:00:00 2001 From: Roman Rader Date: Sat, 23 Mar 2019 01:09:39 +0200 Subject: [PATCH] [PRED-2176] Add optional prediction threshold field to out csv (#155) * PRED-2176 add unit test for no prediction threshold in result * PRED-2176 fix fixture for binary classification in unit tests * PRED-2176 fix help text for column * [PRED-2176] Fix tests --- .../api_response_handlers/pred_api_v10.py | 38 ++++++- datarobot_batch_scoring/batch_scoring.py | 3 +- datarobot_batch_scoring/exceptions.py | 4 + datarobot_batch_scoring/main.py | 7 ++ datarobot_batch_scoring/writer.py | 15 ++- tests/fixtures/temperatura.json | 10 ++ .../temperatura_output_healthy_threshold.csv | 101 ++++++++++++++++++ tests/test_api_response_handlers.py | 58 +++++++++- tests/test_conf_file.py | 4 + tests/test_functional.py | 61 +++++++++++ tests/test_main.py | 64 +++++++++++ tests/test_writer.py | 3 + 12 files changed, 359 insertions(+), 9 deletions(-) create mode 100644 tests/fixtures/temperatura_output_healthy_threshold.csv diff --git a/datarobot_batch_scoring/api_response_handlers/pred_api_v10.py b/datarobot_batch_scoring/api_response_handlers/pred_api_v10.py index 2e440161..ef494a54 100644 --- a/datarobot_batch_scoring/api_response_handlers/pred_api_v10.py +++ b/datarobot_batch_scoring/api_response_handlers/pred_api_v10.py @@ -3,7 +3,8 @@ import json from six.moves import zip -from datarobot_batch_scoring.exceptions import UnexpectedKeptColumnCount +from datarobot_batch_scoring.exceptions import UnexpectedKeptColumnCount, \ + NoPredictionThresholdInResult def row_id_field(result_sorted, batch): @@ -150,6 +151,26 @@ def rows_generator(): return headers, rows_generator() +def pred_threshold_field(result_sorted, pred_threshold_name): + """ Generate prediction threshold field (for classification) + Parameters + ---------- + result_sorted : list[dict] + list of results sorted by rowId + pred_threshold_name : str + column name which should contain prediction threshold + Returns + ------- + header: list[str] + row_generator: iterator + """ + + return [pred_threshold_name], ( + [row.get('predictionThreshold')] + for row in result_sorted + ) + + def pred_decision_field(result_sorted, pred_decision): """ Generate prediction decision field @@ -177,7 +198,7 @@ def format_data(result, batch, **opts): Parameters ---------- - result : list + result : list[dict] list of results batch batch information @@ -192,6 +213,7 @@ def format_data(result, batch, **opts): list of rows """ pred_name = opts.get('pred_name') + pred_threshold_name = opts.get('pred_threshold_name') pred_decision_name = opts.get('pred_decision_name') keep_cols = opts.get('keep_cols') skip_row_id = opts.get('skip_row_id') @@ -240,7 +262,17 @@ def _find_prediction_explanations_key(): ) ) - # Threshold and thresholded decision field ('prediction' value from result) + # Threshold field for classification + # ('predictionThreshold' value from result) + if pred_threshold_name: + if 'predictionThreshold' in single_row: + fields.append( + pred_threshold_field(result_sorted, pred_threshold_name) + ) + else: + raise NoPredictionThresholdInResult() + + # Thresholded decision field ('prediction' value from result) if pred_decision_name: fields.append(pred_decision_field(result_sorted, pred_decision_name)) diff --git a/datarobot_batch_scoring/batch_scoring.py b/datarobot_batch_scoring/batch_scoring.py index 1d3ab8a7..a76e259c 100644 --- a/datarobot_batch_scoring/batch_scoring.py +++ b/datarobot_batch_scoring/batch_scoring.py @@ -77,6 +77,7 @@ def run_batch_predictions(base_url, base_headers, user, pwd, verify_ssl=True, deployment_id=None, max_prediction_explanations=0, + pred_threshold_name=None, pred_decision_name=None): if field_size_limit is not None: @@ -201,7 +202,7 @@ def run_batch_predictions(base_url, base_headers, user, pwd, lid, keep_cols, n_retry, delimiter, dataset, pred_name, ui, fast_mode, encoding, skip_row_id, output_delimiter, - pred_decision_name)) + pred_threshold_name, pred_decision_name)) n_batches_checkpointed_init = len(ctx.db['checkpoints']) ui.debug('number of batches checkpointed initially: {}' diff --git a/datarobot_batch_scoring/exceptions.py b/datarobot_batch_scoring/exceptions.py index 07c4ca1a..a5f91e85 100644 --- a/datarobot_batch_scoring/exceptions.py +++ b/datarobot_batch_scoring/exceptions.py @@ -4,3 +4,7 @@ class ShelveError(Exception): class UnexpectedKeptColumnCount(Exception): pass + + +class NoPredictionThresholdInResult(Exception): + pass diff --git a/datarobot_batch_scoring/main.py b/datarobot_batch_scoring/main.py index 08e6862c..e4defcf1 100644 --- a/datarobot_batch_scoring/main.py +++ b/datarobot_batch_scoring/main.py @@ -213,6 +213,11 @@ def parse_args(argv, standalone=False, deployment_aware=False): 'empty name is used if not specified. For binary ' 'predictions assumes last class in lexical order ' 'as positive') + csv_gr.add_argument('--pred_threshold', type=str, + nargs='?', default=None, + help='Specifies column name for prediction threshold ' + 'for binary classification. Column will not be ' + 'included if not specified') csv_gr.add_argument('--pred_decision', type=str, nargs='?', default=None, help='Specifies column name for prediction decision, ' @@ -306,6 +311,7 @@ def parse_generic_options(parsed_args): skip_row_id = parsed_args['skip_row_id'] field_size_limit = parsed_args.get('field_size_limit') pred_name = parsed_args.get('pred_name') + pred_threshold_name = parsed_args.get('pred_threshold') pred_decision_name = parsed_args.get('pred_decision') dry_run = parsed_args.get('dry_run', False) @@ -363,6 +369,7 @@ def parse_generic_options(parsed_args): 'out_file': out_file, 'output_delimiter': output_delimiter, 'pred_name': pred_name, + 'pred_threshold_name': pred_threshold_name, 'pred_decision_name': pred_decision_name, 'resume': resume, 'skip_dialect': skip_dialect, diff --git a/datarobot_batch_scoring/writer.py b/datarobot_batch_scoring/writer.py index 34df8693..6cd48cac 100644 --- a/datarobot_batch_scoring/writer.py +++ b/datarobot_batch_scoring/writer.py @@ -17,7 +17,7 @@ WriterQueueMsg, ProgressQueueMsg, REPORT_INTERVAL from datarobot_batch_scoring.utils import get_rusage from datarobot_batch_scoring.exceptions import ShelveError, \ - UnexpectedKeptColumnCount + UnexpectedKeptColumnCount, NoPredictionThresholdInResult if six.PY3: import dbm.dumb as dumb_dbm @@ -38,7 +38,7 @@ class RunContext(object): def __init__(self, n_samples, out_file, pid, lid, keep_cols, n_retry, delimiter, dataset, pred_name, ui, file_context, fast_mode, encoding, skip_row_id, output_delimiter, - pred_decision_name): + pred_threshold_name, pred_decision_name): self.n_samples = n_samples self.out_file = out_file self.project_id = pid @@ -48,6 +48,7 @@ def __init__(self, n_samples, out_file, pid, lid, keep_cols, self.delimiter = delimiter self.dataset = dataset self.pred_name = pred_name + self.pred_threshold_name = pred_threshold_name self.pred_decision_name = pred_decision_name self.out_stream = None self._ui = ui @@ -68,7 +69,7 @@ def create(cls, resume, n_samples, out_file, pid, lid, keep_cols, n_retry, delimiter, dataset, pred_name, ui, fast_mode, encoding, skip_row_id, output_delimiter, - pred_decision_name): + pred_threshold_name, pred_decision_name): """Factory method for run contexts. Either resume or start a new one. @@ -84,7 +85,7 @@ def create(cls, resume, n_samples, out_file, pid, lid, return ctx_class(n_samples, out_file, pid, lid, keep_cols, n_retry, delimiter, dataset, pred_name, ui, file_context, fast_mode, encoding, skip_row_id, output_delimiter, - pred_decision_name) + pred_threshold_name, pred_decision_name) def __enter__(self): assert(not self.is_open) @@ -433,6 +434,7 @@ def process_response(self): written_fields, comb = format_data( data, batch, pred_name=self.ctx.pred_name, + pred_threshold_name=self.ctx.pred_threshold_name, pred_decision_name=self.ctx.pred_decision_name, keep_cols=self.ctx.keep_cols, skip_row_id=self.ctx.skip_row_id, @@ -443,6 +445,11 @@ def process_response(self): 'retrieved. This can happen in ' + '--fast mode with --keep_cols where ' + 'some cells contain quoted delimiters') + except NoPredictionThresholdInResult: + self._ui.fatal('No predictionThreshold returned from ' + 'API. --pred_threshold should be used ' + 'only for binary classification ' + 'predictions') except Exception as e: self._ui.fatal(e) diff --git a/tests/fixtures/temperatura.json b/tests/fixtures/temperatura.json index 046173b7..d6725a44 100644 --- a/tests/fixtures/temperatura.json +++ b/tests/fixtures/temperatura.json @@ -9,6 +9,7 @@ "label":0.0 }], "prediction":0.0, + "predictionThreshold":0.5, "rowId":0 }, { @@ -20,6 +21,7 @@ "label":0.0 }], "prediction":0.0, + "predictionThreshold":0.5, "rowId":1 }, { @@ -31,6 +33,7 @@ "label":0.0 }], "prediction":0.0, + "predictionThreshold":0.5, "rowId":2 }, { @@ -39,6 +42,7 @@ {"value":1.0,"label":0.0} ], "prediction":0.0, + "predictionThreshold":0.5, "rowId":3 },{ "predictionValues":[ @@ -46,6 +50,7 @@ {"value":1.0,"label":0.0} ], "prediction":0.0, + "predictionThreshold":0.5, "rowId":4 }, { @@ -55,6 +60,7 @@ "value":1.0,"label":0.0 }], "prediction":0.0, + "predictionThreshold":0.5, "rowId":5 },{ "predictionValues":[{ @@ -63,12 +69,14 @@ "value":0.0,"label":0.0 }], "prediction":1.0, + "predictionThreshold":0.5, "rowId":6 },{ "predictionValues":[{ "value":0.0,"label":1.0 },{"value":1.0,"label":0.0}], "prediction":0.0, + "predictionThreshold":0.5, "rowId":7 },{ "predictionValues":[{ @@ -77,6 +85,7 @@ "value":1.0,"label":0.0 }], "prediction":0.0, + "predictionThreshold":0.5, "rowId":8 },{ "predictionValues":[{ @@ -85,6 +94,7 @@ "value":1.0,"label":0.0 }], "prediction":0.0, + "predictionThreshold":0.5, "rowId":9 }] } \ No newline at end of file diff --git a/tests/fixtures/temperatura_output_healthy_threshold.csv b/tests/fixtures/temperatura_output_healthy_threshold.csv new file mode 100644 index 00000000..c4f8e267 --- /dev/null +++ b/tests/fixtures/temperatura_output_healthy_threshold.csv @@ -0,0 +1,101 @@ +row_id,healthy,threshold +0,0.0,0.5 +1,0.0,0.5 +2,0.0,0.5 +3,0.0,0.5 +4,0.0,0.5 +5,0.0,0.5 +6,1.0,0.5 +7,0.0,0.5 +8,0.0,0.5 +9,0.0,0.5 +10,0.0,0.5 +11,0.0,0.5 +12,0.0,0.5 +13,0.0,0.5 +14,0.0,0.5 +15,0.0,0.5 +16,1.0,0.5 +17,0.0,0.5 +18,0.0,0.5 +19,0.0,0.5 +20,0.0,0.5 +21,0.0,0.5 +22,0.0,0.5 +23,0.0,0.5 +24,0.0,0.5 +25,0.0,0.5 +26,1.0,0.5 +27,0.0,0.5 +28,0.0,0.5 +29,0.0,0.5 +30,0.0,0.5 +31,0.0,0.5 +32,0.0,0.5 +33,0.0,0.5 +34,0.0,0.5 +35,0.0,0.5 +36,1.0,0.5 +37,0.0,0.5 +38,0.0,0.5 +39,0.0,0.5 +40,0.0,0.5 +41,0.0,0.5 +42,0.0,0.5 +43,0.0,0.5 +44,0.0,0.5 +45,0.0,0.5 +46,1.0,0.5 +47,0.0,0.5 +48,0.0,0.5 +49,0.0,0.5 +50,0.0,0.5 +51,0.0,0.5 +52,0.0,0.5 +53,0.0,0.5 +54,0.0,0.5 +55,0.0,0.5 +56,1.0,0.5 +57,0.0,0.5 +58,0.0,0.5 +59,0.0,0.5 +60,0.0,0.5 +61,0.0,0.5 +62,0.0,0.5 +63,0.0,0.5 +64,0.0,0.5 +65,0.0,0.5 +66,1.0,0.5 +67,0.0,0.5 +68,0.0,0.5 +69,0.0,0.5 +70,0.0,0.5 +71,0.0,0.5 +72,0.0,0.5 +73,0.0,0.5 +74,0.0,0.5 +75,0.0,0.5 +76,1.0,0.5 +77,0.0,0.5 +78,0.0,0.5 +79,0.0,0.5 +80,0.0,0.5 +81,0.0,0.5 +82,0.0,0.5 +83,0.0,0.5 +84,0.0,0.5 +85,0.0,0.5 +86,1.0,0.5 +87,0.0,0.5 +88,0.0,0.5 +89,0.0,0.5 +90,0.0,0.5 +91,0.0,0.5 +92,0.0,0.5 +93,0.0,0.5 +94,0.0,0.5 +95,0.0,0.5 +96,1.0,0.5 +97,0.0,0.5 +98,0.0,0.5 +99,0.0,0.5 diff --git a/tests/test_api_response_handlers.py b/tests/test_api_response_handlers.py index 6115a4ed..01855b82 100644 --- a/tests/test_api_response_handlers.py +++ b/tests/test_api_response_handlers.py @@ -1,7 +1,8 @@ import pytest from datarobot_batch_scoring.api_response_handlers import pred_api_v10, api_v1 from datarobot_batch_scoring.consts import Batch -from datarobot_batch_scoring.exceptions import UnexpectedKeptColumnCount +from datarobot_batch_scoring.exceptions import UnexpectedKeptColumnCount, \ + NoPredictionThresholdInResult @pytest.fixture() @@ -10,11 +11,28 @@ def parsed_pred_api_v10_predictions(): { "predictionValues": [{"value": 1, "label": "readmitted"}], "prediction": 1, + "predictionThreshold": 0.5, "rowId": 0 }, { "predictionValues": [{"value": 1, "label": "readmitted"}], "prediction": 1, + "predictionThreshold": 0.5, + "rowId": 1 + }] + + +@pytest.fixture() +def parsed_pred_api_v10_predictions_regression(): + return [ + { + "predictionValues": [{"value": 42.42, "label": "currency"}], + "prediction": 42.42, + "rowId": 0 + }, + { + "predictionValues": [{"value": 84.84, "label": "currency"}], + "prediction": 84.84, "rowId": 1 }] @@ -106,6 +124,7 @@ def test_unpack_data(self): @pytest.mark.parametrize('opts, expected_fields, expected_values', ( ({'pred_name': None, + 'pred_threshold_name': None, 'pred_decision_name': None, 'keep_cols': None, 'skip_row_id': False, @@ -115,6 +134,7 @@ def test_unpack_data(self): [[0, 1], [1, 1]]), ({'pred_name': None, + 'pred_threshold_name': None, 'pred_decision_name': None, 'keep_cols': ['gender'], 'skip_row_id': False, @@ -124,6 +144,7 @@ def test_unpack_data(self): [[0, 'Male', 1], [1, 'Male', 1]]), ({'pred_name': None, + 'pred_threshold_name': None, 'pred_decision_name': None, 'keep_cols': ['gender'], 'skip_row_id': True, @@ -132,14 +153,34 @@ def test_unpack_data(self): ['gender', 'readmitted'], [['Male', 1], ['Male', 1]]), + ({'pred_name': None, + 'pred_threshold_name': 'threshold', + 'keep_cols': None, + 'skip_row_id': False, + 'fast_mode': False, + 'delimiter': ','}, + ['row_id', 'readmitted', 'threshold'], + [[0, 1, 0.5], [1, 1, 0.5]]), + ({'pred_name': None, 'pred_decision_name': 'label', + 'pred_threshold_name': None, 'keep_cols': None, 'skip_row_id': False, 'fast_mode': False, 'delimiter': ','}, ['row_id', 'readmitted', 'label'], [[0, 1, 1], [1, 1, 1]]), + + ({'pred_name': None, + 'pred_decision_name': 'label', + 'pred_threshold_name': 'threshold', + 'keep_cols': None, + 'skip_row_id': False, + 'fast_mode': False, + 'delimiter': ','}, + ['row_id', 'readmitted', 'threshold', 'label'], + [[0, 1, 0.5, 1], [1, 1, 0.5, 1]]), )) def test_format_data(self, parsed_pred_api_v10_predictions, batch, opts, expected_fields, expected_values): @@ -148,6 +189,21 @@ def test_format_data(self, parsed_pred_api_v10_predictions, batch, assert fields == expected_fields assert values == expected_values + @pytest.mark.parametrize('opts', [ + {'pred_name': None, + 'pred_threshold_name': 'threshold', + 'keep_cols': None, + 'skip_row_id': False, + 'fast_mode': False, + 'delimiter': ','}, + ]) + def test_fail_threshold_on_non_binary_classification( + self, parsed_pred_api_v10_predictions_regression, + batch, opts): + with pytest.raises(NoPredictionThresholdInResult): + pred_api_v10.format_data( + parsed_pred_api_v10_predictions_regression, batch, **opts) + class TestApiV1Handlers(object): def test_unpack_data(self): diff --git a/tests/test_conf_file.py b/tests/test_conf_file.py index 182f0cfb..beb5f0b5 100644 --- a/tests/test_conf_file.py +++ b/tests/test_conf_file.py @@ -142,6 +142,7 @@ def test_run_main_with_conf_file(monkeypatch): user=file_username password=file_password max_prediction_explanations=3 + pred_threshold=threshold pred_decision=label""") with NamedTemporaryFile(suffix='.ini', delete=False) as test_file: test_file.write(str(raw_data).encode('utf-8')) @@ -174,6 +175,7 @@ def test_run_main_with_conf_file(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name='threshold', pred_decision_name='label', timeout=None, ui=mock.ANY, @@ -238,6 +240,7 @@ def test_run_main_with_conf_file_deployment_aware(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -302,6 +305,7 @@ def test_run_empty_main_with_conf_file(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, diff --git a/tests/test_functional.py b/tests/test_functional.py index d11294e0..5c006bd6 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -107,6 +107,8 @@ def test_simple(live_server, tmpdir, func_params): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv.gz', pred_name=None, + pred_threshold_name=None, + pred_decision_name=None, timeout=None, ui=ui, auto_sample=False, @@ -148,6 +150,8 @@ def test_prediction_explanations(live_server, tmpdir): delimiter=None, dataset='tests/fixtures/10kDiabetes.csv', pred_name=None, + pred_threshold_name=None, + pred_decision_name=None, timeout=None, ui=ui, auto_sample=False, @@ -190,6 +194,8 @@ def test_prediction_explanations_keepcols(live_server, tmpdir): delimiter=None, dataset='tests/fixtures/10kDiabetes.csv', pred_name=None, + pred_threshold_name=None, + pred_decision_name=None, timeout=None, ui=ui, auto_sample=False, @@ -233,6 +239,8 @@ def test_simple_api_v1(live_server, tmpdir): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv.gz', pred_name=None, + pred_threshold_name=None, + pred_decision_name=None, timeout=None, ui=ui, auto_sample=False, @@ -274,6 +282,8 @@ def test_simple_transferable(live_server, tmpdir): delimiter=None, dataset='tests/fixtures/regression_predict.csv', pred_name=None, + pred_threshold_name=None, + pred_decision_name=None, timeout=None, ui=ui, auto_sample=False, @@ -315,6 +325,8 @@ def test_keep_cols(live_server, tmpdir, ui, func_params, fast_mode=False): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, + pred_decision_name=None, timeout=None, ui=ui, auto_sample=False, @@ -365,6 +377,8 @@ def test_keep_wrong_cols(live_server, tmpdir, func_params, fast_mode=False): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, + pred_decision_name=None, timeout=None, ui=ui, auto_sample=False, @@ -408,6 +422,8 @@ def test_pred_name_classification(live_server, tmpdir, func_params): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name='healthy', + pred_threshold_name=None, + pred_decision_name=None, timeout=None, ui=ui, auto_sample=False, @@ -424,6 +440,51 @@ def test_pred_name_classification(live_server, tmpdir, func_params): assert expected == f.read(), expected +def test_pred_threshold_classification(live_server, tmpdir, func_params): + # train one model in project + out = tmpdir.join('out.csv') + + ui = PickableMock() + base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url()) + ret = run_batch_predictions( + base_url=base_url, + base_headers={}, + user='username', + pwd='password', + api_token=None, + create_api_token=False, + deployment_id=func_params['deployment_id'], + pid=func_params['pid'], + lid=func_params['lid'], + import_id=None, + n_retry=3, + concurrent=1, + resume=False, + n_samples=10, + out_file=str(out), + keep_cols=None, + delimiter=None, + dataset='tests/fixtures/temperatura_predict.csv', + pred_name='healthy', + pred_threshold_name='threshold', + timeout=None, + ui=ui, + auto_sample=False, + fast_mode=False, + dry_run=False, + encoding='', + skip_dialect=False + ) + + assert ret is None + + expected = out.read_text('utf-8') + with open( + 'tests/fixtures/temperatura_output_healthy_threshold.csv', 'rU' + ) as f: + assert expected == f.read(), expected + + def test_pred_decision_name_classification(live_server, tmpdir, func_params): # train one model in project out = tmpdir.join('out.csv') diff --git a/tests/test_main.py b/tests/test_main.py index db7819e5..57a1f7f5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -41,6 +41,7 @@ def test_without_passed_user_and_passwd(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -92,6 +93,7 @@ def test_keep_cols(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -187,6 +189,7 @@ def test_datarobot_key(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -239,6 +242,7 @@ def test_encoding_options(monkeypatch): delimiter='\t', dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -385,6 +389,7 @@ def test_output_delimiter(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -437,6 +442,7 @@ def test_skip_row_id(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -454,6 +460,59 @@ def test_skip_row_id(monkeypatch): ) +def test_pred_threshold(monkeypatch): + main_args = ['--host', + 'http://localhost:53646/api', + '56dd9570018e213242dfa93c', + '56dd9570018e213242dfa93d', + 'tests/fixtures/temperatura_predict.csv', + '--pred_threshold=threshold', + '--encoding=utf-8', '--skip_dialect'] + + monkeypatch.setattr('datarobot_batch_scoring.main.UI', mock.Mock(spec=UI)) + + with mock.patch( + 'datarobot_batch_scoring.main' + '.run_batch_predictions') as mock_method: + main(argv=main_args) + mock_method.assert_called_once_with( + base_url='http://localhost:53646/predApi/v1.0/', + base_headers={}, + user=mock.ANY, + pwd=mock.ANY, + api_token=None, + create_api_token=False, + pid='56dd9570018e213242dfa93c', + lid='56dd9570018e213242dfa93d', + deployment_id=None, + import_id=None, + n_retry=3, + concurrent=4, + resume=None, + n_samples=False, + out_file='out.csv', + keep_cols=None, + delimiter=None, + dataset='tests/fixtures/temperatura_predict.csv', + pred_name=None, + pred_decision_name=None, + pred_threshold_name='threshold', + timeout=None, + ui=mock.ANY, + fast_mode=False, + auto_sample=True, + dry_run=False, + encoding='utf-8', + skip_dialect=True, + skip_row_id=False, + output_delimiter=None, + compression=False, + field_size_limit=None, + verify_ssl=True, + max_prediction_explanations=0 + ) + + def test_pred_decision(monkeypatch): main_args = ['--host', 'http://localhost:53646/api', @@ -490,6 +549,7 @@ def test_pred_decision(monkeypatch): dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, pred_decision_name='label', + pred_threshold_name=None, timeout=None, ui=mock.ANY, fast_mode=False, @@ -539,6 +599,7 @@ def test_batch_scoring_deployment_aware_call(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -587,6 +648,7 @@ def test_datarobot_transferable_call(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -638,6 +700,7 @@ def test_resume(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, @@ -687,6 +750,7 @@ def test_resume_no(monkeypatch): delimiter=None, dataset='tests/fixtures/temperatura_predict.csv', pred_name=None, + pred_threshold_name=None, pred_decision_name=None, timeout=None, ui=mock.ANY, diff --git a/tests/test_writer.py b/tests/test_writer.py index 2f2937bf..40c3a924 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -23,6 +23,7 @@ def test_no_resume_existing_no_ask_new_context(run_context_file): pred_name='pred_name', ui=ui_mock, fast_mode=None, encoding=None, skip_row_id=None, output_delimiter=None, + pred_threshold_name=None, pred_decision_name=None) assert ui_mock.call_count == 0 assert isinstance(ctx, NewRunContext) @@ -48,6 +49,7 @@ def test_resume_if_it_was_run_already(run_context_file): pred_name='pred_name', ui=ui_mock, fast_mode=None, encoding=None, skip_row_id=None, output_delimiter=None, + pred_threshold_name=None, pred_decision_name=None) assert ui_mock.call_count == 0 assert isinstance(ctx, OldRunContext) @@ -72,6 +74,7 @@ def test_asking_if_resume_not_provided(run_context_file): pred_name='pred_name', ui=ui_mock, fast_mode=None, encoding=None, skip_row_id=None, output_delimiter=None, + pred_threshold_name=None, pred_decision_name=None) ui_mock.prompt_yesno.assert_called_once_with( 'Existing run found. Resume')