Skip to content

Commit

Permalink
Split --load into two to avoid unexpected behavior at evaluation time
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Apr 7, 2020
1 parent cc7a0ad commit 0c6e908
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 104 deletions.
68 changes: 11 additions & 57 deletions bin/run-tc-transfer.sh
Expand Up @@ -31,7 +31,7 @@ for LOAD in 'init' 'last' 'auto'; do
echo "########################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--load_train "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
Expand All @@ -45,60 +45,7 @@ for LOAD in 'init' 'last' 'auto'; do
echo "##############################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10

echo "#################################################################################"
echo "#### Transfer Russian model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "#################################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--drop_source_layers 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load 'last' \
--train_files "${ru_csv}" --train_batch_size 1 \
--dev_files "${ru_csv}" --dev_batch_size 1 \
--test_files "${ru_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
done

echo "#######################################################"
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
echo "##### while iterating over loading logic #####"
echo "#######################################################"

for LOAD in 'init' 'last' 'auto'; do
echo "########################################################"
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
echo "########################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10


echo "##############################################################################"
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "##############################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--load_train "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
Expand All @@ -114,13 +61,20 @@ for LOAD in 'init' 'last' 'auto'; do
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--drop_source_layers 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load 'last' \
--load_train 'last' \
--train_files "${ru_csv}" --train_batch_size 1 \
--dev_files "${ru_csv}" --dev_batch_size 1 \
--test_files "${ru_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10

# Test transfer learning checkpoint
python -u evaluate.py --noshow_progressbar \
--test_files "${ru_csv}" --test_batch_size 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--scorer_path '' \
--n_hidden 100
done
8 changes: 2 additions & 6 deletions training/deepspeech_training/evaluate.py
Expand Up @@ -16,7 +16,7 @@
from six.moves import zip

from .util.config import Config, initialize_globals
from .util.checkpoints import load_graph
from .util.checkpoints import load_graph_for_evaluation
from .util.evaluate_tools import calculate_and_print_report
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
Expand Down Expand Up @@ -82,11 +82,7 @@ def evaluate(test_csvs, create_model):
num_processes = 1

with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_graph(session, method_order)
load_graph_for_evaluation(session)

def run_test(init_op, dataset):
wav_filenames = []
Expand Down
33 changes: 15 additions & 18 deletions training/deepspeech_training/train.py
Expand Up @@ -30,7 +30,7 @@
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
Expand Down Expand Up @@ -508,11 +508,7 @@ def train():
tfv1.get_default_graph().finalize()

# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph_for_training(session, method_order)
load_or_init_graph_for_training(session)

def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
Expand Down Expand Up @@ -773,11 +769,7 @@ def export():

with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_graph(session, method_order)
load_graph_for_evaluation(session)

output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
Expand Down Expand Up @@ -857,11 +849,7 @@ def do_single_file_inference(input_file_path):
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)

# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_graph(session, method_order)
load_graph_for_evaluation(session)

features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
Expand Down Expand Up @@ -896,17 +884,26 @@ def do_single_file_inference(input_file_path):
print(decoded[0][1])


def early_checks():
def early_training_checks():
# Check for proper scorer early
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
del scorer

if FLAGS.train_files and FLAGS.test_files and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir:
log_warn('WARNING: You specified different values for --load_checkpoint_dir '
'and --save_checkpoint_dir, but you are running training and testing '
'in a single invocation. The testing step will respect --load_checkpoint_dir, '
'and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. '
'Train and test in two separate invocations, specifying the correct '
'--load_checkpoint_dir in both cases, or use the same location '
'for loading and saving.')


def main(_):
initialize_globals()
early_checks()
early_training_checks()

if FLAGS.train_files:
tfv1.reset_default_graph()
Expand Down
37 changes: 24 additions & 13 deletions training/deepspeech_training/util/checkpoints.py
Expand Up @@ -87,13 +87,7 @@ def _initialize_all_variables(session):
session.run(v.initializer)


def load_or_init_graph_for_training(session, method_order, allow_drop_layers=True):
'''
Load variables from checkpoint or initialize variables following the method
order specified in the method_order parameter.
Valid methods are 'best', 'last' and 'init'.
'''
def _load_or_init_impl(session, method_order, allow_drop_layers):
for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == 'best':
Expand Down Expand Up @@ -124,12 +118,29 @@ def load_or_init_graph_for_training(session, method_order, allow_drop_layers=Tru
sys.exit(1)


def load_graph(session, method_order):
def load_or_init_graph_for_training(session):
'''
Load variables from checkpoint. Initialization is not allowed. Follows the
method order specified in the method_order parameter.
Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint,
and finally initialize the weights from scratch. This can be overriden with
the `--load_train` flag. See its documentation for more info.
'''
if FLAGS.load_train == 'auto':
methods = ['best', 'last', 'init']
else:
methods = [FLAGS.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True)


Valid methods are 'best' and 'last'.
def load_graph_for_evaluation(session):
'''
Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last
checkpoint. This can be overriden with the `--load_evaluate` flag. See its
documentation for more info.
'''
assert('init' not in method_order)
load_or_init_graph_for_training(session, method_order, allow_drop_layers=False)
if FLAGS.load_evaluate == 'auto':
methods = ['best', 'last']
else:
methods = [FLAGS.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False)
9 changes: 6 additions & 3 deletions training/deepspeech_training/util/config.py
Expand Up @@ -10,7 +10,7 @@

from .flags import FLAGS
from .gpu import get_available_gpus
from .logging import log_error
from .logging import log_error, log_warn
from .text import Alphabet, UTF8Alphabet
from .helpers import parse_file_size

Expand Down Expand Up @@ -45,8 +45,11 @@ def initialize_globals():
if not FLAGS.checkpoint_dir:
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))

if FLAGS.load not in ['last', 'best', 'init', 'auto']:
FLAGS.load = 'auto'
if FLAGS.load_train not in ['last', 'best', 'init', 'auto']:
FLAGS.load_train = 'auto'

if FLAGS.load_evaluate not in ['last', 'best', 'auto']:
FLAGS.load_evaluate = 'auto'

# Set default summary dir
if not FLAGS.summary_dir:
Expand Down
3 changes: 2 additions & 1 deletion training/deepspeech_training/util/flags.py
Expand Up @@ -101,7 +101,8 @@ def create_flags():
f.DEFINE_string('save_checkpoint_dir', '', 'directory to which checkpoints are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a fresh model, "transfer" for transfer learning, "auto" for trying several options.')
f.DEFINE_string('load_train', 'auto', 'what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.')
f.DEFINE_string('load_evaluate', 'auto', 'what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.')

# Transfer Learning

Expand Down
8 changes: 2 additions & 6 deletions transcribe.py
Expand Up @@ -29,7 +29,7 @@ def fail(message, code=1):

def transcribe_file(audio_path, tlog_path):
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from deepspeech_training.util.checkpoints import load_graph
from deepspeech_training.util.checkpoints import load_graph_for_evaluation
initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:
Expand All @@ -50,11 +50,7 @@ def transcribe_file(audio_path, tlog_path):
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_graph(session, method_order)
load_graph_for_evaluation(session)
session.run(iterator.make_initializer(data_set))
transcripts = []
while True:
Expand Down

0 comments on commit 0c6e908

Please sign in to comment.