forked from awslabs/gluonts
/
model_iteration_averaging.py
366 lines (312 loc) · 11.1 KB
/
model_iteration_averaging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Any, Dict, List, Optional
import logging
import mxnet as mx
import mxnet.gluon.nn as nn
from mxnet import gluon
from gluonts.core.component import validated
from .callback import Callback
class IterationAveragingStrategy:
r"""
The model averaging is based on paper
"Stochastic Gradient Descent for Non-smooth Optimization: Convergence
Results and Optimal Averaging Schemes",
(http://proceedings.mlr.press/v28/shamir13.pdf), which implements
polynomial-decay averaging, parameterized by eta. When eta = 0, it is
equivalent to simple average over all iterations with same weights.
"""
averaged_model: Optional[Dict[str, mx.nd.NDArray]]
cached_model: Optional[Dict[str, mx.nd.NDArray]]
average_counter: int
averaging_started: bool
@validated()
def __init__(self, eta: float = 0):
r"""
Parameters
----------
eta
Parameter of polynomial-decay averaging.
"""
self.eta = eta
# Dict that maintains the averaged model parameters.
self.averaged_model = None
# Temporarily save the current model, so that the averaged model can be
# used for validation.
self.cached_model = None
# The number of models accumulated in the average.
self.average_counter = 0
# Indicate whether the model averaging has started.
self.averaging_started = False
def update_average_trigger(
self, metric: Any = None, epoch: int = 0, **kwargs
):
r"""
Parameters
----------
metric
The criteria to trigger averaging.
epoch
The epoch to start averaging.
Returns
-------
"""
raise NotImplementedError()
def apply(self, model: nn.HybridBlock) -> Optional[Dict]:
r"""
Parameters
----------
model
The model of the current iteration.
Returns
-------
The averaged model, None if the averaging hasn't started.
"""
if self.averaging_started:
self.update_average(model)
return self.averaged_model
def update_average(self, model: nn.HybridBlock):
r"""
Parameters
----------
model
The model to update the average.
"""
self.average_counter += 1
if self.averaged_model is None:
self.averaged_model = {
k: v.list_data()[0].copy()
for k, v in model.collect_params().items()
}
else:
alpha = (self.eta + 1.0) / (self.eta + self.average_counter)
# moving average
for name, param_avg in self.averaged_model.items():
param_avg[:] += alpha * (
model.collect_params()[name].list_data()[0] - param_avg
)
def load_averaged_model(self, model: nn.HybridBlock):
r"""
When validating/evaluating the averaged model in the half way of
training, use load_averaged_model first to load the averaged model and
overwrite the current model, do the evaluation, and then use
load_cached_model to load the current model back.
Parameters
----------
model
The model that the averaged model is loaded to.
"""
if self.averaged_model is not None:
# cache the current model
if self.cached_model is None:
self.cached_model = {
k: v.list_data()[0].copy()
for k, v in model.collect_params().items()
}
else:
for name, param_cached in self.cached_model.items():
param_cached[:] = model.collect_params()[name].list_data()[
0
]
# load the averaged model
for name, param_avg in self.averaged_model.items():
model.collect_params()[name].set_data(param_avg)
def load_cached_model(self, model: nn.HybridBlock):
r"""
Parameters
----------
model
The model that the cached model is loaded to.
"""
if self.cached_model is not None:
# load the cached model
for name, param_cached in self.cached_model.items():
model.collect_params()[name].set_data(param_cached)
class NTA(IterationAveragingStrategy):
r"""
Implement Non-monotonically Triggered AvSGD (NTA).
This method is based on paper "Regularizing and Optimizing LSTM Language
Models", (https://openreview.net/pdf?id=SyyGPP0TZ), and an implementation
is available in Salesforce GitHub
(https://github.com/salesforce/awd-lstm-lm/blob/master/main.py). Note that
it mismatches the arxiv (and gluonnlp) version, which is referred to as
NTA_V2 below.
"""
val_logs: List[Any]
@validated()
def __init__(
self,
epochs: int,
n: int = 5,
maximize: bool = False,
last_n_trigger: bool = False,
eta: float = 0,
fallback_alpha: float = 0.05,
):
r"""
Depending on the choice of metrics, the users may want to minimize or
maximize the metrics. Thus, set maximize = True to maximize, otherwise
minimize.
Parameters
----------
epochs
The total number of epochs.
n
The non-montone interval.
maximize
Whether to maximize or minimize the validation metric.
eta
Parameter of polynomial-decay averaging.
last_n_trigger
If True, use [-n:] in average trigger, otherwise use [:-n].
fallback_alpha
Fallback epoch proportion of averaging.
"""
super().__init__(eta=eta)
assert 0 <= fallback_alpha <= 1
self.n = n
self.maximize = maximize
self.last_n_trigger = last_n_trigger
# Historical validation metrics.
self.val_logs = []
# The epoch where we fallback to alpha suffix. This solves the edge
# case where the averaging is never triggered and without the fallback
# the model of the last epoch would be returned.
self.fallback_alpha_suffix = epochs * (1.0 - fallback_alpha)
def update_average_trigger(
self, metric: Any = None, epoch: int = 0, **kwargs
):
r"""
Parameters
----------
metric
The criteria to trigger averaging.
epoch
The epoch to start averaging, not used in NTA
Returns
-------
"""
# If not triggered already due to epoch loss check fallback condition
if not self.averaging_started:
if epoch >= self.fallback_alpha_suffix:
self.averaging_started = True
if not self.averaging_started and self.n > 0:
min_len = self.n if self.last_n_trigger else (self.n + 1)
sliced_val_logs = (
self.val_logs[-self.n :]
if self.last_n_trigger
else self.val_logs[: -self.n]
)
if self.maximize:
if len(self.val_logs) >= min_len and metric < max(
sliced_val_logs
):
self.averaging_started = True
else:
if len(self.val_logs) >= min_len and metric > min(
sliced_val_logs
):
self.averaging_started = True
self.val_logs.append(metric)
class Alpha_Suffix(IterationAveragingStrategy):
r"""
Implement Alpha Suffix model averaging.
This method is based on paper "Making Gradient Descent Optimalfor Strongly
Convex Stochastic Optimization" (https://arxiv.org/pdf/1109.5647.pdf).
"""
alpha_suffix: float
@validated()
def __init__(self, epochs: int, alpha: float = 0.75, eta: float = 0):
r"""
Taking iteration average for the last epoch*alpha epochs
Parameters
----------
epochs
The total number of epochs.
alpha
Proportion of averaging.
eta
Parameter of polynomial-decay averaging.
"""
super().__init__(eta=eta)
assert 0 <= alpha <= 1
# The epoch where iteration averaging starts.
self.alpha_suffix = epochs * (1.0 - alpha)
def update_average_trigger(
self, metric: Any = None, epoch: int = 0, **kwargs
):
r"""
Parameters
----------
metric
The criteria to trigger averaging, not used in Alpha Suffix.
epoch
The epoch to start averaging.
Returns
-------
"""
if not self.averaging_started:
if epoch >= self.alpha_suffix:
self.averaging_started = True
class ModelIterationAveraging(Callback):
"""
Callback to implement iteration based model averaging strategies.
Parameters
----------
avg_strategy
IterationAveragingStrategy, one of NTA or Alpha_Suffix from
gluonts.mx.trainer.model_iteration_averaging
"""
@validated()
def __init__(self, avg_strategy: IterationAveragingStrategy):
self.avg_strategy = avg_strategy
def on_validation_epoch_start(
self, training_network: nn.HybridBlock
) -> None:
# use averaged model for validation
self.avg_strategy.load_averaged_model(training_network)
def on_validation_epoch_end(
self,
epoch_no: int,
epoch_loss: float,
training_network: nn.HybridBlock,
trainer: gluon.Trainer,
) -> bool:
self.avg_strategy.load_cached_model(training_network)
return True
def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool:
self.avg_strategy.apply(training_network)
return True
def on_epoch_end(
self,
epoch_no: int,
epoch_loss: float,
training_network: nn.HybridBlock,
trainer: gluon.Trainer,
best_epoch_info: Dict[str, Any],
ctx: mx.Context,
) -> bool:
self.avg_strategy.update_average_trigger(
metric=epoch_loss, epoch=epoch_no + 1
)
# once triggered, update the average immediately
self.avg_strategy.apply(training_network)
return True
def on_train_end(
self,
training_network: nn.HybridBlock,
temporary_dir: str,
ctx: mx.context.Context = None,
) -> None:
logging.info("Loading averaged parameters.")
self.avg_strategy.load_averaged_model(training_network)