/
train_xor.cc
72 lines (60 loc) · 1.96 KB
/
train_xor.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include "dynet/dynet.h"
#include "dynet/training.h"
#include "dynet/expr.h"
#include "dynet/io.h"
#include "dynet/model.h"
#include <iostream>
using namespace std;
using namespace dynet;
int main(int argc, char** argv) {
dynet::initialize(argc, argv);
const unsigned ITERATIONS = 30;
// ParameterCollection (all the model parameters).
ParameterCollection m;
SimpleSGDTrainer trainer(m);
const unsigned HIDDEN_SIZE = 8;
Parameter p_W = m.add_parameters({HIDDEN_SIZE, 2});
Parameter p_b = m.add_parameters({HIDDEN_SIZE});
Parameter p_V = m.add_parameters({1, HIDDEN_SIZE});
Parameter p_a = m.add_parameters({1});
if (argc == 2) {
// Load the model and parameters from file if given.
TextFileLoader loader(argv[1]);
loader.populate(m);
}
// Static declaration of the computation graph.
ComputationGraph cg;
Expression W = parameter(cg, p_W);
Expression b = parameter(cg, p_b);
Expression V = parameter(cg, p_V);
Expression a = parameter(cg, p_a);
// Set x_values to change the inputs to the network.
vector<dynet::real> x_values(2);
Expression x = input(cg, {2}, &x_values);
dynet::real y_value; // Set y_value to change the target output.
Expression y = input(cg, &y_value);
Expression h = tanh(W*x + b);
Expression y_pred = V*h + a;
Expression loss_expr = squared_distance(y_pred, y);
// Show the computation graph, just for fun.
cg.print_graphviz();
// Train the parameters.
for (unsigned iter = 0; iter < ITERATIONS; ++iter) {
double loss = 0;
for (unsigned mi = 0; mi < 4; ++mi) {
bool x1 = mi % 2;
bool x2 = (mi / 2) % 2;
x_values[0] = x1 ? 1 : -1;
x_values[1] = x2 ? 1 : -1;
y_value = (x1 != x2) ? 1 : -1;
loss += as_scalar(cg.forward(loss_expr));
cg.backward(loss_expr);
trainer.update();
}
loss /= 4;
cerr << "E = " << loss << endl;
}
// Output the model and parameter objects to a file.
TextFileSaver saver("xor.model");
saver.save(m);
}