Skip to content

Commit

Permalink
Prevent partial record being returned from the generator.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: a0c2de28c479c0e24377c67eba6264cb0642b4d4
  • Loading branch information
pimlock committed Mar 24, 2022
1 parent d8be88e commit 4672906
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 29 deletions.
80 changes: 52 additions & 28 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pathlib import Path
from typing import Callable, Dict
from typing import Iterator as IteratorType
from typing import List, Optional, Type, Union
from typing import List, Optional, Tuple, Type, Union

import cloudpickle
import gretel_synthetics.const as const
Expand Down Expand Up @@ -712,10 +712,7 @@ def _get_record(self) -> IteratorType[dict]:
# to construct a full line, we'll only count a
# full line once we get through each generator

# if we are using a watchdog thread to monitor generation
# and it throws an exception, a threading event will be set
# that signals generation should stop
if self._thread_event and self._thread_event.is_set():
if self._is_generation_stopped():
break

if self._counter.invalid_count >= self.max_invalid:
Expand All @@ -728,29 +725,9 @@ def _get_record(self) -> IteratorType[dict]:
# the next record.
seed_cache = seed_generator.settings.start_string[0]

record = {}
batch: Batch
for batch, gen in generators:
while True:

# see above usage for watchdog thread exception handling
if self._thread_event and self._thread_event.is_set():
break

line = next(gen) # type: GenText
if line.valid is False:
self._cache_invalid(line)
self._counter.invalid_count += 1
if self._counter.invalid_count > self.max_invalid:
raise RuntimeError(
"Invalid record count exceeded during generation"
)
continue
partial_rec = dict(
zip_longest(batch.headers, line.values_as_list(), fillvalue="")
)
record.update(partial_rec)
break
record = self._generate_record(generators)
if record is None:
continue

# Do a final validation, if configured, on the fully constructed
# record, if this validation fails, we'll still increment our
Expand All @@ -775,6 +752,53 @@ def _get_record(self) -> IteratorType[dict]:
self._counter.valid_count += 1
yield record

def _generate_record(
self, generators: List[Tuple[Batch, IteratorType]]
) -> Optional[dict]:
"""
Generates a single record, by generating data for each batch and merging it all
together.
Args:
generators: List of generators to use
Returns: Optional dict, it will be ``None`` if the record generation was
interrupted or full record otherwise.
"""
record = {}
batch: Batch
for batch, gen in generators:
while True:

if self._is_generation_stopped():
return None

line = next(gen) # type: GenText
if line.valid is False:
self._cache_invalid(line)
self._counter.invalid_count += 1
if self._counter.invalid_count > self.max_invalid:
raise RuntimeError(
"Invalid record count exceeded during generation"
)
continue
partial_rec = dict(
zip_longest(batch.headers, line.values_as_list(), fillvalue="")
)
record.update(partial_rec)
break

return record

def _is_generation_stopped(self):
"""
If we are using a watchdog thread to monitor generation and it throws an
exception, a threading event will be set that signals generation should stop.
Returns: ``True`` if the generation was stopped by setting the _thread_event.
"""
return self._thread_event and self._thread_event.is_set()

def __iter__(self):
return self

Expand Down
62 changes: 61 additions & 1 deletion tests/test_generate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import threading

from typing import Iterator, List
from unittest.mock import Mock

import pytest

from gretel_synthetics.batch import Batch, RecordFactory
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.const import NEWLINE
from gretel_synthetics.errors import GenerationError
from gretel_synthetics.generate import Settings
from gretel_synthetics.generate import gen_text, Settings


@pytest.fixture
Expand Down Expand Up @@ -57,3 +61,59 @@ def test_delim_single_field(tf_config, mock_tokenizer):
config=tf_config, start_string="onlyonefield,", tokenizer=mock_tokenizer
)
assert check.start_string == "onlyonefield<d>"


def test_generate_doesnt_return_partial_record_when_stopped(tmp_path):
"""
Test for a case, when generation is stopped by the _thread_event signal.
Makes sure that we are not returning a partially generated records.
"""
dummy_dir = str(tmp_path)
dummy_config = TensorFlowConfig(
checkpoint_dir=dummy_dir, input_data_path=dummy_dir, field_delimiter="|"
)

rf = RecordFactory(
num_lines=10, batches={}, header_list=["colA", "colB"], delimiter="|"
)
generators = [
(_dummy_batch(["colA"], dummy_config), _gen_and_set_thread_event(rf)),
(_dummy_batch(["colB"], dummy_config), _just_gen("123.33")),
]

record = rf._generate_record(generators)
assert record == {"colA": "world", "colB": "123.33"}
assert rf._counter.invalid_count == 1

record = rf._generate_record(generators)
assert record is None


def _gen_and_set_thread_event(factory: RecordFactory) -> Iterator[gen_text]:
"""
This generator:
- yields an invalid text
- yields a valid text
- sets the _thread_event on the factory instance and yields a valid text
"""
yield gen_text(valid=False, text="hello", delimiter="|")
yield gen_text(valid=True, text="world", delimiter="|")

event = threading.Event()
event.set()
factory._thread_event = event
yield gen_text(valid=True, text="world", delimiter="|")


def _just_gen(value: str) -> Iterator[gen_text]:
while True:
yield gen_text(valid=True, text=value, delimiter="|")


def _dummy_batch(headers: List[str], conf: TensorFlowConfig):
return Batch(
checkpoint_dir="dummy",
input_data_path="dummy",
headers=headers,
config=conf,
)

0 comments on commit 4672906

Please sign in to comment.