Skip to content

Commit

Permalink
PROD-421: Reset generation callback state each time new generation st…
Browse files Browse the repository at this point in the history
…arts.

GitOrigin-RevId: e1a6462647ef28b1fd3955c25b5314d6f293db31
  • Loading branch information
pimlock committed May 3, 2023
1 parent 24894d5 commit fca0f3a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,8 @@ def generate_all(
progress_callback = None
if callback:
progress_callback = _GenerationCallback(callback, callback_interval)
# reset the underlying callback, since we are starting a new genaration
callback(GenerationProgress(), reset=True)

self.reset()
if output is not None and output not in ("df",):
Expand Down
28 changes: 27 additions & 1 deletion tests/test_generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading

from typing import Iterator, List
from typing import Any, Iterator, List
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -89,6 +89,32 @@ def test_generate_doesnt_return_partial_record_when_stopped(tmp_path):
assert record is None


def test_generate_resets_previous_progress(tmp_path):
dummy_dir = str(tmp_path)

rf = RecordFactory(
num_lines=10,
batches={},
header_list=["colA", "colB"],
delimiter="|",
)

reset_called = False

def callback(_: Any, *, reset=False):
if reset is True:
nonlocal reset_called
reset_called = True

result = rf.generate_all(
callback=callback, callback_interval=2, callback_threading=True
)

assert reset_called is True
# all will be empty, as there are no batches configured
assert result.records == [{}] * 10


def _gen_and_set_thread_event(factory: RecordFactory) -> Iterator[gen_text]:
"""
This generator:
Expand Down

0 comments on commit fca0f3a

Please sign in to comment.