-
Notifications
You must be signed in to change notification settings - Fork 41
/
base_trainer.py
375 lines (305 loc) · 13 KB
/
base_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
from tfsnippet.scaffold import TrainLoop, EventKeys
from tfsnippet.utils import (ensure_variables_initialized,
get_default_session_or_error,
DocInherit, EventSource)
from .evaluator import Evaluator
__all__ = ['BaseTrainer']
def check_epochs_and_steps_arg(epochs=None, steps=None):
if (epochs is not None and steps is not None) or \
(epochs is None and steps is None):
raise ValueError('One and only one of `epochs` and `steps` should '
'be specified.')
class OnEveryFewCalls(object):
def __init__(self, key, freq, callback):
assert(callable(callback))
self.key = key
self.freq = freq
self.callback = callback
def __call__(self, trainer):
if getattr(trainer.loop, self.key) % self.freq == 0:
return self.callback()
def __repr__(self): # for `test_base_trainer.py`
return '{}:{}:{}'.format(self.callback, self.key, self.freq)
@DocInherit
class BaseTrainer(object):
"""
Base class for all trainers.
All the trainers provided in :mod:`tfsnippet.trainer` are not
designed to take control of the training totally, which is often
assumed in other libraries such as Keras. Instead, it just takes
responsibility of assembling different steps of a training process
together, and run the main training loop. So it is usually the caller's
responsibility to derive his training operation from a certain TensorFlow
optimizer, and pass it to a proper trainer.
The event schedule of a :class:`BaseTrainer` can be briefly described as::
events.fire(EventKeys.BEFORE_EXECUTION, self)
for epoch in epochs:
events.fire(EventKeys.BEFORE_EPOCH, self)
for step in steps:
events.fire(EventKeys.BEFORE_STEP, self)
... # actually train for a step
events.fire(EventKeys.STEP_EVALUATION, self)
events.fire(EventKeys.STEP_ANNEALING, self)
events.fire(EventKeys.STEP_LOGGING, self)
events.reverse_fire(EventKeys.AFTER_STEP, self)
events.fire(EventKeys.EPOCH_EVALUATION, self)
events.fire(EventKeys.EPOCH_ANNEALING, self)
events.fire(EventKeys.EPOCH_LOGGING, self)
events.reverse_fire(EventKeys.AFTER_EPOCH, self)
events.reverse_fire(EventKeys.AFTER_EXECUTION, self)
Using `trainer.events.on(EventKeys.AFTER_EPOCH, lambda trainer: ...)` can
register an after-epoch event handler. Handlers for other events can be
registered in a similar way.
To make things even simpler, we provide several methods to register
callbacks that will run every few epochs/steps, e.g.::
trainer.evaluate_after_epochs(
lambda: print('after epoch callback'), 10) # run every 10 epochs
trainer.log_after_steps(1000) # call `loop.print_logs` every 1000 steps
"""
def __init__(self, loop, ensure_variables_initialized=True):
"""
Initialize the internal states of :class:`BaseTrainer`.
Args:
loop (TrainLoop): The training loop object.
ensure_variables_initialized (bool): Whether or not to ensure
the variables are initialized in :meth:`run()`?
"""
self._loop = loop
self._ensure_variables_initialized = ensure_variables_initialized
self._events = EventSource([
EventKeys.BEFORE_EXECUTION,
EventKeys.AFTER_EXECUTION,
EventKeys.BEFORE_EPOCH,
EventKeys.EPOCH_EVALUATION,
EventKeys.EPOCH_ANNEALING,
EventKeys.EPOCH_LOGGING,
EventKeys.AFTER_EPOCH,
EventKeys.BEFORE_STEP,
EventKeys.STEP_EVALUATION,
EventKeys.STEP_ANNEALING,
EventKeys.STEP_LOGGING,
EventKeys.AFTER_STEP,
])
self._is_fitting = False
@property
def loop(self):
"""
Get the training loop object.
Returns:
TrainLoop: The training loop object.
"""
return self._loop
@property
def events(self):
"""
Get the event source object.
Returns:
EventSource: The event source object.
"""
return self._events
def run(self):
"""Run training loop."""
if self._is_fitting:
raise RuntimeError('`run()` is not re-entrant.')
self._is_fitting = True
try:
# trigger the before execution event
self.events.fire(EventKeys.BEFORE_EXECUTION, self)
# initialize global training status
session = get_default_session_or_error()
if self._ensure_variables_initialized:
ensure_variables_initialized()
self.loop.print_training_summary()
for _ in self.loop.iter_epochs():
# trigger before epoch event
self.events.fire(EventKeys.BEFORE_EPOCH, self)
# run steps of this epoch
for payload in self._iter_steps():
# trigger before step event
self.events.fire(EventKeys.BEFORE_STEP, self)
# run the step
self._run_step(session, payload)
# trigger after step events
self.events.fire(EventKeys.STEP_EVALUATION, self)
self.events.fire(EventKeys.STEP_ANNEALING, self)
self.events.fire(EventKeys.STEP_LOGGING, self)
self.events.reverse_fire(EventKeys.AFTER_STEP, self)
# trigger after epoch events
self.events.fire(EventKeys.EPOCH_EVALUATION, self)
self.events.fire(EventKeys.EPOCH_ANNEALING, self)
self.events.fire(EventKeys.EPOCH_LOGGING, self)
self.events.reverse_fire(EventKeys.AFTER_EPOCH, self)
# trigger the after execution event
self.events.reverse_fire(EventKeys.AFTER_EXECUTION, self)
finally:
self._is_fitting = False
def _iter_steps(self):
"""
Subclasses should override this to iterate through steps.
A common implementation of :meth:`_iter_steps()` might be::
def _iter_steps(self):
return self.loop.iter_steps(training_data)
Yields:
int or (int, tuple[np.ndarray]): The step counter, or the step
counter as well as the step training data. Will be directly
given to :meth:`_fit_step` as the `payload` argument.
"""
raise NotImplementedError()
def _run_step(self, session, payload):
"""
Subclasses should override this to run a training step.
Args:
session: The TensorFlow session.
payload: The step payload generated by :meth:`_iter_steps`.
"""
raise NotImplementedError()
def log_after_steps(self, freq):
"""
Add a logging hook to run after every few steps.
Args:
freq (int): The frequency for this logging hook to run.
"""
self.events.on(
EventKeys.STEP_LOGGING,
OnEveryFewCalls('step', freq, self.loop.print_logs)
)
def log_after_epochs(self, freq):
"""
Add a logging hook to run after every few epochs.
Args:
freq (int): The frequency for this logging hook to run.
"""
self.events.on(
EventKeys.EPOCH_LOGGING,
OnEveryFewCalls('epoch', freq, self.loop.print_logs)
)
def log_after(self, epochs=None, steps=None):
"""
Add a logging hook to run after every few epochs or steps.
Args:
epochs (None or int): Run validation after every this few `epochs`.
steps (None or int): Run validation after every this few `steps`.
Raises:
ValueError: If both `epochs` and `steps` are specified, or neither
is specified.
"""
check_epochs_and_steps_arg(epochs, steps)
if epochs is not None:
return self.log_after_epochs(epochs)
else:
return self.log_after_steps(steps)
def remove_log_hooks(self):
"""
Remove logging hooks from all lists.
Returns:
int: The number of removed hooks.
"""
self.events.clear_event_handlers(EventKeys.STEP_LOGGING)
self.events.clear_event_handlers(EventKeys.EPOCH_LOGGING)
def evaluate_after_steps(self, evaluator, freq):
"""
Add an evaluation hook to run after every few steps.
Args:
evaluator (Evaluator or () -> any): A evaluator object
(which has ``.run()``), or any callable object.
freq (int): The frequency for this evaluation hook to run.
"""
callback = evaluator if callable(evaluator) else evaluator.run
self.events.on(
EventKeys.STEP_EVALUATION,
OnEveryFewCalls('step', freq, callback)
)
def evaluate_after_epochs(self, evaluator, freq):
"""
Add an evaluation hook to run after every few epochs.
Args:
evaluator (Evaluator or () -> any): A evaluator object
(which has ``.run()``), or any callable object.
freq (int): The frequency for this evaluation hook to run.
"""
callback = evaluator if callable(evaluator) else evaluator.run
self.events.on(
EventKeys.EPOCH_EVALUATION,
OnEveryFewCalls('epoch', freq, callback)
)
def evaluate_after(self, evaluator, epochs=None, steps=None):
"""
Add an evaluation hook to run after every few epochs or steps.
Args:
evaluator (Evaluator or () -> any): A evaluator object
(which has ``.run()``), or any callable object.
epochs (None or int): Run validation after every this few `epochs`.
steps (None or int): Run validation after every this few `steps`.
Raises:
ValueError: If both `epochs` and `steps` are specified, or neither
is specified.
"""
check_epochs_and_steps_arg(epochs, steps)
if epochs is not None:
return self.evaluate_after_epochs(evaluator, freq=epochs)
else:
return self.evaluate_after_steps(evaluator, freq=steps)
def remove_evaluation_hooks(self):
"""
Remove evaluation hooks from all lists.
Returns:
int: The number of removed hooks.
"""
self.events.clear_event_handlers(EventKeys.STEP_EVALUATION)
self.events.clear_event_handlers(EventKeys.EPOCH_EVALUATION)
# legacy names for evaluation
validate_after_steps = evaluate_after_steps
validate_after_epochs = evaluate_after_epochs
validate_after = evaluate_after
remove_validation_hooks = remove_evaluation_hooks
def anneal_after_steps(self, value, freq):
"""
Add an annealing hook to run after every few steps.
Args:
value (AnnealingVariable or () -> any): An annealing variable
(which has ``.anneal()``), or any callable object.
freq (int): The frequency for this annealing hook to run.
"""
callback = value if callable(value) else value.anneal
self.events.on(
EventKeys.STEP_ANNEALING,
OnEveryFewCalls('step', freq, callback)
)
def anneal_after_epochs(self, value, freq):
"""
Add an annealing hook to run after every few epochs.
Args:
value (AnnealingVariable or () -> any): An annealing variable
(which has ``.anneal()``), or any callable object.
freq (int): The frequency for this annealing hook to run.
"""
callback = value if callable(value) else value.anneal
self.events.on(
EventKeys.EPOCH_ANNEALING,
OnEveryFewCalls('epoch', freq, callback)
)
def anneal_after(self, value, epochs=None, steps=None):
"""
Add an annealing hook to run after every few epochs or steps.
Args:
value (AnnealingVariable or () -> any): An annealing variable
(which has ``.anneal()``), or any callable object.
epochs (None or int): Run validation after every this few `epochs`.
steps (None or int): Run validation after every this few `steps`.
Raises:
ValueError: If both `epochs` and `steps` are specified, or neither
is specified.
"""
check_epochs_and_steps_arg(epochs, steps)
if epochs is not None:
return self.anneal_after_epochs(value, freq=epochs)
else:
return self.anneal_after_steps(value, freq=steps)
def remove_annealing_hooks(self):
"""
Remove annealing hooks from all lists.
Returns:
int: The number of removed hooks.
"""
self.events.clear_event_handlers(EventKeys.STEP_ANNEALING)
self.events.clear_event_handlers(EventKeys.EPOCH_ANNEALING)