-
Notifications
You must be signed in to change notification settings - Fork 704
/
xor.py
81 lines (67 loc) · 1.5 KB
/
xor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import sys
import dynet as dy
#xsent = True
xsent = False
HIDDEN_SIZE = 8
ITERATIONS = 2000
m = dy.Model()
trainer = dy.SimpleSGDTrainer(m)
pW = m.add_parameters((HIDDEN_SIZE, 2))
pb = m.add_parameters(HIDDEN_SIZE)
pV = m.add_parameters((1, HIDDEN_SIZE))
pa = m.add_parameters(1)
if len(sys.argv) == 2:
m.populate_from_textfile(sys.argv[1])
W = dy.parameter(pW)
b = dy.parameter(pb)
V = dy.parameter(pV)
a = dy.parameter(pa)
x = dy.vecInput(2)
y = dy.scalarInput(0)
h = dy.tanh((W*x) + b)
if xsent:
y_pred = dy.logistic((V*h) + a)
loss = dy.binary_log_loss(y_pred, y)
T = 1
F = 0
else:
y_pred = (V*h) + a
loss = dy.squared_distance(y_pred, y)
T = 1
F = -1
for iter in range(ITERATIONS):
mloss = 0.0
for mi in range(4):
x1 = mi % 2
x2 = (mi // 2) % 2
x.set([T if x1 else F, T if x2 else F])
y.set(T if x1 != x2 else F)
mloss += loss.scalar_value()
loss.backward()
trainer.update()
mloss /= 4.
print("loss: %0.9f" % mloss)
x.set([F,T])
z = -(-y_pred)
print(z.scalar_value())
m.save("xor.pymodel")
dy.renew_cg()
W = dy.parameter(pW)
b = dy.parameter(pb)
V = dy.parameter(pV)
a = dy.parameter(pa)
x = dy.vecInput(2)
y = dy.scalarInput(0)
h = dy.tanh((W*x) + b)
if xsent:
y_pred = dy.logistic((V*h) + a)
else:
y_pred = (V*h) + a
x.set([T,F])
print("TF",y_pred.scalar_value())
x.set([F,F])
print("FF",y_pred.scalar_value())
x.set([T,T])
print("TT",y_pred.scalar_value())
x.set([F,T])
print("FT",y_pred.scalar_value())