Skip to content

Commit

Permalink
Support parallel synthetic text generation using multiprocessing (#39)
Browse files Browse the repository at this point in the history
* Support parallel synthetic text generation using multiprocessing

* add cloudpickle to test reqs

* review comments

* set CUDA_VISIBLE_DEVICES to -1 in workers

* decode symbols one by one

* remove un-used var, bump version for RC

Co-authored-by: Malte Isberner <malte@gretel.ai>
Co-authored-by: John Myers <john@gretel.ai>
  • Loading branch information
3 people committed Aug 4, 2020
1 parent b8e277a commit deb22ec
Show file tree
Hide file tree
Showing 11 changed files with 538 additions and 192 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.3
0.11.0.rc1
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ smart_open==2.0.0
pandas==1.0.3
numpy==1.18.3
tqdm<5.0
cloudpickle==1.5.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'pandas==1.0.3',
'numpy==1.18.3',
'dataclasses==0.7;python_version<"3.7"',
'cloudpickle==1.5.0',
],
extras_require={
'tf': ['tensorflow==2.1.0']
Expand Down
20 changes: 16 additions & 4 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _crawl_checkpoint_for_batches(checkpoint_dir: str):
for batch_dir in matching_dirs:
idx = int(Path(batch_dir).name.split("_")[-1])
batches.append((idx, _create_batch_from_dir(batch_dir)))

logger.info("Found and loaded %d batches", len(batches))
return dict(sorted(batches, key=lambda b: b[0]))

Expand Down Expand Up @@ -389,6 +389,7 @@ def generate_batch_lines(
max_invalid=MAX_INVALID,
raise_on_exceed_invalid: bool = False,
num_lines: int = None,
parallelism: int = 0,
) -> bool:
"""Generate lines for a single batch. Lines generated are added
to the underlying ``Batch`` object for each batch. The lines
Expand All @@ -404,6 +405,10 @@ def generate_batch_lines(
indicating that the batch failed to generate all lines.
num_lines: The number of lines to generate, if ``None``, then we use the number from the
batch's config
parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
rounded down.
"""
try:
batch = self.batches[batch_idx]
Expand All @@ -419,7 +424,8 @@ def generate_batch_lines(
line: gen_text
try:
for line in generate_text(
batch.config, line_validator=validator, max_invalid=max_invalid, num_lines=num_lines
batch.config, line_validator=validator, max_invalid=max_invalid, num_lines=num_lines,
parallelism=parallelism,
):
if line.valid is None or line.valid is True:
batch.add_valid_data(line)
Expand All @@ -437,7 +443,8 @@ def generate_batch_lines(
return batch.gen_data_count == num_lines

def generate_all_batch_lines(
self, max_invalid=MAX_INVALID, raise_on_failed_batch: bool = False, num_lines: int = None
self, max_invalid=MAX_INVALID, raise_on_failed_batch: bool = False, num_lines: int = None,
parallelism: int = 0,
) -> dict:
"""Generate synthetic lines for all batches. Lines for each batch
are added to the individual ``Batch`` objects. Once generateion is
Expand All @@ -457,6 +464,10 @@ def generate_all_batch_lines(
will be set to ``False`` in the result dictionary from this method.
num_lines: The number of lines to create from each batch. If ``None`` then the value
from the config template will be used.
parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
rounded down.
Returns:
A dictionary of batch number to a bool value that shows if each batch
Expand All @@ -473,7 +484,8 @@ def generate_all_batch_lines(
idx,
max_invalid=max_invalid,
raise_on_exceed_invalid=raise_on_failed_batch,
num_lines=num_lines
num_lines=num_lines,
parallelism=parallelism,
)
return batch_status

Expand Down
198 changes: 26 additions & 172 deletions src/gretel_synthetics/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,60 +7,18 @@
- Trained a model
"""
import logging
from collections import namedtuple
from dataclasses import dataclass, asdict
from typing import Tuple, TYPE_CHECKING, List, Callable
from typing import TYPE_CHECKING, Callable

import sentencepiece as spm
import tensorflow as tf

from gretel_synthetics.model import _build_sequential_model
from gretel_synthetics.generator import Generator, Settings
from gretel_synthetics.generator import gen_text, PredString # noqa # pylint: disable=unused-import
from gretel_synthetics.generate_parallel import split_work, generate_parallel

if TYPE_CHECKING: # pragma: no cover
from gretel_synthetics.config import BaseConfig

PredString = namedtuple("pred_string", ["data"])


@dataclass
class gen_text:
"""
A record that is yielded from the ``generate_text`` generator.
Attributes:
valid: True, False, or None. If the line passed a validation function,
then this will be ``True``. If the validation function raised an exception
then this will be automatically set to ``False``. If no validation function
is used, then this value will be ``None.``
text: The actual record as a string
explain: A string that describes why a record failed validation. This is the
string representation of the ``Exception`` that is thrown in a validation
function. This will only be set if validation fails, otherwise will be ``None.``
delimiter: If the generated text are column/field based records. This will hold the delimiter
used to separate the fields from each other.
"""

valid: bool = None
text: str = None
explain: str = None
delimiter: str = None

def as_dict(self) -> dict:
"""Serialize the generated record to a dictionary
"""
return asdict(self)

def values_as_list(self) -> List[str]:
"""Attempt to split the generated text on the provided delimiter
Returns:
A list of values that are separated by the object's delimiter or None is there
is no delimiter in the text
"""
if self.delimiter is not None:
tmp = self.text.rstrip(self.delimiter)
return tmp.split(self.delimiter)
return None
from gretel_synthetics.config import LocalConfig
else:
LocalConfig = None


logging.basicConfig(
Expand All @@ -69,44 +27,13 @@ def values_as_list(self) -> List[str]:
)


def _load_tokenizer(store: "BaseConfig") -> spm.SentencePieceProcessor:
logging.info("Loading SentencePiece tokenizer")
sp = spm.SentencePieceProcessor()
sp.Load(store.tokenizer_model)
return sp


def _prepare_model(
sp: spm, batch_size: int, store: "BaseConfig"
) -> tf.keras.Sequential: # pragma: no cover
model = _build_sequential_model(
vocab_size=len(sp), batch_size=batch_size, store=store
)

load_dir = store.checkpoint_dir

model.load_weights(tf.train.latest_checkpoint(load_dir)).expect_partial()

model.build(tf.TensorShape([1, None]))
model.summary()

return model


def _load_model(
store: "BaseConfig",
) -> Tuple[spm.SentencePieceProcessor, tf.keras.Sequential]:
sp = _load_tokenizer(store)
model = _prepare_model(sp, 1, store)
return sp, model


def generate_text(
config: "BaseConfig",
config: LocalConfig,
start_string: str = "<n>",
line_validator: Callable = None,
max_invalid: int = 1000,
num_lines: int = None
num_lines: int = None,
parallelism: int = 0,
):
"""A generator that will load a model and start creating records.
Expand All @@ -126,6 +53,10 @@ def generate_text(
lines to generate. If the number of invalid lines exceeds this value a ``RunTimeError``
will be raised.
num_lines: If not ``None``, this will override the ``gen_lines`` value that is provided in the ``config``
parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
rounded down.
Simple validator example::
Expand Down Expand Up @@ -161,99 +92,22 @@ def my_validator(raw_line: str):
f"Latest checkpoint: {tf.train.latest_checkpoint(config.checkpoint_dir)}"
) # noqa

sp, model = _load_model(config)
lines_generated = 0
delim = config.field_delimiter
invalid = 0
settings = Settings(
config=config,
start_string=start_string,
line_validator=line_validator,
max_invalid=max_invalid,
)

if num_lines is not None:
_line_count = num_lines
else:
_line_count = config.gen_lines

while True:
rec = _predict_chars(model, sp, start_string, config).data
_valid = None
try:
if not line_validator:
yield gen_text(text=rec, valid=None, explain=None, delimiter=delim)
else:
check = line_validator(rec)
if check is False:
_valid = False
invalid += 1
else:
_valid = True
yield gen_text(text=rec, valid=_valid, explain=None, delimiter=delim)
except Exception as err:
# logging.warning(f'Line failed validation: {rec} errored with {str(err)}')
invalid += 1
yield gen_text(text=rec, valid=False, explain=str(err), delimiter=delim)
else:
if line_validator and _valid:
lines_generated += 1
elif not line_validator:
lines_generated += 1
else:
...

if invalid > max_invalid:
raise RuntimeError("Maximum number of invalid lines reached!")

if lines_generated >= _line_count:
break
num_workers, chunks = split_work(parallelism, _line_count)


def _predict_chars(
model: tf.keras.Sequential,
sp: spm.SentencePieceProcessor,
start_string: str,
store: "BaseConfig",
) -> str:
"""
Evaluation step (generating text using the learned model).
Args:
model: tf.keras.Sequential model
sp: SentencePiece tokenizer
start_string: string to bootstrap model
store: our config object
Returns:
Yields line of text per iteration
"""

# Converting our start string to numbers (vectorizing)
input_eval = sp.EncodeAsIds(start_string)
input_eval = tf.expand_dims(input_eval, 0)

# Empty string to store each line
sentence_ids = []

# Here batch size == 1
model.reset_states()

while True:
predictions = model(input_eval)
# remove the batch dimension
predictions = tf.squeeze(predictions, 0)

# using a categorical distribution to
# predict the word returned by the model
predictions = predictions / store.gen_temp
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()

# We pass the predicted word as the next input to the model
# along with the previous hidden state
input_eval = tf.expand_dims([predicted_id], 0)
sentence_ids.append(int(predicted_id))

decoded = sp.DecodeIds(sentence_ids)
if store.field_delimiter is not None:
decoded = decoded.replace(
store.field_delimiter_token, store.field_delimiter
)

if "<n>" in decoded:
return PredString(decoded.replace("<n>", ""))
elif 0 < store.gen_chars <= len(decoded):
return PredString(decoded)
if num_workers == 1: # Sequential operation
gen = Generator(settings)
yield from gen.generate_next(_line_count)
else:
yield from generate_parallel(settings, num_workers, chunks)

0 comments on commit deb22ec

Please sign in to comment.