/
scoring.py
422 lines (347 loc) · 15.1 KB
/
scoring.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
""" Callbacks for calculating scores."""
from contextlib import contextmanager
from contextlib import suppress
from functools import partial
import numpy as np
from sklearn.metrics.scorer import (
check_scoring, _BaseScorer, make_scorer)
from sklearn.model_selection._validation import _score
from skorch.utils import data_from_dataset
from skorch.utils import is_skorch_dataset
from skorch.utils import to_numpy
from skorch.callbacks import Callback
from skorch.utils import check_indexing
from skorch.utils import train_loss_score
from skorch.utils import valid_loss_score
__all__ = ['BatchScoring', 'EpochScoring']
@contextmanager
def cache_net_infer(net, use_caching, y_preds):
"""Caching context for ``skorch.NeuralNet`` instance. Returns
a modified version of the net whose ``infer`` method will
subsequently return cached predictions. Leaving the context
will undo the overwrite of the ``infer`` method."""
if not use_caching:
yield net
return
y_preds = iter(y_preds)
net.infer = lambda *a, **kw: next(y_preds)
try:
yield net
finally:
# By setting net.infer we define an attribute `infer`
# that precedes the bound method `infer`. By deleting
# the entry from the attribute dict we undo this.
del net.__dict__['infer']
def convert_sklearn_metric_function(scoring):
"""If ``scoring`` is a sklearn metric function, convert it to a
sklearn scorer and return it. Otherwise, return ``scoring`` unchanged."""
if callable(scoring):
module = getattr(scoring, '__module__', None)
if (
hasattr(module, 'startswith') and
module.startswith('sklearn.metrics.') and
not module.startswith('sklearn.metrics.scorer') and
not module.startswith('sklearn.metrics.tests.')
):
return make_scorer(scoring)
return scoring
class ScoringBase(Callback):
"""Base class for scoring.
Subclass and implement an ``on_*`` method before using.
"""
def __init__(
self,
scoring,
lower_is_better=True,
on_train=False,
name=None,
target_extractor=to_numpy,
use_caching=True,
):
self.scoring = scoring
self.lower_is_better = lower_is_better
self.on_train = on_train
self.name = name
self.target_extractor = target_extractor
self.use_caching = use_caching
# pylint: disable=protected-access
def _get_name(self):
"""Find name of scoring function."""
if self.name is not None:
return self.name
if self.scoring_ is None:
return 'score'
if isinstance(self.scoring_, str):
return self.scoring_
if isinstance(self.scoring_, partial):
return self.scoring_.func.__name__
if isinstance(self.scoring_, _BaseScorer):
return self.scoring_._score_func.__name__
return self.scoring_.__name__
def initialize(self):
self.best_score_ = np.inf if self.lower_is_better else -np.inf
self.scoring_ = convert_sklearn_metric_function(self.scoring)
self.name_ = self._get_name()
return self
# pylint: disable=attribute-defined-outside-init,arguments-differ
def on_train_begin(self, net, X, y, **kwargs):
self.X_indexing_ = check_indexing(X)
self.y_indexing_ = check_indexing(y)
# Looks for the right most index where `*_best` is True
# That index is used to get the best score in `net.history`
with suppress(ValueError, IndexError, KeyError):
best_name_history = net.history[:, '{}_best'.format(self.name_)]
idx_best_reverse = best_name_history[::-1].index(True)
idx_best = len(best_name_history) - idx_best_reverse - 1
self.best_score_ = net.history[idx_best, self.name_]
def _scoring(self, net, X_test, y_test):
"""Resolve scoring and apply it to data. Use cached prediction
instead of running inference again, if available."""
scorer = check_scoring(net, self.scoring_)
scores = _score(
estimator=net,
X_test=X_test,
y_test=y_test,
scorer=scorer,
is_multimetric=False,
)
return scores
def _is_best_score(self, current_score):
if self.lower_is_better is None:
return None
if self.lower_is_better:
return current_score < self.best_score_
return current_score > self.best_score_
class BatchScoring(ScoringBase):
"""Callback that performs generic scoring on batches.
This callback determines the score after each batch and stores it
in the net's history in the column given by ``name``. At the end
of the epoch, the average of the scores are determined and also
stored in the history. Furthermore, it is determined whether this
average score is the best score yet and that information is also
stored in the history.
In contrast to ``EpochScoring``, this callback determines the
score for each batch and then averages the score at the end of the
epoch. This can be disadvantageous for some scores if the batch
size is small -- e.g. area under the ROC will return incorrect
scores in this case. Therefore, it is recommnded to use
``EpochScoring`` unless you really need the scores for each batch.
If ``y`` is None, the ``scoring`` function with signature (model, X, y)
must be able to handle ``X`` as a ``Tensor`` and ``y=None``.
Parameters
----------
scoring : None, str, or callable
If None, use the ``score`` method of the model. If str, it should
be a valid sklearn metric (e.g. "f1_score", "accuracy_score"). If
a callable, it should have the signature (model, X, y), and it
should return a scalar. This works analogously to the ``scoring``
parameter in sklearn's ``GridSearchCV`` et al.
lower_is_better : bool (default=True)
Whether lower (e.g. log loss) or higher (e.g. accuracy) scores
are better
on_train : bool (default=False)
Whether this should be called during train or validation.
name : str or None (default=None)
If not an explicit string, tries to infer the name from the
``scoring`` argument.
target_extractor : callable (default=to_numpy)
This is called on y before it is passed to scoring.
use_caching : bool (default=True)
Re-use the model's prediction for computing the loss to calculate
the score. Turning this off will result in an additional inference
step for each batch.
"""
# pylint: disable=unused-argument,arguments-differ
def on_batch_end(self, net, X, y, training, **kwargs):
if training != self.on_train:
return
if self.scoring in [train_loss_score, valid_loss_score]:
return
y_preds = [kwargs['y_pred']]
with cache_net_infer(net, self.use_caching, y_preds) as cached_net:
# In case of y=None we will not have gathered any samples.
# We expect the scoring function to deal with y=None.
y = None if y is None else self.target_extractor(y)
try:
score = self._scoring(cached_net, X, y)
cached_net.history.record_batch(self.name_, score)
except KeyError:
pass
def get_avg_score(self, history):
if self.on_train:
bs_key = 'train_batch_size'
else:
bs_key = 'valid_batch_size'
weights, scores = list(zip(
*history[-1, 'batches', :, [bs_key, self.name_]]))
score_avg = np.average(scores, weights=weights)
return score_avg
# pylint: disable=unused-argument
def on_epoch_end(self, net, **kwargs):
history = net.history
try:
history[-1, 'batches', :, self.name_]
except KeyError:
return
score_avg = self.get_avg_score(history)
is_best = self._is_best_score(score_avg)
if is_best:
self.best_score_ = score_avg
history.record(self.name_, score_avg)
if is_best is not None:
history.record(self.name_ + '_best', bool(is_best))
class EpochScoring(ScoringBase):
"""Callback that performs generic scoring on predictions.
At the end of each epoch, this callback makes a prediction on
train or validation data, determines the score for that prediction
and whether it is the best yet, and stores the result in the net's
history.
In case you already computed a score value for each batch you
can omit the score computation step by return the value from
the history. For example:
>>> def my_score(net, X=None, y=None):
... losses = net.history[-1, 'batches', :, 'my_score']
... batch_sizes = net.history[-1, 'batches', :, 'valid_batch_size']
... return np.average(losses, weights=batch_sizes)
>>> net = MyNet(callbacks=[
... ('my_score', Scoring(my_score, name='my_score'))
If you fit with a custom dataset, this callback should work as
expected as long as ``use_caching=True`` which enables the
collection of ``y`` values from the dataset. If you decide to
disable the caching of predictions and ``y`` values, you need
to write your own scoring function that is able to deal with the
dataset and returns a scalar, for example:
>>> def ds_accuracy(net, ds, y=None):
... # assume ds yields (X, y), e.g. torchvision.datasets.MNIST
... y_true = [y for _, y in ds]
... y_pred = net.predict(ds)
... return sklearn.metrics.accuracy_score(y_true, y_pred)
>>> net = MyNet(callbacks=[
... EpochScoring(ds_accuracy, use_caching=False)])
>>> ds = torchvision.datasets.MNIST(root=mnist_path)
>>> net.fit(ds)
Parameters
----------
scoring : None, str, or callable (default=None)
If None, use the ``score`` method of the model. If str, it
should be a valid sklearn scorer (e.g. "f1", "accuracy"). If a
callable, it should have the signature (model, X, y), and it
should return a scalar. This works analogously to the
``scoring`` parameter in sklearn's ``GridSearchCV`` et al.
lower_is_better : bool (default=True)
Whether lower scores should be considered better or worse.
on_train : bool (default=False)
Whether this should be called during train or validation data.
name : str or None (default=None)
If not an explicit string, tries to infer the name from the
``scoring`` argument.
target_extractor : callable (default=to_numpy)
This is called on y before it is passed to scoring.
use_caching : bool (default=True)
Collect labels and predictions (``y_true`` and ``y_pred``)
over the course of one epoch and use the cached values for
computing the score. The cached values are shared between
all ``EpochScoring`` instances. Disabling this will result
in an additional inference step for each epoch and an
inability to use arbitrary datasets as input (since we
don't know how to extract ``y_true`` from an arbitrary
dataset).
"""
def _initialize_cache(self):
self.y_trues_ = []
self.y_preds_ = []
def initialize(self):
super().initialize()
self._initialize_cache()
return self
# pylint: disable=arguments-differ,unused-argument
def on_epoch_begin(self, net, dataset_train, dataset_valid, **kwargs):
self._initialize_cache()
# pylint: disable=arguments-differ
def on_batch_end(
self, net, y, y_pred, training, **kwargs):
if not self.use_caching or training != self.on_train:
return
# We collect references to the prediction and target data
# emitted by the training process. Since we don't copy the
# data, all *Scoring callback instances use the same
# underlying data. This is also the reason why we don't run
# self.target_extractor(y) here but on epoch end, so that
# there are no copies of parts of y hanging around during
# training.
if y is not None:
self.y_trues_.append(y)
self.y_preds_.append(y_pred)
def get_test_data(self, dataset_train, dataset_valid):
"""Return data needed to perform scoring.
This is a convenience method that handles picking of
train/valid, different types of input data, use of cache,
etc. for you.
Parameters
----------
dataset_train
Incoming training data or dataset.
dataset_valid
Incoming validation data or dataset.
Returns
-------
X_test
Input data used for making the prediction.
y_test
Target ground truth. If caching was enabled, return cached
y_test.
y_pred : list
The predicted targets. If caching was disabled, the list is
empty. If caching was enabled, the list contains the batches
of the predictions. It may thus be necessary to concatenate
the output before working with it:
``y_pred = np.concatenate(y_pred)``
"""
dataset = dataset_train if self.on_train else dataset_valid
if self.use_caching:
X_test = dataset
y_pred = self.y_preds_
y_test = [self.target_extractor(y) for y in self.y_trues_]
# In case of y=None we will not have gathered any samples.
# We expect the scoring function to deal with y_test=None.
y_test = np.concatenate(y_test) if y_test else None
return X_test, y_test, y_pred
if is_skorch_dataset(dataset):
X_test, y_test = data_from_dataset(
dataset,
X_indexing=self.X_indexing_,
y_indexing=self.y_indexing_,
)
else:
X_test, y_test = dataset, None
if y_test is not None:
# We allow y_test to be None but the scoring function has
# to be able to deal with it (i.e. called without y_test).
y_test = self.target_extractor(y_test)
return X_test, y_test, []
def _record_score(self, history, current_score):
"""Record the current store and, if applicable, if it's the best score
yet.
"""
history.record(self.name_, current_score)
is_best = self._is_best_score(current_score)
if is_best is None:
return
history.record(self.name_ + '_best', bool(is_best))
if is_best:
self.best_score_ = current_score
# pylint: disable=unused-argument,arguments-differ
def on_epoch_end(
self,
net,
dataset_train,
dataset_valid,
**kwargs):
X_test, y_test, y_pred = self.get_test_data(dataset_train, dataset_valid)
if X_test is None:
return
with cache_net_infer(net, self.use_caching, y_pred) as cached_net:
current_score = self._scoring(cached_net, X_test, y_test)
self._record_score(net.history, current_score)
def on_train_end(self, *args, **kwargs):
self._initialize_cache()