diff --git a/DeepSpeech.ipynb b/DeepSpeech.ipynb index cf96f54ec6..94841a1739 100644 --- a/DeepSpeech.ipynb +++ b/DeepSpeech.ipynb @@ -1353,8 +1353,13 @@ " for epoch in range(start_epoch, training_iters):\n", " params = train_params[:]\n", " \n", + " # Determine if we want to display/validate/checkpoint on this iteration\n", + " is_display_step = display_step > 0 and (epoch + 1) % display_step == 0\n", + " is_validation_step = validation_step > 0 and (epoch + 1) % validation_step == 0\n", + " is_checkpoint_step = (epoch > 0 and (epoch + 1) % checkpoint_step == 0) or epoch == training_iters - 1\n", + " \n", " # Requirements to display a WER report\n", - " if epoch % display_step == 0:\n", + " if is_display_step:\n", " # Reset accuracy\n", " total_accuracy = 0.0\n", " # Create training results tuple\n", @@ -1382,23 +1387,25 @@ " writer.flush()\n", " \n", " # Collect individual sample results for WER report\n", - " if epoch % display_step == 0:\n", + " if is_display_step:\n", " collect_results(train_results, result[2])\n", " # Add batch to total_accuracy\n", " total_accuracy += result[3]\n", " \n", " # Print WER report\n", - " if epoch % display_step == 0:\n", + " if is_display_step:\n", " print \"Epoch:\", '%04d' % (epoch), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / ceil(batches_per_device)))\n", " train_wer = calculate_and_print_wer_report(\"Training\", train_results)\n", " print\n", + " else:\n", + " print \"Epoch: %04d\" % (epoch)\n", " \n", " # Validation step\n", - " if epoch % validation_step == 0:\n", + " if is_validation_step:\n", " dev_wer = run_inference(session, \"Validation\", data_sets.dev.total_batches, feed_dict_validate, results_params)\n", "\n", " # Checkpoint the model\n", - " if (epoch % checkpoint_step == 0) or (epoch == training_iters - 1):\n", + " if is_checkpoint_step:\n", " checkpoint_path = os.path.join(checkpoint_dir, 'model.ckpt')\n", " print \"Checkpointing in directory\", \"%s\" % checkpoint_dir\n", " saver.save(session, checkpoint_path, global_step=epoch)\n",