-
Notifications
You must be signed in to change notification settings - Fork 0
/
charRNN.cpp
741 lines (666 loc) · 27.9 KB
/
charRNN.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Hua Zhang mz24cn@hotmail.com
* The code implements C++ version charRNN for mxnet\example\rnn\char-rnn.ipynb with MXNet.cpp API.
* The generated params file is compatiable with python version.
* train() and predict() has been verified with original data samples.
* 2017/1/23:
* Add faster version charRNN based on built-in cuDNN RNN operator, 10 times faster.
* Add time major computation graph, although no substantial performance difference.
* Support continuing training from last params file.
* Rename params file epoch number starts from zero.
*/
#pragma warning(disable: 4996) // VS2015 complains on 'std::copy' ...
#include <cstring>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <vector>
#include <string>
#include <tuple>
#include <algorithm>
#include <functional>
#include <thread>
#include <chrono>
#include "mxnet-cpp/MxNetCpp.h"
using namespace std;
using namespace mxnet::cpp;
struct LSTMState {
Symbol C;
Symbol h;
};
struct LSTMParam {
Symbol i2h_weight;
Symbol i2h_bias;
Symbol h2h_weight;
Symbol h2h_bias;
};
bool TIME_MAJOR = true;
// LSTM Cell symbol
LSTMState LSTM(int num_hidden, const Symbol& indata, const LSTMState& prev_state,
const LSTMParam& param, int seqidx, int layeridx, mx_float dropout = 0) {
auto input = dropout > 0? Dropout(indata, dropout) : indata;
auto prefix = string("t") + to_string(seqidx) + "_l" + to_string(layeridx);
auto i2h = FullyConnected(prefix + "_i2h", input, param.i2h_weight, param.i2h_bias,
num_hidden * 4);
auto h2h = FullyConnected(prefix + "_h2h", prev_state.h, param.h2h_weight, param.h2h_bias,
num_hidden * 4);
auto gates = i2h + h2h;
auto slice_gates = SliceChannel(prefix + "_slice", gates, 4);
auto in_gate = Activation(slice_gates[0], ActivationActType::kSigmoid);
auto in_transform = Activation(slice_gates[1], ActivationActType::kTanh);
auto forget_gate = Activation(slice_gates[2], ActivationActType::kSigmoid);
auto out_gate = Activation(slice_gates[3], ActivationActType::kSigmoid);
LSTMState state;
state.C = (forget_gate * prev_state.C) + (in_gate * in_transform);
state.h = out_gate * Activation(state.C, ActivationActType::kTanh);
return state;
}
Symbol LSTMUnroll(int num_lstm_layer, int sequence_length, int input_dim,
int num_hidden, int num_embed, mx_float dropout = 0) {
auto isTrain = sequence_length > 1;
auto data = Symbol::Variable("data");
if (TIME_MAJOR && isTrain)
data = transpose(data);
auto embed_weight = Symbol::Variable("embed_weight");
auto embed = Embedding("embed", data, embed_weight, input_dim, num_embed);
auto wordvec = isTrain? SliceChannel(embed, sequence_length, TIME_MAJOR? 0 : 1, true) : embed;
vector<LSTMState> last_states;
vector<LSTMParam> param_cells;
for (int l = 0; l < num_lstm_layer; l++) {
string layer = "l" + to_string(l);
LSTMParam param;
param.i2h_weight = Symbol::Variable(layer + "_i2h_weight");
param.i2h_bias = Symbol::Variable(layer + "_i2h_bias");
param.h2h_weight = Symbol::Variable(layer + "_h2h_weight");
param.h2h_bias = Symbol::Variable(layer + "_h2h_bias");
param_cells.push_back(param);
LSTMState state;
state.C = Symbol::Variable(layer + "_init_c");
state.h = Symbol::Variable(layer + "_init_h");
last_states.push_back(state);
}
vector<Symbol> hidden_all;
for (int i = 0; i < sequence_length; i++) {
auto hidden = wordvec[i];
for (int layer = 0; layer < num_lstm_layer; layer++) {
double dp_ratio = layer == 0? 0 : dropout;
auto next_state = LSTM(num_hidden, hidden, last_states[layer], param_cells[layer],
i, layer, dp_ratio);
hidden = next_state.h;
last_states[layer] = next_state;
}
if (dropout > 0)
hidden = Dropout(hidden, dropout);
hidden_all.push_back(hidden);
}
auto hidden_concat = isTrain? Concat(hidden_all, hidden_all.size(), 0) : hidden_all[0];
auto cls_weight = Symbol::Variable("cls_weight");
auto cls_bias = Symbol::Variable("cls_bias");
auto pred = FullyConnected("pred", hidden_concat, cls_weight, cls_bias, input_dim);
auto label = Symbol::Variable("softmax_label");
label = transpose(label);
label = Reshape(label, Shape(), false, Shape(0), false); // -1: infer from graph
auto sm = SoftmaxOutput("softmax", pred, label);
if (isTrain)
return sm;
vector<Symbol> outputs = { sm };
for (auto& state : last_states) {
outputs.push_back(state.C);
outputs.push_back(state.h);
}
return Symbol::Group(outputs);
}
// Currently mxnet GPU version RNN operator is implemented via *fast* NVIDIA cuDNN.
Symbol LSTMWithBuiltInRNNOp(int num_lstm_layer, int sequence_length, int input_dim,
int num_hidden, int num_embed, mx_float dropout = 0) {
auto isTrain = sequence_length > 1;
auto data = Symbol::Variable("data");
if (TIME_MAJOR && isTrain)
data = transpose(data);
auto embed_weight = Symbol::Variable("embed_weight");
auto embed = Embedding("embed", data, embed_weight, input_dim, num_embed);
auto label = Symbol::Variable("softmax_label");
label = transpose(label);
label = Reshape(label, Shape(), false,
Shape(0), false); // FullyConnected requires one dimension
if (!TIME_MAJOR && isTrain)
embed = SwapAxis(embed, 0, 1); // Change to time-major as cuDNN requires
// We need not do the SwapAxis op as python version does. Direct and better performance in C++!
auto rnn_h_init = Symbol::Variable("LSTM_init_h");
auto rnn_c_init = Symbol::Variable("LSTM_init_c");
auto rnn_params = Symbol::Variable("LSTM_parameters"); // See explanations near RNNXavier class
auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, num_hidden, num_lstm_layer,
RNNMode::kLstm, false, dropout, !isTrain);
auto hidden = Reshape(rnn[0], Shape(), false, Shape(0, num_hidden), false);
auto cls_weight = Symbol::Variable("cls_weight");
auto cls_bias = Symbol::Variable("cls_bias");
auto pred = FullyConnected("pred", hidden, cls_weight, cls_bias, input_dim);
/*In rnn-time-major/rnn_cell_demo.py, the author claimed time-major version speeds up
* 1.5~2 times versus batch version. I doubts on the conclusion. In my test, the performance
* of both codes are almost same. In fact, there are no substantially differences between
* two codes. They are both based on time major cuDNN, the computation graph only differs
* slightly on the choices of where to put Reshape/SwapAxis/transpose operation. Here I don't
* use Reshape on pred and keep label shape on SoftmaxOutput like time major version code,
* but Reshape on label for simplification. It doesn't make influence on performacne. */
auto sm = SoftmaxOutput("softmax", pred, label);
if (isTrain)
return sm;
else
return Symbol::Group({ sm, rnn[1/*RNNOpOutputs::kStateOut=1*/],
rnn[2/*RNNOpOutputs::kStateCellOut=2*/] });
}
class Shuffler {
vector<int> sequence;
public:
explicit Shuffler(int size) : sequence(size) {
int* p = sequence.data();
for (int i = 0; i < size; i++)
*p++ = i;
}
void shuffle(function<void(int, int)> lambda = nullptr) {
random_shuffle(sequence.begin(), sequence.end());
int n = 0;
if (lambda != nullptr)
for (int i : sequence)
lambda(n++, i);
}
const int* data() {
return sequence.data();
}
};
class BucketSentenceIter : public DataIter {
Shuffler* random;
int batch, current, end;
unsigned int sequence_length;
Context device;
vector<vector<mx_float>> sequences;
vector<wchar_t> index2chars;
unordered_map<wchar_t, mx_float> charIndices;
public:
BucketSentenceIter(string filename, int minibatch, Context context) : batch(minibatch),
current(-1), device(context) {
auto content = readContent(filename);
buildCharIndex(content);
sequences = convertTextToSequences(content, '\n');
int N = sequences.size() / batch * batch; // total used samples
sequences.resize(N);
sort(sequences.begin(), sequences.end(), [](const vector<mx_float>& a,
const vector<mx_float>& b) { return a.size() < b.size(); });
sequence_length = sequences.back().size();
random = new Shuffler(N);
// We still can get random results if call Reset() firstly
// vector<vector<mx_float>>* target = &sequences;
// random->shuffle([target](int n, int i) { (*target)[n].swap((*target)[i]); });
end = N / batch;
}
virtual ~BucketSentenceIter() {
delete random;
}
unsigned int maxSequenceLength() {
return sequence_length;
}
size_t characterSize() {
return charIndices.size();
}
virtual bool Next(void) {
return ++current < end;
}
virtual NDArray GetData(void) {
const int* indices = random->data();
mx_float *data = new mx_float[sequence_length * batch], *pdata = data;
for (int i = current * batch, end = i + batch; i < end; i++) {
memcpy(pdata, sequences[indices[i]].data(), sequences[indices[i]].size() * sizeof(mx_float));
if (sequences[indices[i]].size() < sequence_length)
memset(pdata + sequences[indices[i]].size(), 0,
(sequence_length - sequences[indices[i]].size()) * sizeof(mx_float));
pdata += sequence_length;
}
NDArray array(Shape(batch, sequence_length), device, false);
array.SyncCopyFromCPU(data, batch * sequence_length);
return array;
}
virtual NDArray GetLabel(void) {
const int* indices = random->data();
mx_float *label = new mx_float[sequence_length * batch], *plabel = label;
for (int i = current * batch, end = i + batch; i < end; i++) {
memcpy(plabel, sequences[indices[i]].data() + 1,
(sequences[indices[i]].size() - 1) * sizeof(mx_float));
memset(plabel + sequences[indices[i]].size() - 1, 0,
(sequence_length - sequences[indices[i]].size() + 1) * sizeof(mx_float));
plabel += sequence_length;
}
NDArray array(Shape(batch, sequence_length), device, false);
array.SyncCopyFromCPU(label, batch * sequence_length);
return array;
}
virtual int GetPadNum(void) {
return sequence_length - sequences[random->data()[current * batch]].size();
}
virtual std::vector<int> GetIndex(void) {
const int* indices = random->data();
vector<int> list(indices + current * batch, indices + current * batch + batch);
return list;
}
virtual void BeforeFirst(void) {
current = -1;
random->shuffle(nullptr);
}
wstring readContent(const string file) {
wifstream ifs(file, ios::binary);
if (ifs) {
wostringstream os;
os << ifs.rdbuf();
return os.str();
}
return L"";
}
void buildCharIndex(const wstring& content) {
// This version buildCharIndex() Compatiable with python version char_rnn dictionary
int n = 1;
charIndices['\0'] = 0; // padding character
index2chars.push_back(0); // padding character index
for (auto c : content)
if (charIndices.find(c) == charIndices.end()) {
charIndices[c] = n++;
index2chars.push_back(c);
}
}
// void buildCharIndex(wstring& content) {
// for (auto c : content)
// charIndices[c]++; // char-frequency map; then char-index map
// vector<tuple<wchar_t, mx_float>> characters;
// for (auto& iter : charIndices)
// characters.push_back(make_tuple(iter.first, iter.second));
// sort(characters.begin(), characters.end(), [](const tuple<wchar_t, mx_float>& a,
// const tuple<wchar_t, mx_float>& b) { return get<1>(a) > get<1>(b); });
// mx_float index = 1; //0 is left for zero-padding
// index2chars.clear();
// index2chars.push_back(0); //zero-padding
// for (auto& t : characters) {
// charIndices[get<0>(t)] = index++;
// index2chars.push_back(get<0>(t));
// }
// }
inline wchar_t character(int i) {
return index2chars[i];
}
inline mx_float index(wchar_t c) {
return charIndices[c];
}
void saveCharIndices(const string file) {
wofstream ofs(file, ios::binary);
if (ofs) {
ofs.write(index2chars.data() + 1, index2chars.size() - 1);
ofs.close();
}
}
static tuple<unordered_map<wchar_t, mx_float>, vector<wchar_t>> loadCharIndices(
const string file) {
wifstream ifs(file, ios::binary);
unordered_map<wchar_t, mx_float> map;
vector<wchar_t> chars;
if (ifs) {
wostringstream os;
os << ifs.rdbuf();
int n = 1;
map[L'\0'] = 0;
chars.push_back(L'\0');
for (auto c : os.str()) {
map[c] = (mx_float) n++;
chars.push_back(c);
}
}
// Note: Can't use {} because this would hit the explicit constructor
return tuple<unordered_map<wchar_t, mx_float>, vector<wchar_t>>(map, chars);
}
vector<vector<mx_float>> convertTextToSequences(const wstring& content, wchar_t spliter) {
vector<vector<mx_float>> sequences;
sequences.push_back(vector<mx_float>());
for (auto c : content)
if (c == spliter && !sequences.back().empty())
sequences.push_back(vector<mx_float>());
else
sequences.back().push_back(charIndices[c]);
return sequences;
}
};
void OutputPerplexity(NDArray* labels, NDArray* output) {
vector<mx_float> charIndices, a;
labels->SyncCopyToCPU(&charIndices, 0L); // 0L indicates all
output->SyncCopyToCPU(&a, 0L)/*4128*84*/;
mx_float loss = 0;
int batchSize = labels->GetShape()[0]/*32*/, sequenceLength = labels->GetShape()[1]/*129*/,
nSamples = output->GetShape()[0]/*4128*/, vocabSize = output->GetShape()[1]/*84*/;
for (int n = 0; n < nSamples; n++) {
int row = n % batchSize, column = n / batchSize, labelOffset = column +
row * sequenceLength; // Search based on column storage: labels.T
mx_float safe_value = max(1e-10f, a[vocabSize * n +
static_cast<int>(charIndices[labelOffset])]);
loss += -log(safe_value); // Calculate negative log-likelihood
}
loss = exp(loss / nSamples);
cout << "Train-Perplexity=" << loss << endl;
}
void SaveCheckpoint(const string filepath, Symbol net, Executor* exe) {
map<string, NDArray> params;
for (auto iter : exe->arg_dict())
if (iter.first.find("_init_") == string::npos
&& iter.first.rfind("data") != iter.first.length() - 4
&& iter.first.rfind("label") != iter.first.length() - 5)
params.insert({"arg:" + iter.first, iter.second});
for (auto iter : exe->aux_dict())
params.insert({"aux:" + iter.first, iter.second});
NDArray::Save(filepath, params);
}
void LoadCheckpoint(const string filepath, Executor* exe) {
map<std::string, NDArray> params = NDArray::LoadToMap(filepath);
for (auto iter : params) {
string type = iter.first.substr(0, 4);
string name = iter.first.substr(4);
NDArray target;
if (type == "arg:")
target = exe->arg_dict()[name];
else if (type == "aux:")
target = exe->aux_dict()[name];
else
continue;
iter.second.CopyTo(&target);
}
}
int input_dim = 0;/*84*/
int sequence_length_max = 0;/*129*/
int num_embed = 256;
int num_lstm_layer = 3;
int num_hidden = 512;
mx_float dropout = 0.2;
void train(const string file, int batch_size, int max_epoch, int start_epoch) {
Context device(DeviceType::kGPU, 0);
BucketSentenceIter dataIter(file, batch_size, device);
string prefix = file.substr(0, file.rfind("."));
dataIter.saveCharIndices(prefix + ".dictionary");
input_dim = static_cast<int>(dataIter.characterSize());
sequence_length_max = dataIter.maxSequenceLength();
auto RNN = LSTMUnroll(num_lstm_layer, sequence_length_max, input_dim, num_hidden,
num_embed, dropout);
map<string, NDArray> args_map;
args_map["data"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
args_map["softmax_label"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
for (int i = 0; i < num_lstm_layer; i++) {
string key = "l" + to_string(i) + "_init_";
args_map[key + "c"] = NDArray(Shape(batch_size, num_hidden), device, false);
args_map[key + "h"] = NDArray(Shape(batch_size, num_hidden), device, false);
}
vector<mx_float> zeros(batch_size * num_hidden, 0);
// RNN.SimpleBind(device, args_map, {}, {{"data", kNullOp}});
Executor* exe = RNN.SimpleBind(device, args_map);
if (start_epoch == -1) {
Xavier xavier = Xavier(Xavier::gaussian, Xavier::in, 2.34);
for (auto &arg : exe->arg_dict())
xavier(arg.first, &arg.second);
} else {
LoadCheckpoint(prefix + "-" + to_string(start_epoch) + ".params", exe);
}
start_epoch++;
mx_float learning_rate = 0.0002;
mx_float weight_decay = 0.000002;
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("lr", learning_rate)
->SetParam("wd", weight_decay);
// opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
// ->SetParam("clip_gradient", 10);
for (int epoch = start_epoch; epoch < max_epoch; ++epoch) {
dataIter.Reset();
auto tic = chrono::system_clock::now();
while (dataIter.Next()) {
auto data_batch = dataIter.GetDataBatch();
data_batch.data.CopyTo(&exe->arg_dict()["data"]);
data_batch.label.CopyTo(&exe->arg_dict()["softmax_label"]);
for (int l = 0; l < num_lstm_layer; l++) {
string key = "l" + to_string(l) + "_init_";
exe->arg_dict()[key + "c"].SyncCopyFromCPU(zeros);
exe->arg_dict()[key + "h"].SyncCopyFromCPU(zeros);
}
NDArray::WaitAll();
exe->Forward(true);
exe->Backward();
for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
}
NDArray::WaitAll();
}
auto toc = chrono::system_clock::now();
cout << "Epoch[" << epoch << "] Time Cost:" <<
chrono::duration_cast<chrono::seconds>(toc - tic).count() << " seconds ";
OutputPerplexity(&exe->arg_dict()["softmax_label"], &exe->outputs[0]);
string filepath = prefix + "-" + to_string(epoch) + ".params";
SaveCheckpoint(filepath, RNN, exe);
}
}
/*The original example, rnn_cell_demo.py, uses default Xavier as initalizer, which relies on
* variable name, cannot initialize LSTM_parameters. Thus it was renamed to LSTM_bias,
* which can be initialized as zero. But it cannot converge after 100 epochs in this corpus
* example. Using RNNXavier, after 15 oscillating epochs, it rapidly converges like old
* LSTMUnroll version. */
class RNNXavier : public Xavier {
public:
RNNXavier(RandType rand_type = gaussian, FactorType factor_type = avg,
float magnitude = 3) : Xavier(rand_type, factor_type, magnitude) {
}
virtual ~RNNXavier() {}
protected:
virtual void InitDefault(NDArray* arr) {
Xavier::InitWeight(arr);
}
};
void trainWithBuiltInRNNOp(const string file, int batch_size, int max_epoch, int start_epoch) {
Context device(DeviceType::kGPU, 0);
BucketSentenceIter dataIter(file, batch_size, device);
string prefix = file.substr(0, file.rfind("."));
dataIter.saveCharIndices(prefix + ".dictionary");
input_dim = static_cast<int>(dataIter.characterSize());
sequence_length_max = dataIter.maxSequenceLength();
auto RNN = LSTMWithBuiltInRNNOp(num_lstm_layer, sequence_length_max, input_dim, num_hidden,
num_embed, dropout);
map<string, NDArray> args_map;
args_map["data"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
// Avoiding SwapAxis, batch_size is of second dimension.
args_map["LSTM_init_c"] = NDArray(Shape(num_lstm_layer, batch_size, num_hidden), device, false);
args_map["LSTM_init_h"] = NDArray(Shape(num_lstm_layer, batch_size, num_hidden), device, false);
args_map["softmax_label"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
vector<mx_float> zeros(batch_size * num_lstm_layer * num_hidden, 0);
Executor* exe = RNN.SimpleBind(device, args_map);
if (start_epoch == -1) {
RNNXavier xavier = RNNXavier(Xavier::gaussian, Xavier::in, 2.34);
for (auto &arg : exe->arg_dict())
xavier(arg.first, &arg.second);
} else {
LoadCheckpoint(prefix + "-" + to_string(start_epoch) + ".params", exe);
}
start_epoch++;
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
// opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
// ->SetParam("clip_gradient", 10);
for (int epoch = start_epoch; epoch < max_epoch; ++epoch) {
dataIter.Reset();
auto tic = chrono::system_clock::now();
while (dataIter.Next()) {
auto data_batch = dataIter.GetDataBatch();
data_batch.data.CopyTo(&exe->arg_dict()["data"]);
data_batch.label.CopyTo(&exe->arg_dict()["softmax_label"]);
exe->arg_dict()["LSTM_init_c"].SyncCopyFromCPU(zeros);
exe->arg_dict()["LSTM_init_h"].SyncCopyFromCPU(zeros);
NDArray::WaitAll();
exe->Forward(true);
exe->Backward();
for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
}
NDArray::WaitAll();
}
auto toc = chrono::system_clock::now();
cout << "Epoch[" << epoch << "] Time Cost:" <<
chrono::duration_cast<chrono::seconds>(toc - tic).count() << " seconds ";
OutputPerplexity(&exe->arg_dict()["softmax_label"], &exe->outputs[0]);
string filepath = prefix + "-" + to_string(epoch) + ".params";
SaveCheckpoint(filepath, RNN, exe);
}
}
void predict(wstring* ptext, int sequence_length, const string param_file,
const string dictionary_file) {
Context device(DeviceType::kGPU, 0);
auto results = BucketSentenceIter::loadCharIndices(dictionary_file);
auto dictionary = get<0>(results);
auto charIndices = get<1>(results);
input_dim = static_cast<int>(charIndices.size());
auto RNN = LSTMUnroll(num_lstm_layer, 1, input_dim, num_hidden, num_embed, 0);
map<string, NDArray> args_map;
args_map["data"] = NDArray(Shape(1, 1), device, false);
args_map["softmax_label"] = NDArray(Shape(1, 1), device, false);
vector<mx_float> zeros(1 * num_hidden, 0);
for (int l = 0; l < num_lstm_layer; l++) {
string key = "l" + to_string(l) + "_init_";
args_map[key + "c"] = NDArray(Shape(1, num_hidden), device, false);
args_map[key + "h"] = NDArray(Shape(1, num_hidden), device, false);
args_map[key + "c"].SyncCopyFromCPU(zeros);
args_map[key + "h"].SyncCopyFromCPU(zeros);
}
Executor* exe = RNN.SimpleBind(device, args_map);
LoadCheckpoint(param_file, exe);
mx_float index;
wchar_t next = 0;
vector<mx_float> softmax;
softmax.resize(input_dim);
for (auto c : *ptext) {
exe->arg_dict()["data"].SyncCopyFromCPU(&dictionary[c], 1);
exe->Forward(false);
exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
for (int l = 0; l < num_lstm_layer; l++) {
string key = "l" + to_string(l) + "_init_";
exe->outputs[l * 2 + 1].CopyTo(&args_map[key + "c"]);
exe->outputs[l * 2 + 2].CopyTo(&args_map[key + "h"]);
}
size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
index = (mx_float) n;
next = charIndices[n];
}
ptext->push_back(next);
for (int i = 0; i < sequence_length; i++) {
exe->arg_dict()["data"].SyncCopyFromCPU(&index, 1);
exe->Forward(false);
exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
for (int l = 0; l < num_lstm_layer; l++) {
string key = "l" + to_string(l) + "_init_";
exe->outputs[l * 2 + 1].CopyTo(&args_map[key + "c"]);
exe->outputs[l * 2 + 2].CopyTo(&args_map[key + "h"]);
}
size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
index = (mx_float) n;
next = charIndices[n];
ptext->push_back(next);
}
}
void predictWithBuiltInRNNOp(wstring* ptext, int sequence_length, const string param_file,
const string dictionary_file) {
Context device(DeviceType::kGPU, 0);
auto results = BucketSentenceIter::loadCharIndices(dictionary_file);
auto dictionary = get<0>(results);
auto charIndices = get<1>(results);
input_dim = static_cast<int>(charIndices.size());
auto RNN = LSTMWithBuiltInRNNOp(num_lstm_layer, 1, input_dim, num_hidden, num_embed, 0);
map<string, NDArray> args_map;
args_map["data"] = NDArray(Shape(1, 1), device, false);
args_map["softmax_label"] = NDArray(Shape(1, 1), device, false);
vector<mx_float> zeros(1 * num_lstm_layer * num_hidden, 0);
// Avoiding SwapAxis, batch_size=1 is of second dimension.
args_map["LSTM_init_c"] = NDArray(Shape(num_lstm_layer, 1, num_hidden), device, false);
args_map["LSTM_init_h"] = NDArray(Shape(num_lstm_layer, 1, num_hidden), device, false);
args_map["LSTM_init_c"].SyncCopyFromCPU(zeros);
args_map["LSTM_init_h"].SyncCopyFromCPU(zeros);
Executor* exe = RNN.SimpleBind(device, args_map);
LoadCheckpoint(param_file, exe);
mx_float index;
wchar_t next = 0;
vector<mx_float> softmax;
softmax.resize(input_dim);
for (auto c : *ptext) {
exe->arg_dict()["data"].SyncCopyFromCPU(&dictionary[c], 1);
exe->Forward(false);
exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
exe->outputs[1].CopyTo(&args_map["LSTM_init_h"]);
exe->outputs[2].CopyTo(&args_map["LSTM_init_c"]);
size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
index = (mx_float) n;
next = charIndices[n];
}
ptext->push_back(next);
for (int i = 0; i < sequence_length; i++) {
exe->arg_dict()["data"].SyncCopyFromCPU(&index, 1);
exe->Forward(false);
exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
exe->outputs[1].CopyTo(&args_map["LSTM_init_h"]);
exe->outputs[2].CopyTo(&args_map["LSTM_init_c"]);
size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
index = (mx_float) n;
next = charIndices[n];
ptext->push_back(next);
}
}
int main(int argc, char** argv) {
if (argc < 5) {
cout << "Usage for training: charRNN train[BuiltIn][TimeMajor] {corpus file}"
" {batch size} {max epoch} [{starting epoch}]" << endl;
cout <<"Usage for prediction: charRNN predict[BuiltIn][TimeMajor] {params file}"
" {dictionary file} {beginning of text}" << endl;
cout <<"Note: The {params file} of train/trainBuiltIn/trainTimeMajor/trainBuiltInTimeMajor"
" are not compatible with each other." << endl;
return 0;
}
string task = argv[1];
bool builtIn = task.find("BuiltIn") != string::npos;
TIME_MAJOR = task.find("TimeMajor") != string::npos;
cout << "use BuiltIn cuDNN RNN: " << builtIn << endl
<< "use data as TimeMajor: " << TIME_MAJOR << endl;
if (task.find("train") == 0) {
cout << "train batch size: " << argv[3] << endl
<< "train max epoch: " << argv[4] << endl;
int start_epoch = argc > 5? atoi(argv[5]) : -1;
// this function will generate dictionary file and params file.
if (builtIn)
trainWithBuiltInRNNOp(argv[2], atoi(argv[3]), atoi(argv[4]), start_epoch);
else
train(argv[2], atoi(argv[3]), atoi(argv[4]), start_epoch); // ditto
} else if (task.find("predict") == 0) {
wstring text; // = L"If there is anyone out there who still doubts ";
// Considering of extending to Chinese samples in future, use wchar_t instead of char
for (char c : string(argv[4]))
text.push_back((wchar_t) c);
/*Python version predicts text default to random selecltions. Here I didn't write the random
code, always choose the 'best' character. So the text length reduced to 600. Longer size often
leads to repeated sentances, since training sequence length is only 129 for obama corpus.*/
if (builtIn)
predictWithBuiltInRNNOp(&text, 600, argv[2], argv[3]);
else
predict(&text, 600, argv[2], argv[3]);
wcout << text << endl;
}
MXNotifyShutdown();
return 0;
}