forked from kaldi-asr/kaldi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rnnlm-core-training.h
252 lines (210 loc) · 10.6 KB
/
rnnlm-core-training.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
// rnnlm/rnnlm-core-training.h
// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_RNNLM_RNNLM_CORE_TRAINING_H_
#define KALDI_RNNLM_RNNLM_CORE_TRAINING_H_
#include "base/kaldi-common.h"
#include "matrix/matrix-lib.h"
#include "rnnlm/rnnlm-example.h"
#include "nnet3/nnet-nnet.h"
#include "nnet3/nnet-compute.h"
#include "nnet3/nnet-optimize.h"
#include "rnnlm/rnnlm-example-utils.h"
namespace kaldi {
namespace rnnlm {
// These are options relating to the core RNNLM training,
// i.e. training the actual neural net that is for the RNNLM
// (when the word embeddings are given). This is analogous
// to NnetTrainerOptions in ../nnet-training.h, except that
// with the RNNLM the training code a few things are different,
// so we're using a totally separate training class.
// We'll add options as we need them.
struct RnnlmCoreTrainerOptions {
int32 print_interval;
BaseFloat momentum;
BaseFloat max_param_change;
BaseFloat l2_regularize_factor;
BaseFloat backstitch_training_scale;
int32 backstitch_training_interval;
RnnlmCoreTrainerOptions():
print_interval(100),
momentum(0.0),
max_param_change(2.0),
l2_regularize_factor(1.0),
backstitch_training_scale(0.0),
backstitch_training_interval(1) { }
void Register(OptionsItf *opts) {
opts->Register("momentum", &momentum, "Momentum constant to apply during "
"training (help stabilize update). e.g. 0.9. Note: we "
"automatically multiply the learning rate by (1-momentum) "
"so that the 'effective' learning rate is the same as "
"before (because momentum would normally increase the "
"effective learning rate by 1/(1-momentum))");
opts->Register("max-param-change", &max_param_change, "The maximum change in "
"parameters allowed per minibatch, measured in Euclidean norm "
"over the entire model (change will be clipped to this value)");
opts->Register("l2-regularize-factor", &l2_regularize_factor, "Factor that "
"affects the strength of l2 regularization on model "
"parameters. The primary way to specify this type of "
"l2 regularization is via the 'l2-regularize'"
"configuration value at the config-file level. "
"--l2-regularize-factor will be multiplied by the component-level "
"l2-regularize values and can be used to correct for effects "
"related to parallelization by model averaging.");
opts->Register("backstitch-training-scale", &backstitch_training_scale,
"backstitch training factor. "
"if 0 then in the normal training mode. It is referred to as "
"'\\alpha' in our publications.");
opts->Register("backstitch-training-interval",
&backstitch_training_interval,
"do backstitch training with the specified interval of "
"minibatches. It is referred to as 'n' in our publications.");
}
};
class ObjectiveTracker {
public:
ObjectiveTracker(int32 reporting_interval);
void AddStats(BaseFloat weight, BaseFloat num_objf,
BaseFloat den_objf,
BaseFloat exact_den_objf = 0.0);
~ObjectiveTracker(); // Prints stats for the final interval, and the overall
// stats.
private:
// prints the stats for the current interval.
void PrintStatsThisInterval() const;
// zeroes the stats for the current interval and adds them to
// the global stats.
void CommitIntervalStats();
// prints the overall stats.
void PrintStatsOverall() const;
int32 reporting_interval_;
int32 num_egs_this_interval_;
double tot_weight_this_interval_; // sum of weights of outputs-- the
// objective is to be divided by this.
double num_objf_this_interval_; // numerator term in objective
double den_objf_this_interval_; // denominator term in objective, to be
// added to numerator term.
// exact_den_objf_this_interval_ is the exact version of the denominator term,
// of the form log(sum(...)) instead of sum(...). This is included for
// debugging and diagnostic purposes, and it will be zero if we're using
// sampling (which will be most of the time).
int32 exact_den_objf_this_interval_;
// the following versions of the variables are overall, not just for
// this interval: once we finish an interval (e.g. 100 examples), we
// add the '...this_interval_' versions of the variables to the
// variables below.
int32 num_egs_;
double tot_weight_;
double num_objf_;
double den_objf_;
double exact_den_objf_;
};
/** This class does the core part of the training of the RNNLM; the
word embeddings are supplied to this class for each minibatch and
while this class can compute objective function derivatives w.r.t.
these embeddings, it is not responsible for updating them.
*/
class RnnlmCoreTrainer {
public:
/** Constructor.
@param [in] config Structure that holds configuration options
@param [in,out] nnet The neural network that is to be trained.
Will be modified each time you call Train().
*/
RnnlmCoreTrainer(const RnnlmCoreTrainerOptions &config,
const RnnlmObjectiveOptions &objective_config,
nnet3::Nnet *nnet);
/* Train on one minibatch.
@param [in] minibatch The RNNLM minibatch to train on, containing
a number of parallel word sequences. It will not
necessarily contain words with the 'original'
numbering, it will in most circumstances contain
just the ones we used; see RenumberRnnlmMinibatch().
@param [in] derived Derived quantities of the minibatch, pre-computed by
calling GetRnnlmExampleDerived() with suitable arguments.
@param [in] word_embedding The matrix giving the embedding of words, of
dimension minibatch.vocab_size by the embedding dimension.
The numbering of the words does not have to be the 'real'
numbering of words, it can consist of words renumbered
by RenumberRnnlmMinibatch(); it just has to be
consistent with the word-ids present in 'minibatch'.
@param [out] word_embedding_deriv If supplied, the derivative of the
objective function w.r.t. the word embedding will be
*added* to this location; it must have the same
dimension as 'word_embedding'.
*/
void Train(const RnnlmExample &minibatch,
const RnnlmExampleDerived &derived,
const CuMatrixBase<BaseFloat> &word_embedding,
CuMatrixBase<BaseFloat> *word_embedding_deriv = NULL);
// The backstitch version of the above function. Depending
// on whether is_backstitch_step1 is true, It could be either the first
// (backward) step, or the second (forward) step of backstitch.
void TrainBackstitch(bool is_backstitch_step1,
const RnnlmExample &minibatch,
const RnnlmExampleDerived &derived,
const CuMatrixBase<BaseFloat> &word_embedding,
CuMatrixBase<BaseFloat> *word_embedding_deriv = NULL);
// Prints out the final stats.
void PrintTotalStats() const;
// Prints out the max-change stats (if nonzero): the percentage of time that
// per-component max-change and global max-change were enforced.
void PrintMaxChangeStats() const;
// Calls ConsolidateMemory() on nnet_ and delta_nnet_.
void ConsolidateMemory();
~RnnlmCoreTrainer();
private:
void ProvideInput(const RnnlmExample &minibatch,
const RnnlmExampleDerived &derived,
const CuMatrixBase<BaseFloat> &word_embedding,
nnet3::NnetComputer *computer);
/** Process the output of the neural net and record the objective function
in objf_info_.
@param [in] is_backstitch_step1 If true update stats otherwise not.
@param [in] minibatch The minibatch for which we're proessing the output.
@param [in] derived Derived quantities from the minibatch.
@param [in] word_embedding The word embedding, with the same numbering as
used in the minibatch (may be subsampled at this point).
@param [out] word_embedding_deriv If non-NULL, the part of the derivative
w.r.t. the word-embedding that arises from the output
computation will be *added* to here.
*/
void ProcessOutput(bool is_backstitch_step1,
const RnnlmExample &minibatch,
const RnnlmExampleDerived &derived,
const CuMatrixBase<BaseFloat> &word_embedding,
nnet3::NnetComputer *computer,
CuMatrixBase<BaseFloat> *word_embedding_deriv = NULL);
// Applies per-component max-change and global max-change to all updatable
// components in *delta_nnet_, and use *delta_nnet_ to update parameters
// in *nnet_.
void UpdateParamsWithMaxChange();
const RnnlmCoreTrainerOptions config_;
const RnnlmObjectiveOptions objective_config_;
nnet3::Nnet *nnet_;
nnet3::Nnet *delta_nnet_; // nnet representing parameter-change for this
// minibatch (or, when using momentum, its moving
// weighted average).
nnet3::CachingOptimizingCompiler compiler_;
int32 num_minibatches_processed_;
// stats for max-change.
std::vector<int32> num_max_change_per_component_applied_;
int32 num_max_change_global_applied_;
ObjectiveTracker objf_info_;
};
} // namespace rnnlm
} // namespace kaldi
#endif //KALDI_RNNLM_RNNLM_CORE_TRAINING_H_