/
1. Text generation_revised.ipynb
458 lines (458 loc) · 18.9 KB
/
1. Text generation_revised.ipynb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import random\n",
"import sys\n",
"import tensorflow as tf\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"corpus length: 600893\n",
"total chars: 57\n",
"nb sequences: 200285\n",
"Vectorization...\n",
"pre-processing ready\n"
]
}
],
"source": [
"path = 'nietzsche.txt'\n",
"text = open(path).read().lower()\n",
"print('corpus length:', len(text))\n",
"\n",
"chars = sorted(list(set(text)))\n",
"print('total chars:', len(chars))\n",
"char_indices = dict((c, i) for i, c in enumerate(chars))\n",
"indices_char = dict((i, c) for i, c in enumerate(chars))\n",
"\n",
"# cut the text in semi-redundant sequences of maxlen characters\n",
"maxlen = 40\n",
"step = 3\n",
"sentences = []\n",
"next_chars = []\n",
"for i in range(0, len(text) - maxlen, step):\n",
" sentences.append(text[i: i + maxlen])\n",
" next_chars.append(text[i + maxlen])\n",
"print('nb sequences:', len(sentences))\n",
"\n",
"n_char=len(chars)\n",
"\n",
"print('Vectorization...')\n",
"X_data = np.zeros((len(sentences), maxlen, n_char), dtype=np.bool)\n",
"Y_data = np.zeros((len(sentences), n_char), dtype=np.bool)\n",
"for i, sentence in enumerate(sentences):\n",
" for t, char in enumerate(sentence):\n",
" X_data[i, t, char_indices[char]] = 1\n",
" Y_data[i, char_indices[next_chars[i]]] = 1\n",
"\n",
"print (\"pre-processing ready\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"parameters ready\n"
]
}
],
"source": [
"# Parameters\n",
"learning_rate = 0.01\n",
"training_iters = 1000000\n",
"batch_size = 128\n",
"display_step = 20\n",
"n_hidden = 128\n",
"\n",
"# tf Graph input\n",
"x = tf.placeholder(\"float\", [None, maxlen, n_char])\n",
"y = tf.placeholder(\"float\", [None, n_char])\n",
"\n",
"# Define weights\n",
"weights = {\n",
" 'out': tf.Variable(tf.random_normal([n_hidden, n_char]))\n",
"}\n",
"biases = {\n",
" 'out': tf.Variable(tf.random_normal([n_char]))\n",
"}\n",
"print (\"parameters ready\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Network ready\n"
]
}
],
"source": [
"with tf.variable_scope(\"model\"):\n",
" #tf.get_variable_scope().reuse_variables()\n",
"\n",
" # Prepare data shape to match `rnn` function requirements\n",
" # Current data input shape: (batch_size, n_steps, n_input)\n",
" # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)\n",
" \n",
" # Permuting batch_size and n_steps\n",
" #x_t = tf.transpose(x, [1, 0, 2])\n",
" # Reshaping to (n_steps*batch_size, n_input)\n",
" #x_t = tf.reshape(x_t, [-1, n_char])\n",
" # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
" #x_t = tf.split(0, maxlen, x_t)\n",
"\n",
" # Define a lstm cell with tensorflow\n",
" if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:\n",
" cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True)\n",
" else:\n",
" cell = tf.contrib.rnn.BasicLSTMCell(n_hidden)\n",
"\n",
" # Get lstm cell output\n",
" outputs, states = tf.nn.dynamic_rnn(cell, x ,time_major = False, dtype=tf.float32)\n",
" \n",
" if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:\n",
" outputs = tf.unpack(tf.transpose(outputs, [1, 0, 2])) # states is the last outputs\n",
" else:\n",
" outputs = tf.unstack(tf.transpose(outputs, [1,0,2]))\n",
"\n",
" # Linear activation, using rnn inner loop last output\n",
" pred = tf.matmul(outputs[-1], weights['out']) + biases['out']\n",
" \n",
" pred_prob= tf.nn.softmax(pred)\n",
"\n",
"#pred = RNN(x, weights, biases)\n",
"\n",
"# Define loss and optimizer\n",
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))\n",
"optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)\n",
"\n",
"# Evaluate model\n",
"correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))\n",
"accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
"\n",
"\n",
"\n",
"print (\"Network ready\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"functions ready\n"
]
}
],
"source": [
"def sample(preds, temperature=1.0):\n",
" # helper function to sample an index from a probability array\n",
" preds = np.asarray(preds).astype('float64')\n",
" preds = np.log(preds) / temperature\n",
" exp_preds = np.exp(preds)\n",
" preds = exp_preds / np.sum(exp_preds)\n",
" probas = np.random.multinomial(1, preds.squeeze(), 1)\n",
" return np.argmax(probas)\n",
"\n",
"def make_batches(size, batch_size):\n",
" nb_batch = int(np.floor(size/float(batch_size)))\n",
" #nb_batch = int(np.ceil(size/float(batch_size)))\n",
" return [(i*batch_size, min(size, (i+1)*batch_size)) for i in range(0, nb_batch)]\n",
"\n",
"def slice_X(X, start=None, stop=None):\n",
" if type(X) == list:\n",
" if hasattr(start, '__len__'):\n",
" return [x[start] for x in X]\n",
" else:\n",
" return [x[start:stop] for x in X]\n",
" else:\n",
" if hasattr(start, '__len__'):\n",
" return X[start]\n",
" else:\n",
" return X[start:stop] \n",
" \n",
"print (\"functions ready\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"datasets ready\n"
]
}
],
"source": [
"import itertools\n",
"\n",
"ins=[X_data,Y_data]\n",
"\n",
"n_train=X_data.shape[0]\n",
"\n",
"index_array = np.arange(n_train)\n",
"\n",
"np.random.shuffle(index_array)\n",
"\n",
"batches = make_batches(n_train, batch_size)\n",
"\n",
"ins=[slice_X(ins,index_array[batch_start:batch_end]) for batch_start, batch_end in batches]\n",
"\n",
"iterator=itertools.cycle((data for data in ins if data != []))\n",
"\n",
"print (\"datasets ready\")\n",
"sample_step = 1000\n",
"# Launch the graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iter 2560, Minibatch Loss= 2.816587, Training Accuracy= 0.17969\n",
"Iter 5120, Minibatch Loss= 2.758723, Training Accuracy= 0.17969\n",
"Iter 7680, Minibatch Loss= 2.592314, Training Accuracy= 0.28906\n",
"Iter 10240, Minibatch Loss= 2.297061, Training Accuracy= 0.34375\n",
"Iter 12800, Minibatch Loss= 2.472381, Training Accuracy= 0.32812\n",
"Iter 15360, Minibatch Loss= 2.288723, Training Accuracy= 0.39062\n",
"Iter 17920, Minibatch Loss= 2.264370, Training Accuracy= 0.31250\n",
"Iter 20480, Minibatch Loss= 2.203702, Training Accuracy= 0.38281\n",
"Iter 23040, Minibatch Loss= 2.517159, Training Accuracy= 0.27344\n",
"Iter 25600, Minibatch Loss= 2.214355, Training Accuracy= 0.29688\n",
"Iter 28160, Minibatch Loss= 2.183286, Training Accuracy= 0.32812\n",
"Iter 30720, Minibatch Loss= 2.255154, Training Accuracy= 0.32812\n",
"Iter 33280, Minibatch Loss= 2.002253, Training Accuracy= 0.35156\n",
"Iter 35840, Minibatch Loss= 2.083503, Training Accuracy= 0.37500\n",
"Iter 38400, Minibatch Loss= 2.292865, Training Accuracy= 0.33594\n",
"Iter 40960, Minibatch Loss= 1.919709, Training Accuracy= 0.50000\n",
"Iter 43520, Minibatch Loss= 2.129637, Training Accuracy= 0.39062\n",
"Iter 46080, Minibatch Loss= 2.409417, Training Accuracy= 0.32812\n",
"Iter 48640, Minibatch Loss= 2.014746, Training Accuracy= 0.38281\n",
"Iter 51200, Minibatch Loss= 2.065424, Training Accuracy= 0.39844\n",
"Iter 53760, Minibatch Loss= 1.799022, Training Accuracy= 0.42969\n",
"Iter 56320, Minibatch Loss= 1.999292, Training Accuracy= 0.36719\n",
"Iter 58880, Minibatch Loss= 2.030244, Training Accuracy= 0.42188\n",
"Iter 61440, Minibatch Loss= 2.197674, Training Accuracy= 0.32812\n",
"Iter 64000, Minibatch Loss= 1.837993, Training Accuracy= 0.43750\n",
"Iter 66560, Minibatch Loss= 2.056824, Training Accuracy= 0.32031\n",
"Iter 69120, Minibatch Loss= 2.130220, Training Accuracy= 0.35156\n",
"Iter 71680, Minibatch Loss= 2.416887, Training Accuracy= 0.34375\n",
"Iter 74240, Minibatch Loss= 2.008669, Training Accuracy= 0.46094\n",
"Iter 76800, Minibatch Loss= 1.913556, Training Accuracy= 0.42969\n",
"Iter 79360, Minibatch Loss= 2.016909, Training Accuracy= 0.50000\n",
"Iter 81920, Minibatch Loss= 1.968499, Training Accuracy= 0.39062\n",
"Iter 84480, Minibatch Loss= 2.017895, Training Accuracy= 0.37500\n",
"Iter 87040, Minibatch Loss= 1.901745, Training Accuracy= 0.43750\n",
"Iter 89600, Minibatch Loss= 1.922046, Training Accuracy= 0.44531\n",
"Iter 92160, Minibatch Loss= 1.762730, Training Accuracy= 0.52344\n",
"Iter 94720, Minibatch Loss= 1.774164, Training Accuracy= 0.42188\n",
"Iter 97280, Minibatch Loss= 1.801815, Training Accuracy= 0.46094\n",
"Iter 99840, Minibatch Loss= 1.844371, Training Accuracy= 0.46094\n",
"Iter 102400, Minibatch Loss= 1.792642, Training Accuracy= 0.46875\n",
"Iter 104960, Minibatch Loss= 1.827651, Training Accuracy= 0.50000\n",
"Iter 107520, Minibatch Loss= 1.958480, Training Accuracy= 0.47656\n",
"Iter 110080, Minibatch Loss= 1.858919, Training Accuracy= 0.43750\n",
"Iter 112640, Minibatch Loss= 1.640953, Training Accuracy= 0.49219\n",
"Iter 115200, Minibatch Loss= 1.803569, Training Accuracy= 0.50000\n",
"Iter 117760, Minibatch Loss= 1.974909, Training Accuracy= 0.41406\n",
"Iter 120320, Minibatch Loss= 1.990376, Training Accuracy= 0.39844\n",
"Iter 122880, Minibatch Loss= 2.017883, Training Accuracy= 0.35156\n",
"Iter 125440, Minibatch Loss= 1.892190, Training Accuracy= 0.46875\n",
"----- Generating with seed: \"onstraint, impulsion, pressure, resistan\"\n",
"onstraint, impulsion, pressure, resistantary:\n",
"skelo, inthering that onidgidy\n",
"strear so thingg ][5pooso of irtumpleer of all aquegation tify?\" tifker iffered take fromm aringed and \"itineved with shive be sehovon bare\n",
"may a wait, but \"fils r\n",
"Iter 128000, Minibatch Loss= 2.016952, Training Accuracy= 0.39844\n",
"Iter 130560, Minibatch Loss= 1.652233, Training Accuracy= 0.51562\n",
"Iter 133120, Minibatch Loss= 1.727036, Training Accuracy= 0.50000\n",
"Iter 135680, Minibatch Loss= 1.863072, Training Accuracy= 0.42969\n",
"Iter 138240, Minibatch Loss= 1.812885, Training Accuracy= 0.50000\n",
"Iter 140800, Minibatch Loss= 1.876485, Training Accuracy= 0.41406\n",
"Iter 143360, Minibatch Loss= 1.769477, Training Accuracy= 0.52344\n",
"Iter 145920, Minibatch Loss= 1.697696, Training Accuracy= 0.42969\n",
"Iter 148480, Minibatch Loss= 1.654513, Training Accuracy= 0.53906\n",
"Iter 151040, Minibatch Loss= 1.900216, Training Accuracy= 0.43750\n",
"Iter 153600, Minibatch Loss= 1.829828, Training Accuracy= 0.46875\n",
"Iter 156160, Minibatch Loss= 1.733665, Training Accuracy= 0.53125\n",
"Iter 158720, Minibatch Loss= 1.843327, Training Accuracy= 0.48438\n",
"Iter 161280, Minibatch Loss= 1.559314, Training Accuracy= 0.49219\n",
"Iter 163840, Minibatch Loss= 1.591175, Training Accuracy= 0.56250\n",
"Iter 166400, Minibatch Loss= 1.804161, Training Accuracy= 0.46094\n",
"Iter 168960, Minibatch Loss= 1.850974, Training Accuracy= 0.50781\n",
"Iter 171520, Minibatch Loss= 1.846101, Training Accuracy= 0.45312\n",
"Iter 174080, Minibatch Loss= 1.864464, Training Accuracy= 0.45312\n",
"Iter 176640, Minibatch Loss= 1.881261, Training Accuracy= 0.43750\n",
"Iter 179200, Minibatch Loss= 1.762946, Training Accuracy= 0.53906\n",
"Iter 181760, Minibatch Loss= 1.917798, Training Accuracy= 0.42188\n",
"Iter 184320, Minibatch Loss= 1.876582, Training Accuracy= 0.42969\n",
"Iter 186880, Minibatch Loss= 1.779570, Training Accuracy= 0.44531\n",
"Iter 189440, Minibatch Loss= 1.789607, Training Accuracy= 0.44531\n",
"Iter 192000, Minibatch Loss= 1.696336, Training Accuracy= 0.50000\n",
"Iter 194560, Minibatch Loss= 1.715119, Training Accuracy= 0.46875\n",
"Iter 197120, Minibatch Loss= 1.779335, Training Accuracy= 0.46094\n",
"Iter 199680, Minibatch Loss= 1.621711, Training Accuracy= 0.47656\n",
"Iter 202240, Minibatch Loss= 1.705649, Training Accuracy= 0.48438\n",
"Iter 204800, Minibatch Loss= 1.683780, Training Accuracy= 0.50000\n",
"Iter 207360, Minibatch Loss= 1.544564, Training Accuracy= 0.49219\n",
"Iter 209920, Minibatch Loss= 1.841880, Training Accuracy= 0.52344\n",
"Iter 212480, Minibatch Loss= 1.635493, Training Accuracy= 0.49219\n",
"Iter 215040, Minibatch Loss= 1.516781, Training Accuracy= 0.50000\n",
"Iter 217600, Minibatch Loss= 1.701025, Training Accuracy= 0.52344\n",
"Iter 220160, Minibatch Loss= 1.614446, Training Accuracy= 0.52344\n",
"Iter 222720, Minibatch Loss= 1.649614, Training Accuracy= 0.48438\n",
"Iter 225280, Minibatch Loss= 1.701459, Training Accuracy= 0.53906\n",
"Iter 227840, Minibatch Loss= 1.868622, Training Accuracy= 0.42188\n",
"Iter 230400, Minibatch Loss= 1.508416, Training Accuracy= 0.55469\n",
"Iter 232960, Minibatch Loss= 1.524642, Training Accuracy= 0.50000\n",
"Iter 235520, Minibatch Loss= 1.511087, Training Accuracy= 0.55469\n",
"Iter 238080, Minibatch Loss= 1.646264, Training Accuracy= 0.49219\n",
"Iter 240640, Minibatch Loss= 1.564551, Training Accuracy= 0.56250\n",
"Iter 243200, Minibatch Loss= 1.519640, Training Accuracy= 0.50781\n",
"Iter 245760, Minibatch Loss= 1.626197, Training Accuracy= 0.46094\n",
"Iter 248320, Minibatch Loss= 1.836848, Training Accuracy= 0.45312\n",
"Iter 250880, Minibatch Loss= 1.645749, Training Accuracy= 0.51562\n",
"Iter 253440, Minibatch Loss= 1.424105, Training Accuracy= 0.54688\n",
"----- Generating with seed: \" caused by the display of\n",
"our power over\"\n",
" caused by the display of\n",
"our power over.\n",
"grean at arnforian to as there and persons at virct to de slevery suspectic notsile even one mwake tike nat is later, an as a glegrestunce, to conscian sofetains and sleast, (as the fordies conuruti\n",
"Iter 256000, Minibatch Loss= 1.558126, Training Accuracy= 0.55469\n",
"Iter 258560, Minibatch Loss= 1.692595, Training Accuracy= 0.50000\n",
"Iter 261120, Minibatch Loss= 1.757544, Training Accuracy= 0.47656\n",
"Iter 263680, Minibatch Loss= 1.479481, Training Accuracy= 0.56250\n",
"Iter 266240, Minibatch Loss= 1.643853, Training Accuracy= 0.50000\n",
"Iter 268800, Minibatch Loss= 1.761557, Training Accuracy= 0.53125\n",
"Iter 271360, Minibatch Loss= 1.603626, Training Accuracy= 0.52344\n",
"Iter 273920, Minibatch Loss= 1.899508, Training Accuracy= 0.43750\n"
]
}
],
"source": [
"with tf.Session() as sess:\n",
" # tf.initialize_all_variables() no long valid from\n",
" # 2017-03-02 if using tensorflow >= 0.12\n",
" if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:\n",
" init = tf.initialize_all_variables()\n",
" else:\n",
" init = tf.global_variables_initializer()\n",
" \n",
" sess.run(init)\n",
" step = 1\n",
" # Keep training until reach max iterations\n",
" while step * batch_size < training_iters:\n",
" [batch_x,batch_y] = next(iterator)\n",
" # Run optimization op (backprop)\n",
" _, acc, loss = sess.run([optimizer,accuracy,cost], feed_dict={x: batch_x, y: batch_y})\n",
" \n",
" if step % display_step == 0:\n",
" print (\"Iter \" + str(step*batch_size) + \", Minibatch Loss= \" + \\\n",
" \"{:.6f}\".format(loss) + \", Training Accuracy= \" + \\\n",
" \"{:.5f}\".format(acc))\n",
" step += 1\n",
" \n",
" start_index = random.randint(0, len(text) - maxlen - 1)\n",
" \n",
" if step % sample_step == 0:\n",
" generated = ''\n",
" sentence = text[start_index: start_index + maxlen]\n",
" generated += sentence\n",
" print('----- Generating with seed: \"' + sentence + '\"')\n",
" sys.stdout.write(generated)\n",
"\n",
" for i in range(200):\n",
" x_sample_input = np.zeros((1, maxlen, n_char))\n",
" for t, char in enumerate(sentence):\n",
" x_sample_input[0, t, char_indices[char]] = 1.\n",
"\n",
" preds = sess.run(pred_prob, feed_dict={x: x_sample_input})\n",
" next_index = sample(preds)\n",
" next_char = indices_char[next_index]\n",
"\n",
" generated += next_char\n",
" sentence = sentence[1:] + next_char\n",
"\n",
" sys.stdout.write(next_char)\n",
" sys.stdout.flush()\n",
" print()\n",
" \n",
" \n",
" print (\"Optimization Finished!\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}