-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Provide quicker feedback from parallel workers * assess hard limit based on nominal chunk, not actual size * fix logic of partial chunks * Don't pre-load entire queue, to avoid max queue size OSError * Update sample standalone generation module * reqs update
- Loading branch information
1 parent
680e141
commit 2f707e1
Showing
6 changed files
with
158 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,5 +134,6 @@ dmypy.json | |
venv* | ||
|
||
checkpoints | ||
examples/checkpoints.zip | ||
|
||
docs/_build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
|
||
""" | ||
Example module on how to run data generation from a standlone python invocation. Tensorflow | ||
requires that processes are launch with "spawn" mode, which requires the use of ``freeze_support()`` | ||
that has to be called in the `__main__` scope of the module. | ||
In the event that you choose to export a Notebook to a pure module, please note the changes below. These | ||
changes will have a ``NOTE:`` comment. | ||
""" | ||
|
||
# NOTE: Required import for launching from standlone module | ||
from multiprocessing import freeze_support | ||
from pathlib import Path | ||
|
||
from gretel_synthetics.config import LocalConfig | ||
from gretel_synthetics.generate import generate_text | ||
|
||
PARALLELISM = 0 | ||
|
||
# Create a config that we can use for both training and generating data | ||
# The default values for ``max_lines`` and ``epochs`` are optimized for training on a GPU. | ||
|
||
|
||
# NOTE: Update your ``checkpoint_dir`` and other config params as needed | ||
config = LocalConfig( | ||
max_lines=0, # maximum lines of training data. Set to ``0`` to train on entire file | ||
max_line_len=2048, # the max line length for input training data | ||
epochs=15, # 15-50 epochs with GPU for best performance | ||
vocab_size=20000, # tokenizer model vocabulary size | ||
gen_lines=1000, # the number of generated text lines | ||
dp=True, # train with differential privacy enabled (privacy assurances, but reduced accuracy) | ||
field_delimiter=",", # specify if the training text is structured, else ``None`` | ||
overwrite=True, # overwrite previously trained model checkpoints | ||
checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(), | ||
input_data_path="https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uber_scooter_rides_1day.csv" | ||
) | ||
|
||
|
||
# Let's generate some text! | ||
# | ||
# The ``generate_text`` funtion is a generator that will return | ||
# a line of predicted text based on the ``gen_lines`` setting in your | ||
# config. | ||
# | ||
# There is no limit on the line length as with proper training, your model | ||
# should learn where newlines generally occur. However, if you want to | ||
# specify a maximum char len for each line, you may set the ``gen_chars`` | ||
# attribute in your config object | ||
|
||
|
||
# Optionally, when generating text, you can provide a callable that takes the | ||
# generated line as a single arg. If this function raises any errors, the | ||
# line will fail validation and will not be returned. The exception message | ||
# will be provided as a ``explain`` field in the resulting dict that gets | ||
# created by ``generate_text`` | ||
def validate_record(line): | ||
rec = line.split(", ") | ||
if len(rec) == 6: | ||
float(rec[5]) | ||
float(rec[4]) | ||
float(rec[3]) | ||
float(rec[2]) | ||
int(rec[0]) | ||
else: | ||
raise Exception('record not 6 parts') | ||
|
||
|
||
# NOTE: You should put the actual generation routine into a function, that can be | ||
# called after the parent python processes is done bootstrapping | ||
def start(): | ||
for line in generate_text(config, line_validator=validate_record, parallelism=PARALLELISM): | ||
print(line) | ||
|
||
|
||
# NOTE: Invoke your generation this way | ||
if __name__ == "__main__": | ||
freeze_support() | ||
start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters