Skip to content

Commit

Permalink
Add summary info to the synthetic generator (#83)
Browse files Browse the repository at this point in the history
* Add summary info to the synthetic generator

Co-authored-by: Temesghen Kahsai <teme@gretel.ai>
  • Loading branch information
lememta and Temesghen Kahsai committed Jan 21, 2021
1 parent 16337b2 commit bf7aa64
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
18 changes: 11 additions & 7 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def generate_batch_lines(
num_lines: int = None,
seed_fields: Union[dict, List[dict]] = None,
parallelism: int = 0,
) -> bool:
) -> dict:
"""Generate lines for a single batch. Lines generated are added
to the underlying ``Batch`` object for each batch. The lines
can be accessed after generation and re-assembled into a DataFrame.
Expand Down Expand Up @@ -566,6 +566,7 @@ def generate_batch_lines(
t = tqdm(total=num_lines, desc="Valid record count ")
t2 = tqdm(total=max_invalid, desc="Invalid record count ")
line: GenText
n_valid, n_invalid = 0, 0
try:
for line in generate_text(
batch.config,
Expand All @@ -574,21 +575,24 @@ def generate_batch_lines(
num_lines=num_lines,
start_string=seed_string,
parallelism=parallelism,
):
):
if line.valid is None or line.valid is True:
batch.add_valid_data(line)
t.update(1)
n_valid += 1
else:
t2.update(1)
batch.gen_data_invalid.append(line)
n_invalid += 1
except TooManyInvalidError:
if raise_on_exceed_invalid:
raise
else:
return False
t.close()
t2.close()
return batch.gen_data_count >= num_lines
is_valid = batch.gen_data_count >= num_lines
return {'valid_lines': n_valid, 'invalid_lines': n_invalid, 'is_valid': is_valid}

def generate_all_batch_lines(
self,
Expand Down Expand Up @@ -633,12 +637,12 @@ def generate_all_batch_lines(
rounded down.
Returns:
A dictionary of batch number to a bool value that shows if each batch
was able to generate the full number of requested lines::
A dictionary of batch number to a dictionary that reports the number of valid, invalid lines and bool value
that shows if each batch was able to generate the full number of requested lines::
{
0: True,
1: True
0: {'valid_lines' : 1000, 'invalid_lines': 10, 'is_valid': True},
1: {'valid_lines' : 500, 'invalid_lines': 5, 'is_valid': True}
}
"""
batch_status = {}
Expand Down
9 changes: 6 additions & 3 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,21 @@ def bad():

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
assert batches.generate_batch_lines(5, max_invalid=1)
summary = batches.generate_batch_lines(5, max_invalid=1)
assert summary.get('is_valid')
check_call = mock_gen.mock_calls[0]
_, _, kwargs = check_call
assert kwargs["max_invalid"] == 1

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
assert batches.generate_batch_lines(5)
summary = batches.generate_batch_lines(5)
assert summary.get('is_valid')

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()]
assert not batches.generate_batch_lines(5)
summary = batches.generate_batch_lines(5)
assert not summary.get('is_valid')

with patch.object(batches, "generate_batch_lines") as mock_gen:
batches.generate_all_batch_lines(max_invalid=15)
Expand Down

0 comments on commit bf7aa64

Please sign in to comment.