Skip to content

Commit

Permalink
Utilize proper max_invalid param for generation (#33)
Browse files Browse the repository at this point in the history
* fixes #32

* One more test verification

Co-authored-by: John Myers <john@gretel.ai>
  • Loading branch information
johntmyers and John Myers committed Jun 19, 2020
1 parent f4f6279 commit f9a6291
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def set_batch_validator(self, batch_idx: int, validator: Callable):
except KeyError:
raise ValueError("invalid batch number!")

def generate_batch_lines(self, batch_idx: int, max_invalid=1000):
def generate_batch_lines(self, batch_idx: int, max_invalid=MAX_INVALID):
"""Generate lines for a single batch. Lines generated are added
to the underlying ``Batch`` object for each batch. The lines
can be accessed after generation and re-assembled into a DataFrame.
Expand All @@ -312,7 +312,7 @@ def generate_batch_lines(self, batch_idx: int, max_invalid=1000):
t2 = tqdm(total=max_invalid, desc="Invalid record count ")
line: gen_text
for line in generate_text(
batch.config, line_validator=validator, max_invalid=MAX_INVALID
batch.config, line_validator=validator, max_invalid=max_invalid
):
if line.valid is None or line.valid is True:
batch.add_valid_data(line)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ def good():
def bad():
return gen_text(text="1,2,3", valid=False, delimiter=",")

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
assert batches.generate_batch_lines(5, max_invalid=1)
check_call = mock_gen.mock_calls[0]
_, _, kwargs = check_call
assert kwargs["max_invalid"] == 1

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
assert batches.generate_batch_lines(5)
Expand All @@ -148,8 +155,11 @@ def bad():
assert not batches.generate_batch_lines(5)

with patch.object(batches, "generate_batch_lines") as mock_gen:
batches.generate_all_batch_lines()
batches.generate_all_batch_lines(max_invalid=15)
assert mock_gen.call_count == len(batches.batches.keys())
check_call = mock_gen.mock_calls[0]
_, _, kwargs = check_call
assert kwargs["max_invalid"] == 15


# get synthetic df
Expand Down

0 comments on commit f9a6291

Please sign in to comment.