Skip to content

Commit

Permalink
Jm/queue updates (#48)
Browse files Browse the repository at this point in the history
* Provide quicker feedback from parallel workers

* assess hard limit based on nominal chunk, not actual size

* fix logic of partial chunks

* Don't pre-load entire queue, to avoid max queue size OSError

* Update sample standalone generation module

* reqs update
  • Loading branch information
johntmyers committed Aug 28, 2020
1 parent 680e141 commit 2f707e1
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,6 @@ dmypy.json
venv*

checkpoints
examples/checkpoints.zip

docs/_build
78 changes: 78 additions & 0 deletions examples/generate_as_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@

"""
Example module on how to run data generation from a standlone python invocation. Tensorflow
requires that processes are launch with "spawn" mode, which requires the use of ``freeze_support()``
that has to be called in the `__main__` scope of the module.
In the event that you choose to export a Notebook to a pure module, please note the changes below. These
changes will have a ``NOTE:`` comment.
"""

# NOTE: Required import for launching from standlone module
from multiprocessing import freeze_support
from pathlib import Path

from gretel_synthetics.config import LocalConfig
from gretel_synthetics.generate import generate_text

PARALLELISM = 0

# Create a config that we can use for both training and generating data
# The default values for ``max_lines`` and ``epochs`` are optimized for training on a GPU.


# NOTE: Update your ``checkpoint_dir`` and other config params as needed
config = LocalConfig(
max_lines=0, # maximum lines of training data. Set to ``0`` to train on entire file
max_line_len=2048, # the max line length for input training data
epochs=15, # 15-50 epochs with GPU for best performance
vocab_size=20000, # tokenizer model vocabulary size
gen_lines=1000, # the number of generated text lines
dp=True, # train with differential privacy enabled (privacy assurances, but reduced accuracy)
field_delimiter=",", # specify if the training text is structured, else ``None``
overwrite=True, # overwrite previously trained model checkpoints
checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),
input_data_path="https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uber_scooter_rides_1day.csv"
)


# Let's generate some text!
#
# The ``generate_text`` funtion is a generator that will return
# a line of predicted text based on the ``gen_lines`` setting in your
# config.
#
# There is no limit on the line length as with proper training, your model
# should learn where newlines generally occur. However, if you want to
# specify a maximum char len for each line, you may set the ``gen_chars``
# attribute in your config object


# Optionally, when generating text, you can provide a callable that takes the
# generated line as a single arg. If this function raises any errors, the
# line will fail validation and will not be returned. The exception message
# will be provided as a ``explain`` field in the resulting dict that gets
# created by ``generate_text``
def validate_record(line):
rec = line.split(", ")
if len(rec) == 6:
float(rec[5])
float(rec[4])
float(rec[3])
float(rec[2])
int(rec[0])
else:
raise Exception('record not 6 parts')


# NOTE: You should put the actual generation routine into a function, that can be
# called after the parent python processes is done bootstrapping
def start():
for line in generate_text(config, line_validator=validate_record, parallelism=PARALLELISM):
print(line)


# NOTE: Invoke your generation this way
if __name__ == "__main__":
freeze_support()
start()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
install_requires=[
'tensorflow_privacy==0.2.2',
'sentencepiece==0.1.91',
'smart_open==2.0.0',
'smart_open>=2.1.0,<3.0',
'tqdm<5.0',
'pandas>=1.0.0',
'numpy>=1.18.0',
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_synthetics/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ def my_validator(raw_line: str):
gen = Generator(settings)
yield from gen.generate_next(_line_count)
else:
yield from generate_parallel(settings, num_workers, chunks)
yield from generate_parallel(settings, num_workers, chunks, num_lines)
86 changes: 70 additions & 16 deletions src/gretel_synthetics/generate_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from dataclasses import dataclass, field
from typing import Iterable, List, Optional, Union, Tuple
import queue
import sys
import os
import sys

import cloudpickle

Expand All @@ -13,6 +13,9 @@
mp = multiprocessing.get_context('spawn') # fork does not work with tensorflow


MAX_QUEUE_SIZE = 30000


@dataclass
class _WorkerStatus:
"""
Expand Down Expand Up @@ -79,7 +82,7 @@ def split_work(parallelism: Union[int, float], total_lines: int, chunk_size: int
return num_workers, chunks


def generate_parallel(settings: Settings, num_workers: int, chunks: List[int]):
def generate_parallel(settings: Settings, num_workers: int, chunks: List[int], num_lines: int):
"""
Runs text generation in parallel mode.
Expand All @@ -95,24 +98,33 @@ def generate_parallel(settings: Settings, num_workers: int, chunks: List[int]):
``gen_text`` objects.
"""

# Create a queue of chunks (integers indicating the number of lines that need to be generated).
# This queue is created with sufficient capacity to hold all chunks, and implements a flow control
# mechanism for the parallel generation. It also is used to signal the exit condition to subprocesses,
# as pre-filling the queue once ensures that an empty queue to a worker will always mean to exit.
worker_input_queue = mp.Queue(maxsize=len(chunks))
for chunk in chunks:
# Instruct each worker to return an intermediate result (at the cost of putting a partial chunk
# back into the queue) if it has generated 105% of the requested number of valid lines per chunk
# (regardless of how many lines are valid in the intermediate result). This ensures that workers
# don't get stuck for a long time generating data from bad models, where the ratio of invalid:valid
# lines is very high.
max_response_size = int(chunks[0] * 1.05)

worker_input_queue = mp.Queue(MAX_QUEUE_SIZE)

# Because of the upper limit on the max queue size, we pre-load
# either all chunks or the max queue size amount of chunks
# If there are still more chunks, we'll load those into the queue
# as space becomes more available
for _ in range(min(MAX_QUEUE_SIZE, len(chunks))):
chunk = chunks.pop()
worker_input_queue.put_nowait(chunk)

# Create a queue for output produced by the worker. This queue should be large enough to buffer all
# intermediate statuses to ensure that upstream processing doesn't block downstream workers.
worker_output_queue = mp.Queue(maxsize=len(chunks) + num_workers)
worker_output_queue = mp.Queue(MAX_QUEUE_SIZE)

pickled_settings = cloudpickle.dumps(settings)

workers = [
mp.Process(
target=_run_parallel_worker,
args=(pickled_settings, worker_input_queue, worker_output_queue),
args=(pickled_settings, worker_input_queue, worker_output_queue, max_response_size),
)
for _ in range(num_workers)
]
Expand All @@ -123,6 +135,7 @@ def generate_parallel(settings: Settings, num_workers: int, chunks: List[int]):

live_workers = len(workers)
total_invalid = 0
total_valid = 0
while live_workers > 0:
output = worker_output_queue.get()

Expand All @@ -142,13 +155,34 @@ def generate_parallel(settings: Settings, num_workers: int, chunks: List[int]):
for line in parsed_output.lines:
if line.valid is not None and not line.valid:
total_invalid += 1
else:
total_valid += 1
if total_invalid > settings.max_invalid:
raise RuntimeError("Maximum number of invalid lines reached!")
yield line

if parsed_output.done:
live_workers -= 1 # We aren't expecting anything more from this worker

# if there are still chunks left, try and add them to the queue
# if there are no more chunks, signal the workers to shutdown
while True:
if not chunks:
break
chunk = chunks.pop()
try:
worker_input_queue.put_nowait(chunk)
except queue.Full:
chunks.append(chunk)
break

if total_valid == num_lines:
for _ in range(len(workers)):
try:
worker_input_queue.put_nowait(None)
except queue.Full:
pass

# Join all worker processes (not strictly necessary, but cleaner).
for worker in workers:
worker.join()
Expand All @@ -157,7 +191,8 @@ def generate_parallel(settings: Settings, num_workers: int, chunks: List[int]):
def _run_parallel_worker(
pickled_settings: bytes,
input_queue: mp.Queue,
output_queue: mp.Queue):
output_queue: mp.Queue,
max_response_size: Optional[int] = None):
# Workers should be using CPU only (note, this has no effect on the parent process)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

Expand All @@ -172,24 +207,43 @@ def _run_parallel_worker(
try:
settings = deserialize_settings(pickled_settings)

for status in _process_all_chunks(settings, input_queue):
for status in _process_all_chunks(settings, input_queue, max_response_size):
output_queue.put(cloudpickle.dumps(status))
except BaseException as e:
# Catch serialization errors etc., and put into queue as a raw str to avoid triggering an exception a
# second time that was caused by lack of serializability.
output_queue.put(str(e))


def _process_all_chunks(settings: Settings, input_queue: mp.Queue) -> Iterable[_WorkerStatus]:
def _process_all_chunks(settings: Settings,
input_queue: mp.Queue,
max_response_size: Optional[int] = None) -> Iterable[_WorkerStatus]:
try:
gen = Generator(settings)

while True:
chunk_size = input_queue.get_nowait()
all_lines = list(gen.generate_next(chunk_size))
chunk_size = input_queue.get()
if chunk_size is None:
yield _WorkerStatus(done=True)
break
prev_invalid = gen.total_invalid
all_lines = list(gen.generate_next(chunk_size, hard_limit=max_response_size))
num_valid_lines = len(all_lines) - (gen.total_invalid - prev_invalid)
if num_valid_lines < chunk_size:
# Return the number of lines by which we fell short to the queue.
# This is guaranteed to succeed because every worker will only do this after
# at least removing an element from the queue, thus ensuring that capacity
# constraints are never violated.
input_queue.put(chunk_size - num_valid_lines)
yield _WorkerStatus(lines=all_lines)
except queue.Empty:
# Input queue is pre-filled, so empty queue means we are done.
# NOTE: We shouldn't really get here, but leaving it for now, since we are
# relying on signaling shutdown via a sentinel
#
# It is possible that further elements will be added to the queue
# because a worker generated too many invalid lines, but in this case we can be certain
# that the number of running workers will never fall below the number of concurrently
# pending chunks.
yield _WorkerStatus(done=True)
except BaseException as e:
# Send any exception in its own status object.
Expand Down
9 changes: 7 additions & 2 deletions src/gretel_synthetics/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,25 @@ def __init__(self, settings: Settings):
self.sp, self.model = _load_model(settings.config)
self.delim = settings.config.field_delimiter

def generate_next(self, num_lines: int) -> Iterable[gen_text]:
def generate_next(self, num_lines: int, hard_limit: Optional[int] = None) -> Iterable[gen_text]:
"""
Returns a sequence of lines.
Args:
num_lines: the number of _valid_ lines that should be generated during this call. The actual
number of lines returned may be higher, in case of invalid lines in the generation output.
hard_limit: if set, imposes a hard limit on the overall number of lines that are generated during
this call, regardless of whether the requested number of valid lines was hit.
Yields:
A ``gen_text`` object for every line (valid or invalid) that is generated.
"""
valid_lines_generated = 0
while valid_lines_generated < num_lines:
total_lines_generated = 0

while valid_lines_generated < num_lines and (hard_limit is None or total_lines_generated < hard_limit):
rec = _predict_chars(self.model, self.sp, self.settings.start_string, self.settings.config).data
total_lines_generated += 1
_valid = None
try:
if not self.settings.line_validator:
Expand Down

0 comments on commit 2f707e1

Please sign in to comment.