-
Notifications
You must be signed in to change notification settings - Fork 3
/
Out of the Box LSTM with TensorFlow.ipynb
565 lines (565 loc) · 21.3 KB
/
Out of the Box LSTM with TensorFlow.ipynb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Out of the Box LSTM with TensorFlow"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Data Preparation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([[1, 0, 0],\n",
" [0, 1, 0],\n",
" [1, 0, 0],\n",
" [1, 0, 0],\n",
" [0, 0, 0],\n",
" [1, 0, 0]])\n",
"'+-++0+'\n",
"3\n"
]
}
],
"source": [
"import numpy as np\n",
"from pprint import pprint\n",
"import datetime\n",
"\n",
"import data_generator\n",
"\n",
"sequence_length = 6\n",
"\n",
"reference_input_data, reference_output_data = data_generator.getSequences(sequence_length)\n",
"\n",
"# data_generator.getSequences(sequence_length) generates all possible combinations of\n",
"# the characters '+-0I', so for a sequence length of 6 characters there are a\n",
"# a total of 4^6 = 4096 possible combinations. Some Examples:\n",
"# '+-+-+-' = 0\n",
"# '------' = -6\n",
"# '0++000' = 2\n",
"# 'I++000' = -2\n",
"#\n",
"# Those sequences are encoded: Every character is representated by a vector, so the actual\n",
"# return value from data_generator.getSequences looks like this:\n",
"pprint(reference_input_data[0])\n",
"\n",
"# There is a helper to decode that again:\n",
"pprint(data_generator.decodeSequence(reference_input_data[0]))\n",
"\n",
"# The solution for that sequence is:\n",
"pprint(reference_output_data[0])\n",
"\n",
"instruction_count = np.array(reference_input_data).shape[2]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"We'll train using 1024/4096 Examples\n"
]
}
],
"source": [
"NUM_EXAMPLES = len(reference_input_data) / 4 # we use 1/4 of the data for the training\n",
"\n",
"test_input = reference_input_data[NUM_EXAMPLES:]\n",
"test_output = reference_output_data[NUM_EXAMPLES:] # everything beyond NUM_EXAMPLES\n",
"\n",
"train_input = reference_input_data[:NUM_EXAMPLES]\n",
"train_output = reference_output_data[:NUM_EXAMPLES]\n",
"\n",
"print(\"We'll train using \" + str(NUM_EXAMPLES) + \"/\" + str(len(reference_input_data)) + \" Examples\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"data = tf.placeholder(tf.float32, [None, sequence_length, instruction_count], name='data')\n",
"target = tf.transpose(tf.placeholder(tf.float32, [None], name='target'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. LSTM Layer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"LSTM_SIZE = 24"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"lstm_cell = tf.nn.rnn_cell.LSTMCell(LSTM_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"lstm_predictions, state = tf.nn.dynamic_rnn(lstm_cell, data, dtype=tf.float32)\n",
"\n",
"lstm_predictions = tf.transpose(lstm_predictions, [1, 0, 2])\n",
"lstm_prediction = tf.gather(lstm_predictions, int(lstm_predictions.get_shape()[0]) - 1)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"weight = tf.Variable(tf.truncated_normal([LSTM_SIZE, 1]))\n",
"bias = tf.Variable(tf.constant(0.1, shape=[1]))\n",
"\n",
"prediction = tf.matmul(lstm_prediction, weight) + bias"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Cost & Optimizing"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'mean_square_error_1:0' shape=() dtype=string>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with tf.name_scope('mean_square_error'):\n",
" mean_square_error = tf.reduce_sum(tf.square(tf.subtract(target, tf.unstack(prediction, axis = 1))))\n",
"tf.summary.scalar('mean_square_error', mean_square_error)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/cs/anaconda/envs/TestEnvironment/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py:93: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
" \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
]
}
],
"source": [
"optimizer = tf.train.AdamOptimizer()\n",
"minimize = optimizer.minimize(mean_square_error)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'error_1:0' shape=() dtype=string>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with tf.name_scope('error'):\n",
" with tf.name_scope('mistakes'):\n",
" mistakes = tf.not_equal(target, tf.round(tf.unstack(prediction, axis = 1)))\n",
" with tf.name_scope('error'):\n",
" error = tf.reduce_mean(tf.cast(mistakes, tf.float32))\n",
"tf.summary.scalar('error', error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Training"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"sess = tf.InteractiveSession()\n",
"merged = tf.summary.merge_all()\n",
"\n",
"date = str(datetime.datetime.now())\n",
"train_writer = tf.summary.FileWriter('logs/out_of_the_box_lstm/' + date + '/train', sess.graph)\n",
"test_writer = tf.summary.FileWriter('logs/sout_of_the_box_lstm/' + date + '/test', sess.graph)\n",
"\n",
"init_op = tf.global_variables_initializer()\n",
"sess.run(init_op)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 20 | incorrect 77.5% | mean squ error 8737.9\n",
"Epoch 40 | incorrect 77.0% | mean squ error 8102.8\n",
"Epoch 60 | incorrect 73.7% | mean squ error 7553.9\n",
"Epoch 80 | incorrect 73.6% | mean squ error 7313.4\n",
"Epoch 100 | incorrect 74.2% | mean squ error 7095.0\n",
"Epoch 120 | incorrect 73.2% | mean squ error 6811.3\n",
"Epoch 140 | incorrect 73.0% | mean squ error 6269.7\n",
"Epoch 160 | incorrect 71.0% | mean squ error 5386.7\n",
"Epoch 180 | incorrect 66.4% | mean squ error 4642.8\n",
"Epoch 200 | incorrect 62.3% | mean squ error 3968.5\n",
"Epoch 220 | incorrect 59.7% | mean squ error 3501.6\n",
"Epoch 240 | incorrect 57.3% | mean squ error 3178.4\n",
"Epoch 260 | incorrect 54.8% | mean squ error 2947.7\n",
"Epoch 280 | incorrect 53.4% | mean squ error 2775.0\n",
"Epoch 300 | incorrect 51.7% | mean squ error 2635.8\n",
"Epoch 320 | incorrect 50.4% | mean squ error 2504.8\n",
"Epoch 340 | incorrect 49.7% | mean squ error 2376.9\n",
"Epoch 360 | incorrect 48.2% | mean squ error 2240.8\n",
"Epoch 380 | incorrect 46.7% | mean squ error 2099.4\n",
"Epoch 400 | incorrect 45.4% | mean squ error 1950.7\n",
"Epoch 420 | incorrect 43.1% | mean squ error 1801.9\n",
"Epoch 440 | incorrect 41.1% | mean squ error 1661.6\n",
"Epoch 460 | incorrect 39.1% | mean squ error 1543.4\n",
"Epoch 480 | incorrect 36.5% | mean squ error 1431.7\n",
"Epoch 500 | incorrect 33.8% | mean squ error 1336.4\n",
"Epoch 520 | incorrect 31.5% | mean squ error 1234.5\n",
"Epoch 540 | incorrect 29.2% | mean squ error 1146.9\n",
"Epoch 560 | incorrect 28.0% | mean squ error 1088.1\n",
"Epoch 580 | incorrect 26.0% | mean squ error 1007.1\n",
"Epoch 600 | incorrect 24.4% | mean squ error 947.8\n",
"Epoch 620 | incorrect 23.6% | mean squ error 896.8\n",
"Epoch 640 | incorrect 23.2% | mean squ error 854.0\n",
"Epoch 660 | incorrect 21.7% | mean squ error 805.9\n",
"Epoch 680 | incorrect 20.9% | mean squ error 767.7\n",
"Epoch 700 | incorrect 20.1% | mean squ error 730.7\n",
"Epoch 720 | incorrect 18.9% | mean squ error 696.8\n",
"Epoch 740 | incorrect 18.0% | mean squ error 669.8\n",
"Epoch 760 | incorrect 16.9% | mean squ error 638.2\n",
"Epoch 780 | incorrect 15.8% | mean squ error 612.3\n",
"Epoch 800 | incorrect 15.0% | mean squ error 588.6\n",
"Epoch 820 | incorrect 14.2% | mean squ error 566.5\n",
"Epoch 840 | incorrect 13.6% | mean squ error 546.2\n",
"Epoch 860 | incorrect 13.2% | mean squ error 527.9\n",
"Epoch 880 | incorrect 12.5% | mean squ error 509.6\n",
"Epoch 900 | incorrect 12.0% | mean squ error 492.3\n",
"Epoch 920 | incorrect 11.6% | mean squ error 476.5\n",
"Epoch 940 | incorrect 11.2% | mean squ error 461.1\n",
"Epoch 960 | incorrect 10.8% | mean squ error 446.3\n",
"Epoch 980 | incorrect 11.7% | mean squ error 471.2\n",
"Epoch 1000 | incorrect 9.8% | mean squ error 423.1\n",
"Epoch 1020 | incorrect 9.9% | mean squ error 408.0\n",
"Epoch 1040 | incorrect 9.4% | mean squ error 396.2\n",
"Epoch 1060 | incorrect 9.2% | mean squ error 385.4\n",
"Epoch 1080 | incorrect 8.9% | mean squ error 374.9\n",
"Epoch 1100 | incorrect 8.7% | mean squ error 364.7\n",
"Epoch 1120 | incorrect 8.5% | mean squ error 355.0\n",
"Epoch 1140 | incorrect 8.2% | mean squ error 345.6\n",
"Epoch 1160 | incorrect 7.8% | mean squ error 336.6\n",
"Epoch 1180 | incorrect 7.5% | mean squ error 328.0\n",
"Epoch 1200 | incorrect 7.5% | mean squ error 323.9\n",
"Epoch 1220 | incorrect 7.0% | mean squ error 316.6\n",
"Epoch 1240 | incorrect 6.6% | mean squ error 305.5\n",
"Epoch 1260 | incorrect 6.3% | mean squ error 299.1\n",
"Epoch 1280 | incorrect 6.3% | mean squ error 293.0\n",
"Epoch 1300 | incorrect 6.1% | mean squ error 287.1\n",
"Epoch 1320 | incorrect 5.8% | mean squ error 281.5\n",
"Epoch 1340 | incorrect 5.6% | mean squ error 276.1\n",
"Epoch 1360 | incorrect 5.3% | mean squ error 270.9\n",
"Epoch 1380 | incorrect 5.1% | mean squ error 265.9\n",
"Epoch 1400 | incorrect 4.9% | mean squ error 261.0\n",
"Epoch 1420 | incorrect 4.7% | mean squ error 256.4\n",
"Epoch 1440 | incorrect 4.5% | mean squ error 251.9\n",
"Epoch 1460 | incorrect 4.2% | mean squ error 247.6\n",
"Epoch 1480 | incorrect 4.3% | mean squ error 244.2\n",
"Epoch 1500 | incorrect 4.1% | mean squ error 240.5\n",
"Epoch 1520 | incorrect 3.7% | mean squ error 236.2\n",
"Epoch 1540 | incorrect 3.6% | mean squ error 232.6\n",
"Epoch 1560 | incorrect 3.5% | mean squ error 229.2\n",
"Epoch 1580 | incorrect 3.4% | mean squ error 226.0\n",
"Epoch 1600 | incorrect 3.3% | mean squ error 222.9\n",
"Epoch 1620 | incorrect 3.2% | mean squ error 219.8\n",
"Epoch 1640 | incorrect 3.1% | mean squ error 216.9\n",
"Epoch 1660 | incorrect 3.1% | mean squ error 214.0\n",
"Epoch 1680 | incorrect 3.0% | mean squ error 211.2\n",
"Epoch 1700 | incorrect 2.9% | mean squ error 208.5\n",
"Epoch 1720 | incorrect 2.8% | mean squ error 205.8\n",
"Epoch 1740 | incorrect 2.7% | mean squ error 203.3\n",
"Epoch 1760 | incorrect 3.0% | mean squ error 213.6\n",
"Epoch 1780 | incorrect 2.7% | mean squ error 200.1\n",
"Epoch 1800 | incorrect 2.6% | mean squ error 196.5\n",
"Epoch 1820 | incorrect 2.6% | mean squ error 194.2\n",
"Epoch 1840 | incorrect 2.5% | mean squ error 192.1\n",
"Epoch 1860 | incorrect 2.4% | mean squ error 190.1\n",
"Epoch 1880 | incorrect 2.4% | mean squ error 188.1\n",
"Epoch 1900 | incorrect 2.4% | mean squ error 186.1\n",
"Epoch 1920 | incorrect 2.3% | mean squ error 184.1\n",
"Epoch 1940 | incorrect 2.3% | mean squ error 182.2\n",
"Epoch 1960 | incorrect 2.2% | mean squ error 180.4\n",
"Epoch 1980 | incorrect 2.1% | mean squ error 178.5\n",
"Epoch 2000 | incorrect 2.1% | mean squ error 176.7\n",
"Epoch 2020 | incorrect 2.0% | mean squ error 174.9\n",
"Epoch 2040 | incorrect 2.3% | mean squ error 193.3\n",
"Epoch 2060 | incorrect 2.1% | mean squ error 173.9\n",
"Epoch 2080 | incorrect 1.8% | mean squ error 170.2\n",
"Epoch 2100 | incorrect 1.8% | mean squ error 168.6\n",
"Epoch 2120 | incorrect 1.8% | mean squ error 167.1\n",
"Epoch 2140 | incorrect 1.7% | mean squ error 165.6\n",
"Epoch 2160 | incorrect 1.7% | mean squ error 164.2\n",
"Epoch 2180 | incorrect 1.7% | mean squ error 162.8\n",
"Epoch 2200 | incorrect 1.7% | mean squ error 161.3\n",
"Epoch 2220 | incorrect 1.7% | mean squ error 159.9\n",
"Epoch 2240 | incorrect 1.7% | mean squ error 158.6\n",
"Epoch 2260 | incorrect 1.7% | mean squ error 157.2\n",
"Epoch 2280 | incorrect 1.6% | mean squ error 155.8\n",
"Epoch 2300 | incorrect 1.6% | mean squ error 154.5\n",
"Epoch 2320 | incorrect 1.5% | mean squ error 153.2\n",
"Epoch 2340 | incorrect 1.9% | mean squ error 178.3\n",
"Epoch 2360 | incorrect 1.5% | mean squ error 151.8\n",
"Epoch 2380 | incorrect 1.5% | mean squ error 149.8\n",
"Epoch 2400 | incorrect 1.5% | mean squ error 148.6\n",
"Epoch 2420 | incorrect 1.5% | mean squ error 147.4\n",
"Epoch 2440 | incorrect 1.5% | mean squ error 146.2\n",
"Epoch 2460 | incorrect 1.5% | mean squ error 145.1\n",
"Epoch 2480 | incorrect 1.4% | mean squ error 144.1\n",
"Epoch 2500 | incorrect 1.4% | mean squ error 143.0\n",
"Epoch 2520 | incorrect 1.4% | mean squ error 141.9\n",
"Epoch 2540 | incorrect 1.4% | mean squ error 140.9\n",
"Epoch 2560 | incorrect 1.4% | mean squ error 139.8\n",
"Epoch 2580 | incorrect 1.4% | mean squ error 138.8\n",
"Epoch 2600 | incorrect 1.4% | mean squ error 137.7\n",
"Epoch 2620 | incorrect 1.4% | mean squ error 136.7\n",
"Epoch 2640 | incorrect 1.4% | mean squ error 135.7\n",
"Epoch 2660 | incorrect 1.4% | mean squ error 158.0\n",
"Epoch 2680 | incorrect 1.3% | mean squ error 135.1\n",
"Epoch 2700 | incorrect 1.4% | mean squ error 133.0\n",
"Epoch 2720 | incorrect 1.4% | mean squ error 132.2\n",
"Epoch 2740 | incorrect 1.4% | mean squ error 131.3\n",
"Epoch 2760 | incorrect 1.4% | mean squ error 130.4\n",
"Epoch 2780 | incorrect 1.4% | mean squ error 129.5\n",
"Epoch 2800 | incorrect 1.3% | mean squ error 128.7\n",
"Epoch 2820 | incorrect 1.3% | mean squ error 127.9\n",
"Epoch 2840 | incorrect 1.3% | mean squ error 127.0\n",
"Epoch 2860 | incorrect 1.3% | mean squ error 126.2\n",
"Epoch 2880 | incorrect 1.3% | mean squ error 125.4\n",
"Epoch 2900 | incorrect 1.3% | mean squ error 124.5\n",
"Epoch 2920 | incorrect 1.3% | mean squ error 123.7\n",
"Epoch 2940 | incorrect 1.3% | mean squ error 122.9\n",
"Epoch 2960 | incorrect 1.4% | mean squ error 138.2\n",
"Epoch 2980 | incorrect 1.2% | mean squ error 123.4\n",
"Epoch 3000 | incorrect 1.3% | mean squ error 120.6\n",
"Epoch 3020 | incorrect 1.3% | mean squ error 120.1\n",
"Epoch 3040 | incorrect 1.3% | mean squ error 119.2\n",
"Epoch 3060 | incorrect 1.3% | mean squ error 118.5\n",
"Epoch 3080 | incorrect 1.3% | mean squ error 117.8\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3100 | incorrect 1.3% | mean squ error 117.1\n",
"Epoch 3120 | incorrect 1.3% | mean squ error 116.4\n",
"Epoch 3140 | incorrect 1.3% | mean squ error 115.7\n",
"Epoch 3160 | incorrect 1.3% | mean squ error 115.0\n",
"Epoch 3180 | incorrect 1.3% | mean squ error 114.4\n",
"Epoch 3200 | incorrect 1.3% | mean squ error 114.6\n",
"Epoch 3220 | incorrect 1.3% | mean squ error 114.8\n",
"Epoch 3240 | incorrect 1.3% | mean squ error 112.3\n",
"Epoch 3260 | incorrect 1.3% | mean squ error 111.8\n",
"Epoch 3280 | incorrect 1.3% | mean squ error 111.2\n",
"Epoch 3300 | incorrect 1.3% | mean squ error 110.6\n",
"Epoch 3320 | incorrect 1.3% | mean squ error 110.0\n",
"Epoch 3340 | incorrect 1.3% | mean squ error 109.4\n",
"Epoch 3360 | incorrect 1.3% | mean squ error 108.8\n",
"Epoch 3380 | incorrect 1.3% | mean squ error 108.2\n",
"Epoch 3400 | incorrect 1.3% | mean squ error 107.6\n",
"Epoch 3420 | incorrect 1.3% | mean squ error 107.0\n",
"Epoch 3440 | incorrect 1.3% | mean squ error 106.4\n",
"Epoch 3460 | incorrect 1.3% | mean squ error 109.2\n",
"Epoch 3480 | incorrect 1.3% | mean squ error 109.6\n",
"Epoch 3500 | incorrect 1.3% | mean squ error 105.3\n",
"Epoch 3520 | incorrect 1.2% | mean squ error 104.4\n",
"Epoch 3540 | incorrect 1.2% | mean squ error 103.8\n",
"Epoch 3560 | incorrect 1.2% | mean squ error 103.2\n",
"Epoch 3580 | incorrect 1.2% | mean squ error 102.7\n",
"Epoch 3600 | incorrect 1.2% | mean squ error 102.2\n",
"Epoch 3620 | incorrect 1.2% | mean squ error 101.7\n",
"Epoch 3640 | incorrect 1.2% | mean squ error 101.2\n",
"Epoch 3660 | incorrect 1.2% | mean squ error 100.7\n",
"Epoch 3680 | incorrect 1.2% | mean squ error 100.2\n",
"Epoch 3700 | incorrect 1.2% | mean squ error 99.7\n",
"Epoch 3720 | incorrect 1.2% | mean squ error 99.1\n",
"Epoch 3740 | incorrect 1.1% | mean squ error 128.6\n",
"Epoch 3760 | incorrect 1.2% | mean squ error 98.3\n",
"Epoch 3780 | incorrect 1.2% | mean squ error 97.8\n",
"Epoch 3800 | incorrect 1.2% | mean squ error 97.4\n",
"Epoch 3820 | incorrect 1.2% | mean squ error 96.9\n",
"Epoch 3840 | incorrect 1.2% | mean squ error 96.4\n",
"Epoch 3860 | incorrect 1.2% | mean squ error 96.0\n",
"Epoch 3880 | incorrect 1.2% | mean squ error 95.5\n",
"Epoch 3900 | incorrect 1.2% | mean squ error 95.1\n",
"Epoch 3920 | incorrect 1.2% | mean squ error 94.6\n",
"Epoch 3940 | incorrect 1.2% | mean squ error 94.2\n",
"Epoch 3960 | incorrect 1.2% | mean squ error 93.7\n",
"Epoch 3980 | incorrect 1.2% | mean squ error 93.3\n",
"Epoch 4000 | incorrect 1.3% | mean squ error 112.0\n"
]
}
],
"source": [
"epoch = 4000\n",
"\n",
"for i in range(epoch):\n",
" if (i + 1) % 20 == 0:\n",
" summary, incorrect, mean_squ_err = sess.run([merged, error, mean_square_error], {data: test_input, target: test_output})\n",
" test_writer.add_summary(summary, i)\n",
" \n",
" print('Epoch {:4d} | incorrect {: 3.1f}% | mean squ error {: 3.1f}'.format(i + 1, incorrect * 100, mean_squ_err))\n",
" else:\n",
" summary, acc = sess.run([merged, error], {data: train_input, target: train_output})\n",
" train_writer.add_summary(summary, i)\n",
" \n",
" sess.run(minimize,{data: train_input, target: train_output})"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.99913526]], dtype=float32)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Test the result\n",
"sess.run(prediction, {data: [data_generator.encodeSequence(\"00-+++\")]})"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sess.close()\n",
"train_writer.close()\n",
"test_writer.close()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}