Skip to content
This repository has been archived by the owner on Sep 9, 2020. It is now read-only.

Commit

Permalink
[PRED-2176] Add optional prediction threshold field to out csv (#155)
Browse files Browse the repository at this point in the history

* PRED-2176 add unit test for no prediction threshold in result

* PRED-2176 fix fixture for binary classification in unit tests

* PRED-2176 fix help text for column

* [PRED-2176] Fix tests
  • Loading branch information
rrader committed Mar 22, 2019
1 parent e0f60f1 commit 03a464f
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 9 deletions.
38 changes: 35 additions & 3 deletions datarobot_batch_scoring/api_response_handlers/pred_api_v10.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import json
from six.moves import zip

from datarobot_batch_scoring.exceptions import UnexpectedKeptColumnCount
from datarobot_batch_scoring.exceptions import UnexpectedKeptColumnCount, \
NoPredictionThresholdInResult


def row_id_field(result_sorted, batch):
Expand Down Expand Up @@ -150,6 +151,26 @@ def rows_generator():
return headers, rows_generator()


def pred_threshold_field(result_sorted, pred_threshold_name):
""" Generate prediction threshold field (for classification)
Parameters
----------
result_sorted : list[dict]
list of results sorted by rowId
pred_threshold_name : str
column name which should contain prediction threshold
Returns
-------
header: list[str]
row_generator: iterator
"""

return [pred_threshold_name], (
[row.get('predictionThreshold')]
for row in result_sorted
)


def pred_decision_field(result_sorted, pred_decision):
""" Generate prediction decision field
Expand Down Expand Up @@ -177,7 +198,7 @@ def format_data(result, batch, **opts):
Parameters
----------
result : list
result : list[dict]
list of results
batch
batch information
Expand All @@ -192,6 +213,7 @@ def format_data(result, batch, **opts):
list of rows
"""
pred_name = opts.get('pred_name')
pred_threshold_name = opts.get('pred_threshold_name')
pred_decision_name = opts.get('pred_decision_name')
keep_cols = opts.get('keep_cols')
skip_row_id = opts.get('skip_row_id')
Expand Down Expand Up @@ -240,7 +262,17 @@ def _find_prediction_explanations_key():
)
)

# Threshold and thresholded decision field ('prediction' value from result)
# Threshold field for classification
# ('predictionThreshold' value from result)
if pred_threshold_name:
if 'predictionThreshold' in single_row:
fields.append(
pred_threshold_field(result_sorted, pred_threshold_name)
)
else:
raise NoPredictionThresholdInResult()

# Thresholded decision field ('prediction' value from result)
if pred_decision_name:
fields.append(pred_decision_field(result_sorted, pred_decision_name))

Expand Down
3 changes: 2 additions & 1 deletion datarobot_batch_scoring/batch_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def run_batch_predictions(base_url, base_headers, user, pwd,
verify_ssl=True,
deployment_id=None,
max_prediction_explanations=0,
pred_threshold_name=None,
pred_decision_name=None):

if field_size_limit is not None:
Expand Down Expand Up @@ -201,7 +202,7 @@ def run_batch_predictions(base_url, base_headers, user, pwd,
lid, keep_cols, n_retry, delimiter,
dataset, pred_name, ui, fast_mode,
encoding, skip_row_id, output_delimiter,
pred_decision_name))
pred_threshold_name, pred_decision_name))

n_batches_checkpointed_init = len(ctx.db['checkpoints'])
ui.debug('number of batches checkpointed initially: {}'
Expand Down
4 changes: 4 additions & 0 deletions datarobot_batch_scoring/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ class ShelveError(Exception):

class UnexpectedKeptColumnCount(Exception):
pass


class NoPredictionThresholdInResult(Exception):
pass
7 changes: 7 additions & 0 deletions datarobot_batch_scoring/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def parse_args(argv, standalone=False, deployment_aware=False):
'empty name is used if not specified. For binary '
'predictions assumes last class in lexical order '
'as positive')
csv_gr.add_argument('--pred_threshold', type=str,
nargs='?', default=None,
help='Specifies column name for prediction threshold '
'for binary classification. Column will not be '
'included if not specified')
csv_gr.add_argument('--pred_decision', type=str,
nargs='?', default=None,
help='Specifies column name for prediction decision, '
Expand Down Expand Up @@ -306,6 +311,7 @@ def parse_generic_options(parsed_args):
skip_row_id = parsed_args['skip_row_id']
field_size_limit = parsed_args.get('field_size_limit')
pred_name = parsed_args.get('pred_name')
pred_threshold_name = parsed_args.get('pred_threshold')
pred_decision_name = parsed_args.get('pred_decision')
dry_run = parsed_args.get('dry_run', False)

Expand Down Expand Up @@ -363,6 +369,7 @@ def parse_generic_options(parsed_args):
'out_file': out_file,
'output_delimiter': output_delimiter,
'pred_name': pred_name,
'pred_threshold_name': pred_threshold_name,
'pred_decision_name': pred_decision_name,
'resume': resume,
'skip_dialect': skip_dialect,
Expand Down
15 changes: 11 additions & 4 deletions datarobot_batch_scoring/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
WriterQueueMsg, ProgressQueueMsg, REPORT_INTERVAL
from datarobot_batch_scoring.utils import get_rusage
from datarobot_batch_scoring.exceptions import ShelveError, \
UnexpectedKeptColumnCount
UnexpectedKeptColumnCount, NoPredictionThresholdInResult

if six.PY3:
import dbm.dumb as dumb_dbm
Expand All @@ -38,7 +38,7 @@ class RunContext(object):
def __init__(self, n_samples, out_file, pid, lid, keep_cols,
n_retry, delimiter, dataset, pred_name, ui, file_context,
fast_mode, encoding, skip_row_id, output_delimiter,
pred_decision_name):
pred_threshold_name, pred_decision_name):
self.n_samples = n_samples
self.out_file = out_file
self.project_id = pid
Expand All @@ -48,6 +48,7 @@ def __init__(self, n_samples, out_file, pid, lid, keep_cols,
self.delimiter = delimiter
self.dataset = dataset
self.pred_name = pred_name
self.pred_threshold_name = pred_threshold_name
self.pred_decision_name = pred_decision_name
self.out_stream = None
self._ui = ui
Expand All @@ -68,7 +69,7 @@ def create(cls, resume, n_samples, out_file, pid, lid,
keep_cols, n_retry,
delimiter, dataset, pred_name, ui,
fast_mode, encoding, skip_row_id, output_delimiter,
pred_decision_name):
pred_threshold_name, pred_decision_name):
"""Factory method for run contexts.
Either resume or start a new one.
Expand All @@ -84,7 +85,7 @@ def create(cls, resume, n_samples, out_file, pid, lid,
return ctx_class(n_samples, out_file, pid, lid, keep_cols, n_retry,
delimiter, dataset, pred_name, ui, file_context,
fast_mode, encoding, skip_row_id, output_delimiter,
pred_decision_name)
pred_threshold_name, pred_decision_name)

def __enter__(self):
assert(not self.is_open)
Expand Down Expand Up @@ -433,6 +434,7 @@ def process_response(self):
written_fields, comb = format_data(
data, batch,
pred_name=self.ctx.pred_name,
pred_threshold_name=self.ctx.pred_threshold_name,
pred_decision_name=self.ctx.pred_decision_name,
keep_cols=self.ctx.keep_cols,
skip_row_id=self.ctx.skip_row_id,
Expand All @@ -443,6 +445,11 @@ def process_response(self):
'retrieved. This can happen in ' +
'--fast mode with --keep_cols where ' +
'some cells contain quoted delimiters')
except NoPredictionThresholdInResult:
self._ui.fatal('No predictionThreshold returned from '
'API. --pred_threshold should be used '
'only for binary classification '
'predictions')
except Exception as e:
self._ui.fatal(e)

Expand Down
10 changes: 10 additions & 0 deletions tests/fixtures/temperatura.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"label":0.0
}],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":0
},
{
Expand All @@ -20,6 +21,7 @@
"label":0.0
}],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":1
},
{
Expand All @@ -31,6 +33,7 @@
"label":0.0
}],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":2
},
{
Expand All @@ -39,13 +42,15 @@
{"value":1.0,"label":0.0}
],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":3
},{
"predictionValues":[
{"value":0.0,"label":1.0},
{"value":1.0,"label":0.0}
],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":4
},
{
Expand All @@ -55,6 +60,7 @@
"value":1.0,"label":0.0
}],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":5
},{
"predictionValues":[{
Expand All @@ -63,12 +69,14 @@
"value":0.0,"label":0.0
}],
"prediction":1.0,
"predictionThreshold":0.5,
"rowId":6
},{
"predictionValues":[{
"value":0.0,"label":1.0
},{"value":1.0,"label":0.0}],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":7
},{
"predictionValues":[{
Expand All @@ -77,6 +85,7 @@
"value":1.0,"label":0.0
}],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":8
},{
"predictionValues":[{
Expand All @@ -85,6 +94,7 @@
"value":1.0,"label":0.0
}],
"prediction":0.0,
"predictionThreshold":0.5,
"rowId":9
}]
}
101 changes: 101 additions & 0 deletions tests/fixtures/temperatura_output_healthy_threshold.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
row_id,healthy,threshold
0,0.0,0.5
1,0.0,0.5
2,0.0,0.5
3,0.0,0.5
4,0.0,0.5
5,0.0,0.5
6,1.0,0.5
7,0.0,0.5
8,0.0,0.5
9,0.0,0.5
10,0.0,0.5
11,0.0,0.5
12,0.0,0.5
13,0.0,0.5
14,0.0,0.5
15,0.0,0.5
16,1.0,0.5
17,0.0,0.5
18,0.0,0.5
19,0.0,0.5
20,0.0,0.5
21,0.0,0.5
22,0.0,0.5
23,0.0,0.5
24,0.0,0.5
25,0.0,0.5
26,1.0,0.5
27,0.0,0.5
28,0.0,0.5
29,0.0,0.5
30,0.0,0.5
31,0.0,0.5
32,0.0,0.5
33,0.0,0.5
34,0.0,0.5
35,0.0,0.5
36,1.0,0.5
37,0.0,0.5
38,0.0,0.5
39,0.0,0.5
40,0.0,0.5
41,0.0,0.5
42,0.0,0.5
43,0.0,0.5
44,0.0,0.5
45,0.0,0.5
46,1.0,0.5
47,0.0,0.5
48,0.0,0.5
49,0.0,0.5
50,0.0,0.5
51,0.0,0.5
52,0.0,0.5
53,0.0,0.5
54,0.0,0.5
55,0.0,0.5
56,1.0,0.5
57,0.0,0.5
58,0.0,0.5
59,0.0,0.5
60,0.0,0.5
61,0.0,0.5
62,0.0,0.5
63,0.0,0.5
64,0.0,0.5
65,0.0,0.5
66,1.0,0.5
67,0.0,0.5
68,0.0,0.5
69,0.0,0.5
70,0.0,0.5
71,0.0,0.5
72,0.0,0.5
73,0.0,0.5
74,0.0,0.5
75,0.0,0.5
76,1.0,0.5
77,0.0,0.5
78,0.0,0.5
79,0.0,0.5
80,0.0,0.5
81,0.0,0.5
82,0.0,0.5
83,0.0,0.5
84,0.0,0.5
85,0.0,0.5
86,1.0,0.5
87,0.0,0.5
88,0.0,0.5
89,0.0,0.5
90,0.0,0.5
91,0.0,0.5
92,0.0,0.5
93,0.0,0.5
94,0.0,0.5
95,0.0,0.5
96,1.0,0.5
97,0.0,0.5
98,0.0,0.5
99,0.0,0.5
Loading

0 comments on commit 03a464f

Please sign in to comment.