Skip to content

Commit

Permalink
Correct calculations for when we display/validate/checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Cwiiis committed Nov 17, 2016
1 parent c323e02 commit a49a2e5
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions DeepSpeech.ipynb
Expand Up @@ -1353,8 +1353,13 @@
" for epoch in range(start_epoch, training_iters):\n",
" params = train_params[:]\n",
" \n",
" # Determine if we want to display/validate/checkpoint on this iteration\n",
" is_display_step = display_step > 0 and (epoch + 1) % display_step == 0\n",
" is_validation_step = validation_step > 0 and (epoch + 1) % validation_step == 0\n",
" is_checkpoint_step = (epoch > 0 and (epoch + 1) % checkpoint_step == 0) or epoch == training_iters - 1\n",
" \n",
" # Requirements to display a WER report\n",
" if epoch % display_step == 0:\n",
" if is_display_step:\n",
" # Reset accuracy\n",
" total_accuracy = 0.0\n",
" # Create training results tuple\n",
Expand Down Expand Up @@ -1382,23 +1387,25 @@
" writer.flush()\n",
" \n",
" # Collect individual sample results for WER report\n",
" if epoch % display_step == 0:\n",
" if is_display_step:\n",
" collect_results(train_results, result[2])\n",
" # Add batch to total_accuracy\n",
" total_accuracy += result[3]\n",
" \n",
" # Print WER report\n",
" if epoch % display_step == 0:\n",
" if is_display_step:\n",
" print \"Epoch:\", '%04d' % (epoch), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / ceil(batches_per_device)))\n",
" train_wer = calculate_and_print_wer_report(\"Training\", train_results)\n",
" print\n",
" else:\n",
" print \"Epoch: %04d\" % (epoch)\n",
" \n",
" # Validation step\n",
" if epoch % validation_step == 0:\n",
" if is_validation_step:\n",
" dev_wer = run_inference(session, \"Validation\", data_sets.dev.total_batches, feed_dict_validate, results_params)\n",
"\n",
" # Checkpoint the model\n",
" if (epoch % checkpoint_step == 0) or (epoch == training_iters - 1):\n",
" if is_checkpoint_step:\n",
" checkpoint_path = os.path.join(checkpoint_dir, 'model.ckpt')\n",
" print \"Checkpointing in directory\", \"%s\" % checkpoint_dir\n",
" saver.save(session, checkpoint_path, global_step=epoch)\n",
Expand Down

0 comments on commit a49a2e5

Please sign in to comment.