Skip to content

Commit

Permalink
removing droout for validation
Browse files Browse the repository at this point in the history
  • Loading branch information
swarbrickjones committed Aug 3, 2016
1 parent b9bffcf commit 6e2a31e
Showing 1 changed file with 46 additions and 54 deletions.
100 changes: 46 additions & 54 deletions spelling_bee_RNN.ipynb
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 317,
"execution_count": 1,
"metadata": {
"collapsed": false
},
Expand Down Expand Up @@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {
"collapsed": false
},
Expand All @@ -80,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {
"collapsed": false
},
Expand All @@ -104,11 +104,20 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"supercalifragilisticexpialidocious\n",
"[55, 64, 53, 26, 42, 6, 43, 7, 32, 54, 5, 41, 7, 43, 37, 55, 57, 35, 42, 25, 42, 55, 53, 38, 6, 43, 7, 21, 48, 56, 7, 55]\n"
]
}
],
"source": [
"max_k = max([len(k) for k,v in pronounce_dict.items()])\n",
"max_v = max([len(v) for k,v in pronounce_dict.items()])\n",
Expand All @@ -127,7 +136,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {
"collapsed": false
},
Expand All @@ -136,8 +145,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"133779\n",
"108006\n"
"133779\n"
]
}
],
Expand Down Expand Up @@ -165,7 +173,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"metadata": {
"collapsed": false
},
Expand Down Expand Up @@ -213,9 +221,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {
"collapsed": true
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -233,7 +241,7 @@
},
{
"cell_type": "code",
"execution_count": 337,
"execution_count": 8,
"metadata": {
"collapsed": false
},
Expand All @@ -250,7 +258,7 @@
},
{
"cell_type": "code",
"execution_count": 338,
"execution_count": 9,
"metadata": {
"collapsed": false
},
Expand Down Expand Up @@ -278,7 +286,7 @@
},
{
"cell_type": "code",
"execution_count": 339,
"execution_count": 10,
"metadata": {
"collapsed": true
},
Expand Down Expand Up @@ -310,26 +318,37 @@
},
{
"cell_type": "code",
"execution_count": 340,
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"keep_prob = tf.placeholder(\"float\")\n",
"\n",
"cells = [rnn_cell.DropoutWrapper(\n",
" rnn_cell.BasicLSTMCell(embedding_dim), output_keep_prob=0.5\n",
" rnn_cell.BasicLSTMCell(embedding_dim), output_keep_prob=keep_prob\n",
" ) for i in range(3)]\n",
"\n",
"stacked_lstm = rnn_cell.MultiRNNCell(cells)\n",
"\n",
"with tf.variable_scope(\"decoders\") as scope:\n",
" decode_outputs, decode_state = seq2seq.embedding_rnn_seq2seq(\n",
" encode_input, decode_input, cell, input_vocab_size, output_vocab_size)\n",
" encode_input, decode_input, stacked_lstm, input_vocab_size, output_vocab_size)\n",
" \n",
" scope.reuse_variables()\n",
" \n",
" decode_outputs_test, decode_state_test = seq2seq.embedding_rnn_seq2seq(\n",
" encode_input, decode_input, cell, input_vocab_size, output_vocab_size, \n",
" encode_input, decode_input, stacked_lstm, input_vocab_size, output_vocab_size, \n",
" feed_previous=True)"
]
},
Expand All @@ -342,7 +361,7 @@
},
{
"cell_type": "code",
"execution_count": 341,
"execution_count": 15,
"metadata": {
"collapsed": false
},
Expand All @@ -356,7 +375,7 @@
},
{
"cell_type": "code",
"execution_count": 342,
"execution_count": 16,
"metadata": {
"collapsed": false
},
Expand All @@ -381,7 +400,7 @@
},
{
"cell_type": "code",
"execution_count": 343,
"execution_count": 17,
"metadata": {
"collapsed": true
},
Expand Down Expand Up @@ -423,7 +442,7 @@
},
{
"cell_type": "code",
"execution_count": 346,
"execution_count": 18,
"metadata": {
"collapsed": true
},
Expand All @@ -439,12 +458,14 @@
"def train_batch(data_iter):\n",
" X, Y = data_iter.next_batch()\n",
" feed_dict = get_feed(X, Y)\n",
" feed_dict[keep_prob] = 0.5\n",
" _, out = sess.run([train_op, loss], feed_dict)\n",
" return out\n",
"\n",
"def get_eval_batch_data(data_iter):\n",
" X, Y = data_iter.next_batch()\n",
" feed_dict = get_feed(X, Y)\n",
" feed_dict[keep_prob] = 1.\n",
" all_output = sess.run([loss] + decode_outputs_test, feed_dict)\n",
" eval_loss = all_output[0]\n",
" decode_output = np.array(all_output[1:]).transpose([1,0,2])\n",
Expand Down Expand Up @@ -473,43 +494,14 @@
},
{
"cell_type": "code",
"execution_count": 348,
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"val loss : 0.200953, val predict = 40.0%\n",
"train loss : 0.165016, train predict = 48.0%\n",
"\n",
"val loss : 0.187188, val predict = 42.5%\n",
"train loss : 0.143054, train predict = 50.6%\n",
"\n",
"val loss : 0.175085, val predict = 45.8%\n",
"train loss : 0.115830, train predict = 56.4%\n",
"\n",
"val loss : 0.178171, val predict = 45.7%\n",
"train loss : 0.111954, train predict = 58.9%\n",
"\n",
"val loss : 0.172554, val predict = 49.3%\n",
"train loss : 0.091568, train predict = 65.0%\n",
"\n",
"val loss : 0.181853, val predict = 46.6%\n",
"train loss : 0.083485, train predict = 66.0%\n",
"\n",
"val loss : 0.179023, val predict = 48.3%\n",
"train loss : 0.064898, train predict = 72.2%\n",
"\n",
"interrupted by user\n"
]
}
],
"outputs": [],
"source": [
"for i in range(10000):\n",
"for i in range(100000):\n",
" try:\n",
" train_batch(train_iter)\n",
" if i % 1000 == 0:\n",
Expand Down

0 comments on commit 6e2a31e

Please sign in to comment.