Skip to content

Commit

Permalink
Merge pull request #169 from customprogrammingsolutions/bugfix-corpus…
Browse files Browse the repository at this point in the history
…-labels

[MRG] Implement default behavior for label parameter
  • Loading branch information
oadams committed Jun 6, 2018
2 parents bc6fe93 + 462a881 commit 28fb857
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions persephone/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class Corpus:
"""

def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Any,
def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Optional[Any] = None,
max_samples:int=1000, speakers: Optional[Sequence[str]] = None) -> None:
""" Construct a `Corpus` instance from preprocessed data.
Expand All @@ -138,7 +138,9 @@ def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Any,
label_type: A string describing the transcription labels. For example,
"phonemes" or "tones".
labels: A set of strings representing labels (tokens) used in
transcription. For example: {"a", "o", "th", ...}
transcription. For example: {"a", "o", "th", ...}.
If this parameter is not provided the experiment directory is
scanned for labels present in the transcription files.
max_samples: The maximum number of samples an utterance in the
corpus may have. If an utterance is longer than this, it is not
included in the corpus.
Expand Down Expand Up @@ -168,7 +170,10 @@ def __init__(self, feat_type: str, label_type: str, tgt_dir: Path, labels: Any,
self.set_and_check_directories(tgt_dir)

# Label-related stuff
self.labels = labels
if labels is not None:
self.labels = labels
else:
self.labels = determine_labels(self.tgt_dir, label_type)
self.vocab_size = len(self.labels)
self.LABEL_TO_INDEX, self.INDEX_TO_LABEL = self.initialize_labels(self.labels)
logger.info("Corpus label set: \n\t{}".format(self.labels))
Expand Down Expand Up @@ -605,4 +610,4 @@ def determine_labels(target_dir: Path, label_type: str) -> set:
print("Unicode decode error on file {}".format(fn))
raise
phonemes = phonemes.union(line_phonemes)
return phonemes
return phonemes

0 comments on commit 28fb857

Please sign in to comment.