-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
training_algorithm.py
136 lines (118 loc) · 4.53 KB
/
training_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Module defining the interface for training algorithms."""
from pylearn2.datasets.dataset import Dataset
class TrainingAlgorithm(object):
"""
An abstract superclass that defines the interface of training
algorithms.
"""
def _register_update_callbacks(self, update_callbacks):
"""
.. todo::
WRITEME
"""
if update_callbacks is None:
update_callbacks = []
# If it's iterable, we're fine. If not, it's a single callback,
# so wrap it in a list.
try:
iter(update_callbacks)
self.update_callbacks = update_callbacks
except TypeError:
self.update_callbacks = [update_callbacks]
def setup(self, model, dataset):
"""
Initialize the given training algorithm.
Parameters
----------
model : object
Object that implements the Model interface defined in
`pylearn2.models`.
dataset : object
Object that implements the Dataset interface defined in
`pylearn2.datasets`.
Notes
-----
Called by the training script prior to any calls involving data.
This is a good place to compile theano functions for doing learning.
"""
self.model = model
def train(self, dataset):
"""
Performs some amount of training, generally one "epoch" of online
learning
Parameters
----------
dataset : object
Object implementing the dataset interface defined in
`pylearn2.datasets.dataset.Dataset`.
Returns
-------
None
"""
raise NotImplementedError()
def _set_monitoring_dataset(self, monitoring_dataset):
"""
.. todo::
WRITEME
Parameters
----------
monitoring_dataset : None or Dataset or dict
None for no monitoring, or Dataset, to monitor on one dataset,
or dict mapping string names to Datasets
"""
if isinstance(monitoring_dataset, Dataset):
self.monitoring_dataset = { '': monitoring_dataset }
else:
if monitoring_dataset is not None:
assert isinstance(monitoring_dataset, dict)
for key in monitoring_dataset:
assert isinstance(key, str)
value = monitoring_dataset[key]
if not isinstance(value, Dataset):
raise TypeError("Monitoring dataset with name " + key +
" is not a dataset, it is a " +
str(type(value)))
self.monitoring_dataset = monitoring_dataset
def continue_learning(self, model):
"""
Return True to continue learning. Called after the Monitor
has been run on the latest parameters so the monitor may be used
to determine convergence.
Parameters
----------
model : WRITEME
"""
raise NotImplementedError(str(type(self))+" does not implement " +
"continue_learning.")
def _synchronize_batch_size(self, model):
"""
Adapts `self.batch_size` to be consistent with `model`
Parameters
----------
model : Model
The model to synchronize the batch size with
"""
batch_size = self.batch_size
if hasattr(model, "force_batch_size"):
if model.force_batch_size and model.force_batch_size > 0:
if batch_size is not None:
if batch_size != model.force_batch_size:
if self.set_batch_size:
model.set_batch_size(batch_size)
else:
raise ValueError("batch_size argument to " +
str(type(self)) +
"conflicts with model's " +
"force_batch_size attribute")
else:
self.batch_size = model.force_batch_size
if self.batch_size is None:
raise NoBatchSizeError()
class NoBatchSizeError(ValueError):
"""
An exception raised when the user does not specify a batch size anywhere.
"""
def __init__(self):
super(NoBatchSizeError, self).__init__("Neither the "
"TrainingAlgorithm nor the model were given a specification "
"of the batch size.")