Skip to content

Commit

Permalink
Merge pull request #185 from customprogrammingsolutions/warn-on-inval…
Browse files Browse the repository at this point in the history
…id-corpus-reader

[MRG] Improve CorpusReader.train_batch_gen in edge case situation
  • Loading branch information
oadams committed Aug 20, 2018
2 parents d73db45 + 45335f4 commit cb5d5ea
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 12 deletions.
3 changes: 3 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Added
- Raise a label mismatch exception if label kwarg to Corpus constructor is inconsistent with automatically determined labels.

### Fixed
- `CorpusReader.train_batch_gen` raises StopIteration instead of returning None if no data can be generated.

## [0.3.2]

### Added
Expand Down
12 changes: 7 additions & 5 deletions persephone/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,12 @@ def divide_prefixes(prefixes: List[str], seed:int=0) -> Tuple[List[str], List[st

return train_prefixes, valid_prefixes, test_prefixes

def indices_to_labels(self, indices):
def indices_to_labels(self, indices: Sequence[int]) -> List[str]:
""" Converts a sequence of indices into their corresponding labels."""

return [(self.INDEX_TO_LABEL[index]) for index in indices]

def labels_to_indices(self, labels):
def labels_to_indices(self, labels: Sequence[str]) -> List[int]:
""" Converts a sequence of labels into their corresponding indices."""

return [self.LABEL_TO_INDEX[label] for label in labels]
Expand All @@ -507,14 +507,16 @@ def num_feats(self):
return self._num_feats

def prefixes_to_fns(self, prefixes: List[str]) -> Tuple[List[str], List[str]]:
""" Fetches the file paths to the features files and labels files
corresponding to the provided list of features"""
# TODO Return pathlib.Paths
feat_fns = [str(self.feat_dir / ("%s.%s.npy" % (prefix, self.feat_type)))
for prefix in prefixes]
label_fns = [str(self.label_dir / ("%s.%s" % (prefix, self.label_type)))
for prefix in prefixes]
return feat_fns, label_fns

def get_train_fns(self):
def get_train_fns(self) -> Tuple[List[str], List[str]]:
""" Fetches the training set of the corpus.
Outputs a Tuple of size 2, where the first element is a list of paths
Expand All @@ -523,11 +525,11 @@ def get_train_fns(self):
"""
return self.prefixes_to_fns(self.train_prefixes)

def get_valid_fns(self):
def get_valid_fns(self) -> Tuple[List[str], List[str]]:
""" Fetches the validation set of the corpus."""
return self.prefixes_to_fns(self.valid_prefixes)

def get_test_fns(self):
def get_test_fns(self) -> Tuple[List[str], List[str]]:
""" Fetches the test set of the corpus."""
return self.prefixes_to_fns(self.test_prefixes)

Expand Down
10 changes: 7 additions & 3 deletions persephone/corpus_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import logging
import logging.config
from pathlib import Path
import pprint
import random
from typing import List, Sequence, Iterator

import numpy as np

Expand Down Expand Up @@ -114,12 +116,12 @@ def load_batch(self, fn_batch):
return batch_inputs, batch_inputs_lens, batch_targets


def make_batches(self, utterance_fns):
def make_batches(self, utterance_fns: Sequence[Path]) -> List[Sequence[Path]]:
""" Group utterances into batches for decoding. """

return utils.make_batches(utterance_fns, self.batch_size)

def train_batch_gen(self):
def train_batch_gen(self) -> Iterator:
""" Returns a generator that outputs batches in the training data."""

# Create batches of batch_size and shuffle them.
Expand All @@ -132,6 +134,8 @@ def train_batch_gen(self):
logger.debug("Batch of training filenames: %s",
pprint.pformat(fn_batch))
yield self.load_batch(fn_batch)
else:
raise StopIteration

def valid_batch(self):
""" Returns a single batch with all the validation cases."""
Expand Down Expand Up @@ -192,7 +196,7 @@ def __repr__(self):
"\tbatch_size=%s,\n" % repr(self.batch_size) +
"\tcorpus=\n%s)" % repr(self.corpus))

def calc_time(self):
def calc_time(self) -> None:
"""
Prints statistics about the the total duration of recordings in the
corpus.
Expand Down
1 change: 0 additions & 1 deletion persephone/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def train(self, early_stopping_steps: int = 10, min_epochs: int = 30,
batch_gen = self.corpus_reader.train_batch_gen()

train_ler_total = 0
batch_i = None
print("\tBatch...", end="")
for batch_i, batch in enumerate(batch_gen):
print("%d..." % batch_i, end="")
Expand Down
6 changes: 3 additions & 3 deletions persephone/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def batch_per(hyps: Sequence[Sequence[T]],
macro_per += distance.edit_distance(ref, hyp)/len(ref)
return macro_per/len(hyps)

def get_prefixes(dirname, extension):
def get_prefixes(dirname: str, extension: str) -> List[str]:
""" Returns a list of prefixes to files in the directory (which might be a whole
corpus, or a train/valid/test subset. The prefixes include the path leading
up to it, but only the filename up until the first observed period '.'
Expand Down Expand Up @@ -153,7 +153,7 @@ def filter_by_size(feat_dir: Path, prefixes: List[str], feat_type: str,
if length <= max_samples]
return prefixes

def sort_by_size(feat_dir, prefixes, feat_type) -> List[str]:
def sort_by_size(feat_dir: Path, prefixes: List[str], feat_type: str) -> List[str]:
prefix_lens = get_prefix_lens(feat_dir, prefixes, feat_type)
prefix_lens.sort(key=lambda prefix_len: prefix_len[1])
prefixes = [prefix for prefix, _ in prefix_lens]
Expand All @@ -167,7 +167,7 @@ def is_number(string):
except ValueError:
return False

def wav_length(fn):
def wav_length(fn: str) -> float:
""" Returns the length of the WAV file in seconds."""

args = [config.SOX_PATH, fn, "-n", "stat"]
Expand Down

0 comments on commit cb5d5ea

Please sign in to comment.