Skip to content

Commit

Permalink
Merge pull request #143 from customprogrammingsolutions/bugfix-assert…
Browse files Browse the repository at this point in the history
…s-as-validation

[MRG] Raise exceptions not just asserts
  • Loading branch information
oadams committed May 22, 2018
2 parents e616919 + 0cbebf1 commit 81dfaab
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
3 changes: 3 additions & 0 deletions persephone/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,13 @@ def ensure_no_set_overlap(self) -> None:

if train & valid:
logger.warning("train and valid have overlapping items: {}".format(train & valid))
raise PersephoneException("train and valid have overlapping items: {}".format(train & valid))
if train & test:
logger.warning("train and test have overlapping items: {}".format(train & test))
raise PersephoneException("train and test have overlapping items: {}".format(train & test))
if valid & test:
logger.warning("valid and test have overlapping items: {}".format(valid & test))
raise PersephoneException("valid and test have overlapping items: {}".format(valid & test))

def pickle(self):
""" Pickles the Corpus object in a file in tgt_dir. """
Expand Down
11 changes: 9 additions & 2 deletions persephone/corpus_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def __init__(self, corpus, num_train=None, batch_size=None, max_samples=None, ra
if batch_size:
self.batch_size = batch_size
if num_train % batch_size != 0:
raise PersephoneException("""Number of training examples %d not divisible
by batch size %d.""" % (num_train, batch_size))
logger.error("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))
raise PersephoneException("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))
else:
# Dynamically change batch size based on number of training
# examples.
Expand All @@ -70,6 +72,11 @@ def __init__(self, corpus, num_train=None, batch_size=None, max_samples=None, ra
# For now we hope that training numbers are powers of two or
# something... If not, crash before anything else happens.
assert num_train % self.batch_size == 0
if num_train % self.batch_size != 0:
logger.error("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, self.batch_size))
raise PersephoneException("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))

random.seed(rand_seed)

Expand Down
4 changes: 4 additions & 0 deletions persephone/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def zero_pad(matrix, to_length):
x is of length to_length."""

assert matrix.shape[0] <= to_length
if not matrix.shape[0] <= to_length:
logger.error("zero_pad cannot be performed on matrix with shape {}"
" to length {}".format(matrix.shape[0], to_length))
raise ValueError
result = np.zeros((to_length,) + matrix.shape[1:])
result[:matrix.shape[0]] = matrix
return result
Expand Down

0 comments on commit 81dfaab

Please sign in to comment.