Skip to content

Commit

Permalink
fixing decoding when feats don't exist.
Browse files Browse the repository at this point in the history
  • Loading branch information
oadams committed Oct 13, 2018
1 parent 43f5bd1 commit 90dc954
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 38 deletions.
68 changes: 36 additions & 32 deletions persephone/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,22 @@ def decode_corpus(model_path_prefix: Union[str, Path],
output_path: Union[str, Path],
*,
batch_size: int = 64,
preprocessed_output_path: Optional[Path]=None,
feat_dir: Optional[Path]=None,
batch_x_name: str="batch_x:0",
batch_x_lens_name: str="batch_x_lens:0",
output_name: str="hyp_dense_decoded:0") -> List[List[str]]:
input_paths = [Path(corpus.tgt_dir) / "feat" / Path(prefix + ".wav")
input_paths = [Path(corpus.tgt_dir) / "wav" / Path(prefix + ".wav")
for prefix in corpus.untranscribed_prefixes]
decoded = decode(model_path_prefix,
input_paths,
label_set=corpus.labels,
output_path=output_path,
feature_type=corpus.feat_type,
batch_size=batch_size,
preprocessed_output_path=preprocessed_output_path,
batch_x_name=batch_x_name,
batch_x_lens_name=batch_x_lens_name,
output_name=output_name)
decode(model_path_prefix,
input_paths,
label_set=corpus.labels,
output_path=output_path,
feature_type=corpus.feat_type,
batch_size=batch_size,
feat_dir=feat_dir,
batch_x_name=batch_x_name,
batch_x_lens_name=batch_x_lens_name,
output_name=output_name)

def decode(model_path_prefix: Union[str, Path],
input_paths: Sequence[Path],
Expand All @@ -76,7 +76,7 @@ def decode(model_path_prefix: Union[str, Path],
*,
feature_type: str = "fbank", #TODO Make this None and infer feature_type from dimension of NN input layer.
batch_size: int = 64,
preprocessed_output_path: Optional[Path]=None,
feat_dir: Optional[Path]=None,
batch_x_name: str="batch_x:0",
batch_x_lens_name: str="batch_x_lens:0",
output_name: str="hyp_dense_decoded:0") -> List[List[str]]:
Expand All @@ -89,10 +89,11 @@ def decode(model_path_prefix: Union[str, Path],
input_paths: A sequence of `pathlib.Path`s to WAV files to put through
the model provided.
label_set: The set of all the labels this model uses.
output_path: Path to a text file to store the output hypothesis.
feature_type: The type of features this model uses.
Note that this MUST match the type of features that the
model was trained on initially.
preprocessed_output_path: Any files that require preprocessing will be
feat_dir: Any files that require preprocessing will be
saved to the path specified by this.
batch_x_name: The name of the tensorflow input for batch_x
batch_x_lens_name: The name of the tensorflow input for batch_x_lens
Expand All @@ -111,26 +112,28 @@ def decode(model_path_prefix: Union[str, Path],
)

preprocessed_file_paths = []
prefixes = [p.stem for p in input_paths]
for prefix in prefixes:
for p in input_paths:
prefix = p.stem
# Check the "feat" directory as per the filesystem conventions of a Corpus
feature_file_ext = ".{}.npy".format(feature_type)
conventional_npy_location = p.parent.parent / "feat" / (Path(prefix + feature_file_ext))
if conventional_npy_location.exists():
# don't need to preprocess it
preprocessed_file_paths.append(conventional_npy_location)
else:
if not preprocessed_output_path:
preprocessed_output_path = conventional_npy_location
if not feat_dir:
feat_dir = p.parent.parent / "feat"
if not feat_dir.is_dir():
os.makedirs(str(feat_dir))

mono16k_wav_path = preprocessed_output_path / "{}.wav".format(prefix)
feat_path = preprocessed_output_path / "{}.{}.npy".format(prefix, feature_type)
mono16k_wav_path = feat_dir / "{}.wav".format(prefix)
feat_path = feat_dir / "{}.{}.npy".format(prefix, feature_type)
feat_extract.convert_wav(p, mono16k_wav_path)
preprocessed_file_paths.append(feat_path)
# preprocess the file that weren't found in the features directory
# as per the filesystem conventions
if preprocessed_output_path:
feat_extract.from_dir(preprocessed_output_path, feature_type)
if feat_dir:
feat_extract.from_dir(feat_dir, feature_type)

fn_batches = utils.make_batches(preprocessed_file_paths, batch_size)
# Load the model and perform decoding.
Expand All @@ -154,6 +157,7 @@ def decode(model_path_prefix: Union[str, Path],
human_readable = dense_to_human_readable(dense_decoded, indices_to_labels)

output_path = Path(output_path)
prefixes = [p.stem for p in input_paths]
prefixes_hyps = sorted(list(zip(prefixes, human_readable)))
with output_path.open("w", encoding=ENCODING) as f:
for prefix, hyp in prefixes_hyps:
Expand Down Expand Up @@ -253,16 +257,16 @@ def decode(self):
batch_x_name = self.batch_x.name
batch_x_lens_name = self.batch_x_lens.name
output_name = self.dense_decoded.name
decoded = decode(model_path_prefix,
input_paths,
label_set,
output_path = Path(self.exp_dir) / "decoded" / "untranscribed.txt",
feature_type=feature_type,
batch_size=batch_size,
batch_x_name=batch_x_name,
batch_x_lens_name=batch_x_lens_name,
output_name=output_name)

decode(model_path_prefix,
input_paths,
label_set,
output_path = Path(self.exp_dir) / "decoded" / "untranscribed.txt",
feature_type=feature_type,
batch_size=batch_size,
batch_x_name=batch_x_name,
batch_x_lens_name=batch_x_lens_name,
output_name=output_name)
def eval(self, restore_model_path: Optional[str]=None) -> None:
""" Evaluates the model on a test set."""

Expand Down
11 changes: 5 additions & 6 deletions persephone/tests/test_rnn_ctc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,22 @@ def test_model_train_and_decode(tmpdir, create_sine, make_wav, create_test_corpu

make_wav(sine_to_decode, wav_to_decode_path)

output_path = tmpdir.mkdir("decode_output")
feat_dir = tmpdir.mkdir("feat")

model_checkpoint_path = base_directory / "model" / "model_best.ckpt"
result = decode(
decode(
model_checkpoint_path,
[Path(wav_to_decode_path)],
label_set = {"A", "B", "C"},
output_path = tmpdir.join("output.txt"),
feature_type = "fbank",
preprocessed_output_path=Path(str(output_path)),
feat_dir=Path(str(feat_dir)),
batch_x_name = test_model.batch_x.name,
batch_x_lens_name = test_model.batch_x_lens.name,
output_name = test_model.dense_decoded.name
)

assert result
assert len(result) == 1

# TODO Fix this test so that we actually confirm decent decoding output

def test_model_train_callback(tmpdir, create_sine, make_wav, create_test_corpus):
"""Test that we can create a model, train it then get our callback called on each epoch of training"""
Expand Down

0 comments on commit 90dc954

Please sign in to comment.