Skip to content

Commit

Permalink
Merge pull request #194 from CustomProgrammingSolutions/feature-train…
Browse files Browse the repository at this point in the history
…ing-epoch-callback

[MRG] Callback for training loop at end of each epoch
  • Loading branch information
oadams committed Oct 13, 2018
2 parents cfe244c + d53e5fb commit 7bc6dd6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
20 changes: 17 additions & 3 deletions persephone/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import os
from pathlib import Path
import sys
from typing import Optional, Union, Sequence, Set, List, Dict
from typing import Callable, Optional, Union, Sequence, Set, List, Dict

import tensorflow as tf

from .preprocess import labels, feat_extract
Expand Down Expand Up @@ -257,7 +258,8 @@ def output_best_scores(self, best_epoch_str: str) -> None:

def train(self, early_stopping_steps: int = 10, min_epochs: int = 30,
max_valid_ler: float = 1.0, max_train_ler: float = 0.3,
max_epochs: int = 100, restore_model_path: Optional[str]=None) -> None:
max_epochs: int = 100, restore_model_path: Optional[str]=None,
epoch_callback: Optional[Callable[[Dict], None]]=None) -> None:
""" Train the model.
min_epochs: minimum number of epochs to run training for.
Expand All @@ -271,6 +273,10 @@ def train(self, early_stopping_steps: int = 10, min_epochs: int = 30,
Training will continue until this is met or another
stopping condition occurs.
restore_model_path: The path to restore a model from.
epoch_callback: A callback that is called at the end of each training epoch.
The parameters passed to the callable will be the epoch number,
the current training LER and the current validation LER.
This can be useful for progress reporting.
"""
logger.info("Training model")
best_valid_ler = 2.0
Expand Down Expand Up @@ -319,7 +325,7 @@ def train(self, early_stopping_steps: int = 10, min_epochs: int = 30,
if os.path.exists(training_log_path):
logger.error("Error, overwriting existing log file at path {}".format(training_log_path))
with open(training_log_path, "w") as out_file:
for epoch in itertools.count():
for epoch in itertools.count(start=1):
print("\nexp_dir %s, epoch %d" % (self.exp_dir, epoch))
batch_gen = self.corpus_reader.train_batch_gen()

Expand Down Expand Up @@ -372,6 +378,14 @@ def train(self, early_stopping_steps: int = 10, min_epochs: int = 30,
if best_epoch_str is None:
best_epoch_str = epoch_str

# Call the callback here if it was defined
if epoch_callback:
epoch_callback({
"epoch": epoch,
"training_ler": (train_ler_total / (batch_i + 1)), # current training LER
"valid_ler": valid_ler, # Current validation LER
})

# Implement early stopping.
if valid_ler < best_valid_ler:
print("New best valid_ler", file=out_file)
Expand Down
42 changes: 41 additions & 1 deletion persephone/tests/test_rnn_ctc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,44 @@ def test_model_train_and_decode(tmpdir, create_sine, make_wav, create_test_corpu
)

assert result
assert len(result) == 1
assert len(result) == 1


def test_model_train_callback(tmpdir, create_sine, make_wav, create_test_corpus):
"""Test that we can create a model, train it then get our callback called on each epoch of training"""
from persephone.corpus_reader import CorpusReader
from persephone.rnn_ctc import Model
from pathlib import Path
corpus = create_test_corpus()

# If it turns out that `tgt_dir` is not in the public interface of the Corpus
# this test should change and get the base directory from the fixture that created it.
base_directory = corpus.tgt_dir
print("base_directory", base_directory)

corpus_r = CorpusReader(
corpus,
batch_size=1
)
assert corpus_r

test_model = Model(
base_directory,
corpus_r,
num_layers=3,
hidden_size=50
)
assert test_model

from unittest.mock import Mock

mock_callback = Mock(return_value=None)

test_model.train(
early_stopping_steps=1,
min_epochs=1,
max_epochs=10,
epoch_callback=mock_callback
)

assert mock_callback.call_count == 10

0 comments on commit 7bc6dd6

Please sign in to comment.