Skip to content

Commit

Permalink
Jm/docs (#76)
Browse files Browse the repository at this point in the history
* Header order preservation, test generation during model load in read-mode

* use new keras code for dp model

* Move global patch only with DP model built, add warning
  • Loading branch information
johntmyers committed Nov 25, 2020
1 parent 794fde0 commit 2ccadcb
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 9 deletions.
13 changes: 13 additions & 0 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,11 @@ def generate_batch_lines(
seed_fields: A dictionary that maps field/column names to initial seed values for those columns. This seed
will only apply to the first batch that gets trained and generated. Additionally, the fields provided
in the mapping MUST exist at the front of the first batch.
NOTE:
This param may also be a list of dicts. If this is the case, then ``num_lines`` will automatically
be set to the list length downstream, and a 1:1 ratio will be used for generating valid lines for
each prefix.
parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
Expand Down Expand Up @@ -611,9 +616,17 @@ def generate_all_batch_lines(
will be set to ``False`` in the result dictionary from this method.
num_lines: The number of lines to create from each batch. If ``None`` then the value
from the config template will be used.
NOTE:
Will be overridden / ignored if ``seed_fields`` is a list. Will be set to the len of the list.
seed_fields: A dictionary that maps field/column names to initial seed values for those columns. This seed
will only apply to the first batch that gets trained and generated. Additionally, the fields provided
in the mapping MUST exist at the front of the first batch.
NOTE:
This param may also be a list of dicts. If this is the case, then ``num_lines`` will automatically
be set to the list length downstream, and a 1:1 ratio will be used for generating valid lines for
each prefix.
parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
Expand Down
12 changes: 12 additions & 0 deletions src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def get_training_callable(self) -> Callable:
"""
pass

def gpu_check(self):
"""Optionally do a GPU check and warn if
a GPU is not available, if not overridden,
do nothing
"""
pass

def __post_init__(self):
if not self.checkpoint_dir or not self.input_data_path:
raise AttributeError(
Expand Down Expand Up @@ -260,6 +267,11 @@ def get_generator_class(self):
def get_training_callable(self):
return train_rnn

def gpu_check(self):
device_name = tf.test.gpu_device_name()
if not device_name.startswith("/device:GPU:"):
logging.warning("***** GPU not found, CPU will be used instead! *****")


#################
# Config Factory
Expand Down
19 changes: 17 additions & 2 deletions src/gretel_synthetics/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ class GenText(gen_text):


class BaseGenerator(ABC):
"""Specific generation modules should have a
"""
Do not use directly.
Specific generation modules should have a
subclass of this ABC that implements the core logic
for generating data
"""
Expand All @@ -81,6 +84,8 @@ def generate_next(self, num_lines: int, hard_limit: Optional[int] = None) -> Ite
@dataclass
class Settings:
"""
Do not use directly.
Arguments for a generator generating lines of text.
This class contains basic settings for a generation process. It is separated from the Generator class
Expand Down Expand Up @@ -145,6 +150,13 @@ def generate_text(
value, then you MUST utilize the field delimiter specified in your config.
An example would be "foo,bar,baz,". Also, if using a field delimiter, the string
MUST end with the delimiter value.
NOTE:
This param may also be a list of prefixes. If this is provided, then
the generator will attempt to create exactly 1 record for each seed in the
list. The ``num_lines`` param will be implicity set to the size of the list
and this number of records will be created at a 1:1 ratio between prefix strings
and valid records.
line_validator: An optional callback validator function that will take
the raw string value from the generator as a single argument. This validator
can executue arbitrary code with the raw string value. The validator function
Expand All @@ -155,7 +167,10 @@ def generate_text(
max_invalid: If using a ``line_validator``, this is the maximum number of invalid
lines to generate. If the number of invalid lines exceeds this value a ``RunTimeError``
will be raised.
num_lines: If not ``None``, this will override the ``gen_lines`` value that is provided in the ``config``
num_lines: If not ``None``, this will override the ``gen_lines`` value that is provided in the ``config``.
NOTE:
If ``start_string`` is a list, this value will be set to the length of that list and any other
values for the param are ignored.
parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
Expand Down
17 changes: 10 additions & 7 deletions src/gretel_synthetics/tensorflow/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ def build_dp_model(store, batch_size, vocab_size) -> tf.keras.Sequential:
"""
logging.warning("Experimental: Differentially private training enabled")

recurrent_v2 = importlib.import_module("tensorflow.python.keras.layers.recurrent_v2")
# NOTE: This patches the LSTMs to use the new Keras 2.4.x code paths
# and will have no effect when the module function is removed
use_new_code = getattr(recurrent_v2, "_use_new_code", None)
if use_new_code is not None:
logging.warning("******* Patching TensorFlow to utilize new Keras code paths, see: %s", "https://github.com/tensorflow/tensorflow/issues/44917 *******") # noqa
recurrent_v2._use_new_code = lambda: True # pylint: disable=protected-access
try:
recurrent_v2 = importlib.import_module("tensorflow.python.keras.layers.recurrent_v2")
# NOTE: This patches the LSTMs to use the new Keras 2.4.x code paths
# and will have no effect when the module function is removed
use_new_code = getattr(recurrent_v2, "_use_new_code", None)
if use_new_code is not None:
logging.warning("******* Patching TensorFlow to utilize new Keras code paths, see: %s", "https://github.com/tensorflow/tensorflow/issues/44917 *******") # noqa
recurrent_v2._use_new_code = lambda: True # pylint: disable=protected-access
except ModuleNotFoundError:
pass

optimizer = make_keras_optimizer_class(RMSprop)(
l2_norm_clip=store.dp_l2_norm_clip,
Expand Down
3 changes: 3 additions & 0 deletions src/gretel_synthetics/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional


from gretel_synthetics.tokenizers import SentencePieceTokenizerTrainer, tokenizer_from_model_dir


if TYPE_CHECKING:
from gretel_synthetics.config import BaseConfig
from gretel_synthetics.tokenizers import BaseTokenizerTrainer, BaseTokenizer
Expand Down Expand Up @@ -63,6 +65,7 @@ def train(store: BaseConfig, tokenizer_trainer: Optional[BaseTokenizerTrainer] =
)
train_fn = store.get_training_callable()
store.save_model_params()
store.gpu_check()
train_fn(params)


Expand Down

0 comments on commit 2ccadcb

Please sign in to comment.