Skip to content

Commit

Permalink
Merge pull request #216 from CustomProgrammingSolutions/refactor-kwargs
Browse files Browse the repository at this point in the history
[MRG] Make optional arguments keyword args
  • Loading branch information
oadams committed Jul 9, 2019
2 parents 5fe0f10 + ce75b53 commit 13fcc66
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
7 changes: 5 additions & 2 deletions persephone/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ class Corpus:
"""

def __init__(self, feat_type: str, label_type: str, tgt_dir: Path,
*,
labels: Optional[Set[str]] = None,
max_samples:int=1000, speakers: Optional[Sequence[str]] = None) -> None:
max_samples: int=1000,
speakers: Optional[Sequence[str]] = None) -> None:
""" Construct a `Corpus` instance from preprocessed data.
Assumes that the corpus data has been preprocessed and is
Expand Down Expand Up @@ -235,6 +237,7 @@ def __init__(self, feat_type: str, label_type: str, tgt_dir: Path,
@classmethod
def from_elan(cls: Type[CorpusT], org_dir: Path, tgt_dir: Path,
feat_type: str = "fbank", label_type: str = "phonemes",
*,
utterance_filter: Callable[[Utterance], bool] = None,
label_segmenter: Optional[LabelSegmenter] = None,
speakers: List[str] = None, lazy: bool = True,
Expand Down Expand Up @@ -461,7 +464,7 @@ def write_prefixes(prefixes: List[str], prefix_fn: Path) -> None:
print(prefix, file=prefix_f)

@staticmethod
def divide_prefixes(prefixes: List[str], seed:int=0) -> Tuple[List[str], List[str], List[str]]:
def divide_prefixes(prefixes: List[str], *, seed:int=0) -> Tuple[List[str], List[str], List[str]]:
"""Divide data into training, validation and test subsets"""
if len(prefixes) < 3:
raise PersephoneException(
Expand Down
2 changes: 1 addition & 1 deletion persephone/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def output_best_scores(self, best_epoch_str: str) -> None:
"w", encoding=ENCODING) as best_f:
print(best_epoch_str, file=best_f, flush=True)

def train(self, early_stopping_steps: int = 10, min_epochs: int = 30,
def train(self, *, early_stopping_steps: int = 10, min_epochs: int = 30,
max_valid_ler: float = 1.0, max_train_ler: float = 0.3,
max_epochs: int = 100, restore_model_path: Optional[str]=None,
epoch_callback: Optional[Callable[[Dict], None]]=None) -> None:
Expand Down
4 changes: 2 additions & 2 deletions persephone/tests/experiments/test_na.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_tutorial():

# Test the first setup encouraged in the tutorial
labels = corpus.determine_labels(Path(na_example_dir), "phonemes")
corp = corpus.Corpus("fbank", "phonemes", Path(na_example_dir), labels)
corp = corpus.Corpus("fbank", "phonemes", Path(na_example_dir), labels=labels)

exp_dir = experiment.train_ready(corp, directory=EXP_BASE_DIR)

Expand All @@ -111,7 +111,7 @@ def test_fast():

labels = corpus.determine_labels(Path(tiny_example_dir), "phonemes")

corp = corpus.Corpus("fbank", "phonemes", Path(tiny_example_dir), labels)
corp = corpus.Corpus("fbank", "phonemes", Path(tiny_example_dir), labels=labels)
exp_dir = experiment.prep_exp_dir(directory=EXP_BASE_DIR)
model = experiment.get_simple_model(exp_dir, corp)
model.train(min_epochs=2, max_epochs=5)
Expand Down

0 comments on commit 13fcc66

Please sign in to comment.