Skip to content

Commit

Permalink
Merge pull request #197 from CustomProgrammingSolutions/decoding-func…
Browse files Browse the repository at this point in the history
…tion

[MRG] Work on standalone model decoding function
  • Loading branch information
oadams committed Sep 24, 2018
2 parents 0b3d450 + bd548df commit 1999ca5
Show file tree
Hide file tree
Showing 12 changed files with 515 additions and 83 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ language: python
python:
- "3.5"
- "3.6"
before_install:
- if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then sudo add-apt-repository -y ppa:mc3man/trusty-media && sudo apt-get update && sudo apt-get install ffmpeg; fi
install:
- pip install .
- pip install pytest-cov
Expand Down
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added
- Raise a label mismatch exception if label kwarg to Corpus constructor is inconsistent with automatically determined labels.
- Test fixtures for Corpus creation
- Test coverage for Corpus and Model creation

### Changed
- Update package dependencies versions.

### Fixed
- `CorpusReader.train_batch_gen` raises StopIteration instead of returning None if no data can be generated.
- Decoding from saved model is now possible for arbitrary Tensorflow model topologies that have the same input and output structure via named arguments that specify where input and output to the model occur.
- RNN CTC model class now accepts `pathlib.Path` for directory argument

## [0.3.2]

Expand Down
41 changes: 28 additions & 13 deletions persephone/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ class Corpus:
"""

def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Optional[Any] = None,
def __init__(self, feat_type: str, label_type: str, tgt_dir: Path,
labels: Optional[Set[str]] = None,
max_samples:int=1000, speakers: Optional[Sequence[str]] = None) -> None:
""" Construct a `Corpus` instance from preprocessed data.
Expand Down Expand Up @@ -151,8 +152,10 @@ def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Optio
if speakers:
raise NotImplementedError("Speakers not implemented")

logger.debug("Creating a new Corpus object with feature type %s, label type %s,"
"target directory %s, label set %s, ms, max_samples, speakers")
logger.debug("Creating a new Corpus object with feature type {}, label type {},"
"target directory {}, label set {}, max_samples {}, speakers {}".format(
feat_type, label_type, labels, tgt_dir, max_samples, speakers)
)

# In case path is supplied as a string, make it a Path
self.tgt_dir = Path(tgt_dir)
Expand All @@ -179,8 +182,8 @@ def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Optio
self.labels = labels
found_labels = determine_labels(self.tgt_dir, label_type)
if found_labels != self.labels:
raise LabelMismatchException("""User specified labels, {}, do
not match those automatically found, {}.""".format(labels,
raise LabelMismatchException("User specified labels, {}, do"
" not match those automatically found, {}.".format(labels,
found_labels))
else:
self.labels = determine_labels(self.tgt_dir, label_type)
Expand Down Expand Up @@ -213,7 +216,7 @@ def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Optio
self.get_test_fns()[0]
)
except PersephoneException:
logger.error("Got overlap between train valid and test data sets")
logger.error("Got overlap between train, valid and test data sets")
raise

untranscribed_from_file = self.get_untranscribed_prefixes()
Expand All @@ -232,7 +235,7 @@ def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Optio
def from_elan(cls: Type[CorpusT], org_dir: Path, tgt_dir: Path,
feat_type: str = "fbank", label_type: str = "phonemes",
utterance_filter: Callable[[Utterance], bool] = None,
label_segmenter: LabelSegmenter = None,
label_segmenter: Optional[LabelSegmenter] = None,
speakers: List[str] = None, lazy: bool = True,
tier_prefixes: Tuple[str, ...] = ("xv", "rf")) -> CorpusT:
""" Construct a `Corpus` from ELAN files.
Expand Down Expand Up @@ -306,7 +309,7 @@ def from_elan(cls: Type[CorpusT], org_dir: Path, tgt_dir: Path,
wav.extract_wavs(utterances, (tgt_dir / "wav"), lazy=lazy)

corpus = cls(feat_type, label_type, tgt_dir,
label_segmenter.labels, speakers=speakers)
labels=label_segmenter.labels, speakers=speakers)
corpus.utterances = utterances
return corpus

Expand Down Expand Up @@ -350,7 +353,7 @@ def set_and_check_directories(self, tgt_dir: Path) -> None:
raise PersephoneException(
"The supplied path requires a 'label' subdirectory.")

def initialize_labels(self, labels: Sequence[str]) -> Tuple[dict, dict]:
def initialize_labels(self, labels: Set[str]) -> Tuple[dict, dict]:
"""Create mappings from label to index and index to label"""
logger.debug("Creating mappings for labels")

Expand Down Expand Up @@ -459,19 +462,31 @@ def write_prefixes(prefixes: List[str], prefix_fn: Path) -> None:
@staticmethod
def divide_prefixes(prefixes: List[str], seed:int=0) -> Tuple[List[str], List[str], List[str]]:
"""Divide data into training, validation and test subsets"""
if len(prefixes) < 3:
raise PersephoneException(
"{} cannot be split into 3 groups as it only has {} items".format(prefixes, len(prefixes))
)
Ratios = namedtuple("Ratios", ["train", "valid", "test"])
ratios=Ratios(.90, .05, .05)
train_end = int(ratios.train*len(prefixes))
valid_end = int(train_end + ratios.valid*len(prefixes))

# We must make sure that at least one element exists in test
if valid_end == len(prefixes):
valid_end -= 1

# If train_end and valid_end are the same we end up with no valid_prefixes
# so we must ensure at least one prefix is placed in this category
if train_end == valid_end:
train_end -= 1

random.seed(seed)
random.shuffle(prefixes)

train_prefixes = prefixes[:train_end]
valid_prefixes = prefixes[train_end:valid_end]
test_prefixes = prefixes[valid_end:]

# TODO Adjust code to cope properly with toy datasets where these
# subsets might actually be empty.
assert train_prefixes, "Got empty set for training data"
assert valid_prefixes, "Got empty set for validation data"
assert test_prefixes, "Got empty set for testing data"
Expand Down Expand Up @@ -598,7 +613,7 @@ def from_pickle(cls: Type[CorpusT], tgt_dir: Path) -> CorpusT:
return pickle.load(f)


def determine_labels(target_dir: Path, label_type: str) -> set:
def determine_labels(target_dir: Path, label_type: str) -> Set[str]:
""" Returns a set of all phonemes found in the corpus. Assumes that WAV files and
label files are split into utterances and segregated in a directory which contains a
"wav" subdirectory and "label" subdirectory.
Expand All @@ -615,7 +630,7 @@ def determine_labels(target_dir: Path, label_type: str) -> set:
raise FileNotFoundError(
"The directory {} does not exist.".format(target_dir))

phonemes = set() # type: set
phonemes = set() # type: Set[str]
for fn in os.listdir(str(label_dir)):
if fn.endswith(str(label_type)):
with (label_dir / fn).open("r") as f:
Expand Down
8 changes: 4 additions & 4 deletions persephone/corpus_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,15 @@ def human_readable_hyp_ref(self, dense_decoded, dense_y):

return hyps, refs

def human_readable(self, dense_repr):
def human_readable(self, dense_repr: Sequence[Sequence[int]]) -> List[List[str]]:
""" Returns a human readable version of a dense representation of
either or reference to facilitate simple manual inspection.
"""

transcripts = []
for i in range(len(dense_repr)):
transcript = [phn_i for phn_i in dense_repr[i] if phn_i != 0]
transcript = self.corpus.indices_to_labels(transcript)
for dense_r in dense_repr:
non_empty_phonemes = [phn_i for phn_i in dense_r if phn_i != 0]
transcript = self.corpus.indices_to_labels(non_empty_phonemes)
transcripts.append(transcript)

return transcripts
Expand Down
91 changes: 72 additions & 19 deletions persephone/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import os
from pathlib import Path
import sys
from typing import Optional, Union, Sequence, Set, List

from typing import Optional, Union, Sequence, Set, List, Dict
import tensorflow as tf

from .preprocess import labels
from .preprocess import labels, feat_extract
from . import utils
from . import config
from .exceptions import PersephoneException
Expand All @@ -33,31 +32,80 @@ def load_metagraph(model_path_prefix: Union[str, Path]) -> tf.train.Saver:
metagraph = tf.train.import_meta_graph(model_path_prefix + ".meta")
return metagraph

def dense_to_human_readable(dense_repr, index_to_label):
def dense_to_human_readable(dense_repr: Sequence[Sequence[int]], index_to_label: Dict[int, str]) -> List[List[str]]:
""" Converts a dense representation of model decoded output into human
readable, using a mapping from indices to labels. """

transcripts = []
for i in range(len(dense_repr)):
transcript = [phn_i for phn_i in dense_repr[i] if phn_i != 0]
transcript = [index_to_label[index] for index in transcript]
for dense_r in dense_repr:
non_empty_phonemes = [phn_i for phn_i in dense_r if phn_i != 0]
transcript = [index_to_label[index] for index in non_empty_phonemes]
transcripts.append(transcript)

return transcripts

def decode(model_path_prefix: Union[str, Path],
input_paths: Sequence[Path],
label_set: Set[str]) -> List[List[str]]:
label_set: Set[str],
*,
feature_type: str = "fbank", #TODO Make this None and infer feature_type from dimension of NN input layer.
batch_size: int = 64,
preprocessed_output_path: Optional[Path]=None,
batch_x_name: str="Placeholder:0",
batch_x_lens_name: str="Placeholder_1:0",
output_name: str="SparseToDense:0") -> List[List[str]]:
"""Use an existing tensorflow model that exists on disk to decode
WAV files.
Args:
model_path_prefix: The path to the saved tensorflow model.
This is the full prefix to the ".ckpt" file.
input_paths: A sequence of `pathlib.Path`s to WAV files to put through
the model provided.
label_set: The set of all the labels this model uses.
feature_type: The type of features this model uses.
Note that this MUST match the type of features that the
model was trained on initially.
preprocessed_output_path: Any files that require preprocessing will be
saved to the path specified by this.
batch_x_name: The name of the tensorflow input for batch_x
batch_x_lens_name: The name of the tensorflow input for batch_x_lens
output_name: The name of the tensorflow output
"""

model_path_prefix = str(model_path_prefix)

# TODO Confirm that that WAVs exist.

# TODO Confirm that the feature files exist. Create them if they don't.

# TODO Change the second argument to have some upper bound. If the caller
# requests 1000 WAVs be transcribed, they shouldn't all go in one batch.
fn_batches = utils.make_batches(input_paths, len(input_paths))
for p in input_paths:
if not p.exists():
raise PersephoneException(
"The WAV file path {} does not exist".format(p)
)

preprocessed_file_paths = []
for p in input_paths:
# Check the "feat" directory as per the filesystem conventions of a Corpus
prefix = p.stem
feature_file_ext = ".{}.npy".format(feature_type)
conventional_npy_location = p.parent / "feat" / (Path(prefix + feature_file_ext))
if conventional_npy_location.exists():
# don't need to preprocess it
preprocessed_file_paths.append(conventional_npy_location)
else:
if preprocessed_output_path:
mono16k_wav_path = preprocessed_output_path / "{}.wav".format(prefix)
feat_path = preprocessed_output_path / "{}.{}.npy".format(prefix, feature_type)
feat_extract.convert_wav(p, mono16k_wav_path)
preprocessed_file_paths.append(feat_path)
else:
raise PersephoneException(
"Can't preprocess file as no output path was provided, "
"please specify preprocessed_output_path")
# preprocess the file that weren't found in the features directory
# as per the filesystem conventions
if preprocessed_output_path:
feat_extract.from_dir(preprocessed_output_path, feature_type)

fn_batches = utils.make_batches(preprocessed_file_paths, batch_size)
# Load the model and perform decoding.
metagraph = load_metagraph(model_path_prefix)
with tf.Session() as sess:
Expand All @@ -69,10 +117,10 @@ def decode(model_path_prefix: Union[str, Path],
# TODO These placeholder names should be a backup if names from a newer
# naming scheme aren't present. Otherwise this won't generalize to
# different architectures.
feed_dict = {"Placeholder:0": batch_x,
"Placeholder_1:0": batch_x_lens}
feed_dict = {batch_x_name: batch_x,
batch_x_lens_name: batch_x_lens}

dense_decoded = sess.run("SparseToDense:0", feed_dict=feed_dict)
dense_decoded = sess.run(output_name, feed_dict=feed_dict)

# Create a human-readable representation of the decoded.
indices_to_labels = labels.make_indices_to_labels(label_set)
Expand Down Expand Up @@ -171,7 +219,7 @@ def eval(self, restore_model_path: Optional[str]=None) -> None:
logger.info("restoring model from %s", restore_model_path)
saver.restore(sess, restore_model_path)
else:
assert self.saved_model_path
assert self.saved_model_path, "{}".format(self.saved_model_path)
logger.info("restoring model from %s", self.saved_model_path)
saver.restore(sess, self.saved_model_path)

Expand Down Expand Up @@ -373,5 +421,10 @@ def train(self, early_stopping_steps: int = 10, min_epochs: int = 30,
# numper of epochs.
continue

# Check we actually saved a checkpoint
if not self.saved_model_path:
raise PersephoneException(
"No checkpoint was saved so model evaluation cannot be performed. "
"This can happen if the validaion LER never converges.")
# Finally, run evaluation on the test set.
self.eval(restore_model_path=self.saved_model_path)
23 changes: 18 additions & 5 deletions persephone/preprocess/feat_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def mfcc(wav_path):
feat_fn = wav_path[:-3] + "mfcc13_d.npy"
np.save(feat_fn, all_feats)

def combine_fbank_and_pitch(feat_dir, prefix):
def combine_fbank_and_pitch(feat_dir: str, prefix: str) -> None:

fbank_fn = os.path.join(feat_dir, prefix + ".fbank.npy")
fbanks = np.load(fbank_fn)
Expand Down Expand Up @@ -105,13 +105,18 @@ def flatten(feats_3d):
# specific wav format, so that should be coupled together with pitch extraction
# here.
def from_dir(dirpath: Path, feat_type: str) -> None:
""" Performs feature extraction from the WAV files in a directory. """
""" Performs feature extraction from the WAV files in a directory.
Args:
dirpath: A `Path` to the directory where the WAV files reside.
feat_type: The type of features that are being used.
"""

logger.info("Extracting features from directory {}".format(dirpath))

dirname = str(dirpath)

def all_wavs_processed():
def all_wavs_processed() -> bool:
"""
True if all wavs in the directory have corresponding numpy feature
file; False otherwise.
Expand All @@ -127,6 +132,7 @@ def all_wavs_processed():

if all_wavs_processed():
# Then nothing needs to be done here
logger.info("All WAV files already preprocessed")
return
# Otherwise, go on and process everything...

Expand Down Expand Up @@ -155,12 +161,19 @@ def all_wavs_processed():
raise PersephoneException("Feature type not found: %s" % feat_type)

def convert_wav(org_wav_fn: Path, tgt_wav_fn: Path) -> None:
""" Converts the wav into a 16bit mono 16000Hz wav."""
""" Converts the wav into a 16bit mono 16000Hz wav.
Args:
org_wav_fn: A `Path` to the original wave file
tgt_wav_fn: The `Path` to output the processed wave file
"""
if not org_wav_fn.exists():
raise FileNotFoundError
args = [config.FFMPEG_PATH,
"-i", str(org_wav_fn), "-ac", "1", "-ar", "16000", str(tgt_wav_fn)]
subprocess.run(args)

def kaldi_pitch(wav_dir, feat_dir):
def kaldi_pitch(wav_dir: str, feat_dir: str) -> None:
""" Extract Kaldi pitch features. Assumes 16k mono wav files."""

logger.debug("Make wav.scp and pitch.scp files")
Expand Down

0 comments on commit 1999ca5

Please sign in to comment.