Skip to content
This repository has been archived by the owner on Sep 9, 2020. It is now read-only.

Commit

Permalink
[PRED-2176] Add optional prediction decision field to out csv (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
rrader committed Mar 22, 2019
1 parent c7f9897 commit e0f60f1
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 10 deletions.
29 changes: 28 additions & 1 deletion datarobot_batch_scoring/api_response_handlers/pred_api_v10.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,28 @@ def rows_generator():
return headers, rows_generator()


def pred_decision_field(result_sorted, pred_decision):
""" Generate prediction decision field
Parameters
----------
result_sorted : list[dict]
list of results sorted by rowId
pred_decision : str
column name which should contain prediction decision (label)
Returns
-------
header: list[str]
row_generator: iterator
"""

return [pred_decision], (
[row.get('prediction')]
for row in result_sorted
)


def format_data(result, batch, **opts):
""" Generate rows of response from results
Expand All @@ -170,7 +192,7 @@ def format_data(result, batch, **opts):
list of rows
"""
pred_name = opts.get('pred_name')
# pred_decision = opts.get('pred_decision')
pred_decision_name = opts.get('pred_decision_name')
keep_cols = opts.get('keep_cols')
skip_row_id = opts.get('skip_row_id')
fast_mode = opts.get('fast_mode')
Expand Down Expand Up @@ -217,6 +239,11 @@ def _find_prediction_explanations_key():
prediction_explanations_key
)
)

# Threshold and thresholded decision field ('prediction' value from result)
if pred_decision_name:
fields.append(pred_decision_field(result_sorted, pred_decision_name))

# endregion

headers = list(
Expand Down
6 changes: 4 additions & 2 deletions datarobot_batch_scoring/batch_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def run_batch_predictions(base_url, base_headers, user, pwd,
field_size_limit=None,
verify_ssl=True,
deployment_id=None,
max_prediction_explanations=0):
max_prediction_explanations=0,
pred_decision_name=None):

if field_size_limit is not None:
csv.field_size_limit(field_size_limit)
Expand Down Expand Up @@ -199,7 +200,8 @@ def run_batch_predictions(base_url, base_headers, user, pwd,
RunContext.create(resume, n_samples, out_file, pid,
lid, keep_cols, n_retry, delimiter,
dataset, pred_name, ui, fast_mode,
encoding, skip_row_id, output_delimiter))
encoding, skip_row_id, output_delimiter,
pred_decision_name))

n_batches_checkpointed_init = len(ctx.db['checkpoints'])
ui.debug('number of batches checkpointed initially: {}'
Expand Down
7 changes: 7 additions & 0 deletions datarobot_batch_scoring/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def parse_args(argv, standalone=False, deployment_aware=False):
'empty name is used if not specified. For binary '
'predictions assumes last class in lexical order '
'as positive')
csv_gr.add_argument('--pred_decision', type=str,
nargs='?', default=None,
help='Specifies column name for prediction decision, '
'the value predicted by the model (class label '
'for classification)')
csv_gr.add_argument('--fast', action='store_true',
default=defaults['fast'],
help='Experimental: faster CSV processor. '
Expand Down Expand Up @@ -301,6 +306,7 @@ def parse_generic_options(parsed_args):
skip_row_id = parsed_args['skip_row_id']
field_size_limit = parsed_args.get('field_size_limit')
pred_name = parsed_args.get('pred_name')
pred_decision_name = parsed_args.get('pred_decision')
dry_run = parsed_args.get('dry_run', False)

n_samples = int(parsed_args['n_samples'])
Expand Down Expand Up @@ -357,6 +363,7 @@ def parse_generic_options(parsed_args):
'out_file': out_file,
'output_delimiter': output_delimiter,
'pred_name': pred_name,
'pred_decision_name': pred_decision_name,
'resume': resume,
'skip_dialect': skip_dialect,
'skip_row_id': skip_row_id,
Expand Down
11 changes: 8 additions & 3 deletions datarobot_batch_scoring/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class RunContext(object):

def __init__(self, n_samples, out_file, pid, lid, keep_cols,
n_retry, delimiter, dataset, pred_name, ui, file_context,
fast_mode, encoding, skip_row_id, output_delimiter):
fast_mode, encoding, skip_row_id, output_delimiter,
pred_decision_name):
self.n_samples = n_samples
self.out_file = out_file
self.project_id = pid
Expand All @@ -47,6 +48,7 @@ def __init__(self, n_samples, out_file, pid, lid, keep_cols,
self.delimiter = delimiter
self.dataset = dataset
self.pred_name = pred_name
self.pred_decision_name = pred_decision_name
self.out_stream = None
self._ui = ui
self.file_context = file_context
Expand All @@ -65,7 +67,8 @@ def __init__(self, n_samples, out_file, pid, lid, keep_cols,
def create(cls, resume, n_samples, out_file, pid, lid,
keep_cols, n_retry,
delimiter, dataset, pred_name, ui,
fast_mode, encoding, skip_row_id, output_delimiter):
fast_mode, encoding, skip_row_id, output_delimiter,
pred_decision_name):
"""Factory method for run contexts.
Either resume or start a new one.
Expand All @@ -80,7 +83,8 @@ def create(cls, resume, n_samples, out_file, pid, lid,

return ctx_class(n_samples, out_file, pid, lid, keep_cols, n_retry,
delimiter, dataset, pred_name, ui, file_context,
fast_mode, encoding, skip_row_id, output_delimiter)
fast_mode, encoding, skip_row_id, output_delimiter,
pred_decision_name)

def __enter__(self):
assert(not self.is_open)
Expand Down Expand Up @@ -429,6 +433,7 @@ def process_response(self):
written_fields, comb = format_data(
data, batch,
pred_name=self.ctx.pred_name,
pred_decision_name=self.ctx.pred_decision_name,
keep_cols=self.ctx.keep_cols,
skip_row_id=self.ctx.skip_row_id,
fast_mode=self.ctx.fast_mode,
Expand Down
101 changes: 101 additions & 0 deletions tests/fixtures/temperatura_output_decision.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
row_id,0.0,1.0,label
0,1.0,0.0,0.0
1,1.0,0.0,0.0
2,1.0,0.0,0.0
3,1.0,0.0,0.0
4,1.0,0.0,0.0
5,1.0,0.0,0.0
6,0.0,1.0,1.0
7,1.0,0.0,0.0
8,1.0,0.0,0.0
9,1.0,0.0,0.0
10,1.0,0.0,0.0
11,1.0,0.0,0.0
12,1.0,0.0,0.0
13,1.0,0.0,0.0
14,1.0,0.0,0.0
15,1.0,0.0,0.0
16,0.0,1.0,1.0
17,1.0,0.0,0.0
18,1.0,0.0,0.0
19,1.0,0.0,0.0
20,1.0,0.0,0.0
21,1.0,0.0,0.0
22,1.0,0.0,0.0
23,1.0,0.0,0.0
24,1.0,0.0,0.0
25,1.0,0.0,0.0
26,0.0,1.0,1.0
27,1.0,0.0,0.0
28,1.0,0.0,0.0
29,1.0,0.0,0.0
30,1.0,0.0,0.0
31,1.0,0.0,0.0
32,1.0,0.0,0.0
33,1.0,0.0,0.0
34,1.0,0.0,0.0
35,1.0,0.0,0.0
36,0.0,1.0,1.0
37,1.0,0.0,0.0
38,1.0,0.0,0.0
39,1.0,0.0,0.0
40,1.0,0.0,0.0
41,1.0,0.0,0.0
42,1.0,0.0,0.0
43,1.0,0.0,0.0
44,1.0,0.0,0.0
45,1.0,0.0,0.0
46,0.0,1.0,1.0
47,1.0,0.0,0.0
48,1.0,0.0,0.0
49,1.0,0.0,0.0
50,1.0,0.0,0.0
51,1.0,0.0,0.0
52,1.0,0.0,0.0
53,1.0,0.0,0.0
54,1.0,0.0,0.0
55,1.0,0.0,0.0
56,0.0,1.0,1.0
57,1.0,0.0,0.0
58,1.0,0.0,0.0
59,1.0,0.0,0.0
60,1.0,0.0,0.0
61,1.0,0.0,0.0
62,1.0,0.0,0.0
63,1.0,0.0,0.0
64,1.0,0.0,0.0
65,1.0,0.0,0.0
66,0.0,1.0,1.0
67,1.0,0.0,0.0
68,1.0,0.0,0.0
69,1.0,0.0,0.0
70,1.0,0.0,0.0
71,1.0,0.0,0.0
72,1.0,0.0,0.0
73,1.0,0.0,0.0
74,1.0,0.0,0.0
75,1.0,0.0,0.0
76,0.0,1.0,1.0
77,1.0,0.0,0.0
78,1.0,0.0,0.0
79,1.0,0.0,0.0
80,1.0,0.0,0.0
81,1.0,0.0,0.0
82,1.0,0.0,0.0
83,1.0,0.0,0.0
84,1.0,0.0,0.0
85,1.0,0.0,0.0
86,0.0,1.0,1.0
87,1.0,0.0,0.0
88,1.0,0.0,0.0
89,1.0,0.0,0.0
90,1.0,0.0,0.0
91,1.0,0.0,0.0
92,1.0,0.0,0.0
93,1.0,0.0,0.0
94,1.0,0.0,0.0
95,1.0,0.0,0.0
96,0.0,1.0,1.0
97,1.0,0.0,0.0
98,1.0,0.0,0.0
99,1.0,0.0,0.0
12 changes: 12 additions & 0 deletions tests/test_api_response_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def test_unpack_data(self):

@pytest.mark.parametrize('opts, expected_fields, expected_values', (
({'pred_name': None,
'pred_decision_name': None,
'keep_cols': None,
'skip_row_id': False,
'fast_mode': False,
Expand All @@ -114,6 +115,7 @@ def test_unpack_data(self):
[[0, 1], [1, 1]]),
({'pred_name': None,
'pred_decision_name': None,
'keep_cols': ['gender'],
'skip_row_id': False,
'fast_mode': False,
Expand All @@ -122,12 +124,22 @@ def test_unpack_data(self):
[[0, 'Male', 1], [1, 'Male', 1]]),
({'pred_name': None,
'pred_decision_name': None,
'keep_cols': ['gender'],
'skip_row_id': True,
'fast_mode': False,
'delimiter': ','},
['gender', 'readmitted'],
[['Male', 1], ['Male', 1]]),
({'pred_name': None,
'pred_decision_name': 'label',
'keep_cols': None,
'skip_row_id': False,
'fast_mode': False,
'delimiter': ','},
['row_id', 'readmitted', 'label'],
[[0, 1, 1], [1, 1, 1]]),
))
def test_format_data(self, parsed_pred_api_v10_predictions, batch,
opts, expected_fields, expected_values):
Expand Down
6 changes: 5 additions & 1 deletion tests/test_conf_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def test_run_main_with_conf_file(monkeypatch):
model_id=56dd9570018e213242dfa93d
user=file_username
password=file_password
max_prediction_explanations=3""")
max_prediction_explanations=3
pred_decision=label""")
with NamedTemporaryFile(suffix='.ini', delete=False) as test_file:
test_file.write(str(raw_data).encode('utf-8'))

Expand Down Expand Up @@ -173,6 +174,7 @@ def test_run_main_with_conf_file(monkeypatch):
delimiter=None,
dataset='tests/fixtures/temperatura_predict.csv',
pred_name=None,
pred_decision_name='label',
timeout=None,
ui=mock.ANY,
auto_sample=False,
Expand Down Expand Up @@ -236,6 +238,7 @@ def test_run_main_with_conf_file_deployment_aware(monkeypatch):
delimiter=None,
dataset='tests/fixtures/temperatura_predict.csv',
pred_name=None,
pred_decision_name=None,
timeout=None,
ui=mock.ANY,
auto_sample=False,
Expand Down Expand Up @@ -299,6 +302,7 @@ def test_run_empty_main_with_conf_file(monkeypatch):
delimiter=None,
dataset='tests/fixtures/temperatura_predict.csv',
pred_name=None,
pred_decision_name=None,
timeout=None,
ui=mock.ANY,
auto_sample=False,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,49 @@ def test_pred_name_classification(live_server, tmpdir, func_params):
assert expected == f.read(), expected


def test_pred_decision_name_classification(live_server, tmpdir, func_params):
# train one model in project
out = tmpdir.join('out.csv')

ui = PickableMock()
base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
ret = run_batch_predictions(
base_url=base_url,
base_headers={},
user='username',
pwd='password',
api_token=None,
create_api_token=False,
deployment_id=func_params['deployment_id'],
pid=func_params['pid'],
lid=func_params['lid'],
import_id=None,
n_retry=3,
concurrent=1,
resume=False,
n_samples=10,
out_file=str(out),
keep_cols=None,
delimiter=None,
dataset='tests/fixtures/temperatura_predict.csv',
pred_name=None,
pred_decision_name='label',
timeout=None,
ui=ui,
auto_sample=False,
fast_mode=False,
dry_run=False,
encoding='',
skip_dialect=False
)

assert ret is None

expected = out.read_text('utf-8')
with open('tests/fixtures/temperatura_output_decision.csv', 'rU') as f:
assert expected == f.read(), expected


def test_422(live_server, tmpdir):
# train one model in project
out = tmpdir.join('out.csv')
Expand Down
Loading

0 comments on commit e0f60f1

Please sign in to comment.