Skip to content

Commit

Permalink
Jm/syn 21 (#58)
Browse files Browse the repository at this point in the history
- Introduce Keras Early Stopping and Save Best Model features. Set default number of epochs to 100 which should allow most training sequences to automatically stop without potential over-fitting.

- Provide better tracking of which epoch's model was used as the best one in the model history table

- Temporarily disable DP mode
  • Loading branch information
johntmyers committed Oct 5, 2020
1 parent a94d403 commit 53c3df2
Show file tree
Hide file tree
Showing 14 changed files with 344 additions and 133 deletions.
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

[![Documentation Status](https://readthedocs.org/projects/gretel-synthetics/badge/?version=stable)](https://gretel-synthetics.readthedocs.io/en/stable/?badge=stable)

This code has been developed and tested on Python 3.6, 3.7, and 3.8.

This code has been developed and tested on Python 3.7. Python 3.8 is currently unsupported. While not developed on Python 3.6, this code will run in Google Colab, which currently uses 3.6. If you wish to use Python 3.6, out side of Google Colab, you may install with the `py36` extras: `pip install gretel-synthetics[tf,py36]`, for example.
This code is developed for TensorFlow 2.3.X and above.

This package allows developers to quickly get immersed with synthetic data generation through the use of neural networks. The more complex pieces of working with libraries like Tensorflow and differential privacy are bundled into friendly Python classes and functions.

Expand All @@ -16,8 +17,32 @@ For example usage, please launch the example Jupyter Notebook and step through t
for free in Google Colaboratory. If you're running on a CPU, you might want to grab a cup of coffee,
or lower `max_lines` and `epochs` to 5000 and 10, respectively.

# Roadmap

## Pre 0.14.X

Prior to the 0.14.x versions of Gretel Synthetics, we noticed that the differential privacy library we are using (tensorflow-privacy) may not be properly called based on the version of TensorFlow being used, particularly TF 2.1+. What this means is that with the `dp` option enabled on versions before 0.14.X, the synthetic data may not have been run through DP optimizers properly. We are currently working with the TensorFlow privacy team on an update to resolve this situation.

## 0.14.X

This release series will continue to operate as the versions prior and we will continue to add new functionality that makes training more automated and user friendly. Some enhancements are incorporating Keras' features to do early stopping of model training based on observed loss or accuracy and ensuring that the best versions of models are stored. This will remove the need to guess an optimal number of training epochs and help train the best model sooner.

One temporary change that will be done in this release series is throwing a `RuntimeError` in the event the `dp` option is enabled. We are doing this for a couple of reasons:

1) We want to reduce the risk DP is not applied properly to your data. By default, `dp` has always been disabled by default, so this will continue to remain the case.

2) We did not want to drastically change the signature of the configuration object. By removing these options it becomes more ambiguous to throw a `TypeError` because of removed parameters than it does to throw a `RunTimeError` with a more detailed explanation of why the option cannot be used temporarily.


## 0.15.X

We are currently working to ensure that our differentially private optimizers are called correctly when enabled, and plan to introduce them in this release series. To correctly subclass the standard non-differentially private optimizers in a future-proof way, we are leveraging the Keras V2 optimizer interfaces introduced in TensorFlow 2.4.x. Additionally, we will be doing a significant amount of hyperparameter optimization and provide default optimizers and hyperparameters for non-DP and DP training.

In this release you may expect to see an interface change to the configuration object. We are exploring the use of an `optimizer` parameter that will take an optional `Optimizer()` or `DPOptimizer()` class that you can instantiate yourself and provide to the configuration. This will allow you to explore multiple optimizers with your data. We will still continue to provide the `dp` boolean option that if used will default to optimal `Optimizer()` or `DPOptimizer()` objects based on our hyperparameter testing and should work well for a variety of general synthetic use cases.


# Getting Started

By default, we do not install Tensorflow via pip as many developers and cloud services such as Google Colab are
running customized versions for their hardware. If you wish to pip install Tensorflow along with gretel-synthetics,
use the [tf] commands below instead.
Expand Down
4 changes: 0 additions & 4 deletions examples/dataframe_batch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,8 @@
"checkpoint_dir = str(Path.cwd() / \"checkpoints\")\n",
"\n",
"config_template = {\n",
" \"max_lines\": 0,\n",
" \"max_line_len\": 2048,\n",
" \"epochs\": 15,\n",
" \"vocab_size\": 20000,\n",
" \"gen_lines\": 100,\n",
" \"dp\": True,\n",
" \"field_delimiter\": \",\",\n",
" \"overwrite\": True,\n",
" \"checkpoint_dir\": checkpoint_dir\n",
Expand Down
7 changes: 2 additions & 5 deletions examples/launch_synthetics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@
"\n",
"config_template = {\n",
" \"checkpoint_dir\": checkpoint_dir,\n",
" \"dp\": True, # enable differential privacy in training\n",
" \"epochs\": 15,\n",
" \"gen_lines\": 100,\n",
" \"overwrite\": True,\n",
" \"save_all_checkpoints\": False,\n",
" \"vocab_size\": 20000\n",
Expand Down Expand Up @@ -224,7 +221,7 @@
"# num_lines=500 will override the synthetic config ``num_lines``, set whatever number you need\n",
"# max_invalid=5000 will override the default invalid line limit that terminates execution, set whatever number you need\n",
"\n",
"bundle.generate()"
"bundle.generate(num_lines=1000, max_invalid=1000)"
]
},
{
Expand Down Expand Up @@ -271,7 +268,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
6 changes: 1 addition & 5 deletions examples/synthetic_records.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,8 @@
"# The default values for ``max_lines`` and ``epochs`` are optimized for training on a GPU.\n",
"\n",
"config = LocalConfig(\n",
" max_lines=0, # maximum lines of training data. Set to ``0`` to train on entire file\n",
" max_line_len=2048, # the max line length for input training data\n",
" epochs=15, # 15-50 epochs with GPU for best performance\n",
" vocab_size=20000, # tokenizer model vocabulary size\n",
" gen_lines=1000, # the number of generated text lines\n",
" dp=True, # train with differential privacy enabled (privacy assurances, but reduced accuracy)\n",
" field_delimiter=\",\", # specify if the training text is structured, else ``None``\n",
" overwrite=True, # overwrite previously trained model checkpoints\n",
" checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),\n",
Expand Down Expand Up @@ -102,7 +98,7 @@
" else:\n",
" raise Exception('record not 6 parts')\n",
" \n",
"for line in generate_text(config, line_validator=validate_record):\n",
"for line in generate_text(config, line_validator=validate_record, num_lines=1000):\n",
" print(line)"
]
}
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
tensorflow==2.1.0
tensorflow_privacy==0.2.2
tensorflow==2.3.1
tensorflow_privacy==0.5.1
sentencepiece==0.1.91
smart_open==2.0.0
pandas>=1.0.0
smart_open>=2.1.0,<3.0
pandas>=1.1.0
numpy>=1.18.0
tqdm<5.0
loky==2.8.0
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@
packages=find_packages('src'),
python_requires=">=3.6",
install_requires=[
'tensorflow_privacy==0.2.2',
'tensorflow_privacy==0.5.1',
'sentencepiece==0.1.91',
'smart_open>=2.1.0,<3.0',
'tqdm<5.0',
'pandas>=1.0.0',
'pandas>=1.1.0',
'numpy>=1.18.0',
'dataclasses==0.7;python_version<"3.7"',
'loky==2.8.0',
],
extras_require={
'tf': ['tensorflow==2.1.0']
'tf': ['tensorflow==2.3.1']
},
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
34 changes: 31 additions & 3 deletions src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass, asdict, field
from typing import Optional


logging.basicConfig(
format="%(asctime)s : %(threadName)s : %(levelname)s : %(message)s",
level=logging.INFO,
Expand All @@ -19,6 +20,8 @@

TOKENIZER_PREFIX = "m"
MODEL_PARAMS = "model_params.json"
VAL_LOSS = "loss"
VAL_ACC = "accuracy"


@dataclass
Expand All @@ -35,7 +38,15 @@ class BaseConfig:
this length will be ignored. Default is ``2048``.
epochs (optional): Number of epochs to train the model. An epoch is an iteration over the entire
training set provided. For production use cases, 15-50 epochs are recommended.
Default is ``30``.
The default is ``100`` and is intentionally set extra high. By default, ``early_stopping``
is also enabled and will stop training epochs once the model is no longer improving.
early_stopping (optional). Defaults to ``True``. If enabled, regardless of the number of epochs, automatically
deduce when the model is no longer improving and terminating training.
early_stopping_patience (optional). Defaults to 5. Number of epochs to wait for when there is no improvement
in the model. After this number of epochs, training will terminate.
best_model_metric (optional). Defaults to "loss". The metric to use to track when a model is no
longer improving. Defaults to the loss value. An alternative option is "accuracy."
A error will be raised if either of this values are not used.
batch_size (optional): Number of samples per gradient update. Using larger batch sizes can help
make more efficient use of CPU/GPU parallelization, at the cost of memory.
If unspecified, batch_size will default to ``64``.
Expand All @@ -55,6 +66,8 @@ class BaseConfig:
compromise between retaining model accuracy and preventing overfitting. Default is 0.2.
rnn_initializer (optional): Initializer for the kernal weights matrix, used for the linear
transformation of the inputs. Default is ``glorot_transform``.
optimizer (optional): Optimizer used by the neural network to maximize accuracy and reduce
loss. Currently supported optimizers: ``Adam``, ``SGD``, and ``Adagrad``. Default is ``Adam``.
field_delimiter (optional): Delimiter to use for training on structured data. When specified,
the delimiter is passed as a user-specified token to the tokenizer, which can improve
synthetic data quality. For unstructured text, leave as ``None``. For structured text
Expand Down Expand Up @@ -105,6 +118,9 @@ class BaseConfig:
save_all_checkpoints (optional). Set to ``True`` to save all model checkpoints as they are created,
which can be useful for optimal model selection. Set to ``False`` to save only the latest
checkpoint. Default is ``True``.
save_best_model (optional). Defaults to ``True``. Track the best version of the model (checkpoint) to be used.
If ``save_all_checkpoints`` is disabled, then the saved model will be overwritten by newer ones only if they
are better.
overwrite (optional). Set to ``True`` to automatically overwrite previously saved model checkpoints.
If ``False``, the trainer will generate an error if checkpoints exist in the model directory.
Default is ``False``.
Expand All @@ -114,7 +130,10 @@ class BaseConfig:

# Training configurations
max_lines: int = 0
epochs: int = 15
epochs: int = 100
early_stopping: bool = True
early_stopping_patience: int = 5
best_model_metric: str = VAL_LOSS
batch_size: int = 64
buffer_size: int = 10000
seq_length: int = 100
Expand All @@ -135,7 +154,7 @@ class BaseConfig:

# Diff privacy configs
dp: bool = False
dp_learning_rate: float = 0.015
dp_learning_rate: float = 0.001
dp_noise_multiplier: float = 1.1
dp_l2_norm_clip: float = 1.0
dp_microbatches: int = 256
Expand All @@ -148,6 +167,7 @@ class BaseConfig:

# Checkpoint storage
save_all_checkpoints: bool = False
save_best_model: bool = True
overwrite: bool = False

@abstractmethod
Expand Down Expand Up @@ -212,6 +232,14 @@ class LocalConfig(BaseConfig, _PathSettingsMixin):
input_data_path: str = None

def __post_init__(self):
# FIXME: Remove @ 0.15.X when new optimizers are available for DP
if self.dp:
raise RuntimeError(
"DP mode is disabled in v0.14.X. Please remove or set this value to ``False`` to continue with out DP. DP will be re-enabled in v0.15.X. Please see the README for more details" # noqa
)

if self.best_model_metric not in (VAL_LOSS, VAL_ACC):
raise AttributeError("Invalid value for bset_model_metric")
if not self.checkpoint_dir or not self.input_data_path:
raise AttributeError(
"Must provide checkpoint_dir and input_path_dir params!"
Expand Down
4 changes: 2 additions & 2 deletions src/gretel_synthetics/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sentencepiece as spm
import tensorflow as tf

from gretel_synthetics.model import _build_sequential_model
from gretel_synthetics.model import build_sequential_model

if TYPE_CHECKING:
from gretel_synthetics.config import BaseConfig, LocalConfig
Expand All @@ -23,7 +23,7 @@ def _load_tokenizer(store: LocalConfig) -> spm.SentencePieceProcessor:
def _prepare_model(
sp: spm.SentencePieceProcessor, batch_size: int, store: LocalConfig
) -> tf.keras.Sequential: # pragma: no cover
model = _build_sequential_model(
model = build_sequential_model(
vocab_size=len(sp), batch_size=batch_size, store=store
)

Expand Down
90 changes: 54 additions & 36 deletions src/gretel_synthetics/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,62 @@
Tensorflow - Keras Sequential RNN (GRU)
"""
import logging
from typing import Tuple, TYPE_CHECKING

from tensorflow.keras.optimizers import RMSprop # pylint: disable=import-error
import tensorflow as tf
from tensorflow_privacy.privacy.optimizers.dp_optimizer import (
make_gaussian_optimizer_class as make_dp_optimizer,
)
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy

from gretel_synthetics.config import BaseConfig
if TYPE_CHECKING:
from gretel_synthetics.config import BaseConfig
else:
BaseConfig = None


def _build_sequential_model(
DEFAULT = "default"


OPTIMIZERS = {
DEFAULT: {'dp': make_keras_optimizer_class(RMSprop), 'default': RMSprop}
}


def select_optimizer(store: BaseConfig):
if store.dp:
return OPTIMIZERS[DEFAULT]["dp"]
else:
return OPTIMIZERS[DEFAULT][DEFAULT]


def build_sequential_model(
vocab_size: int, batch_size: int, store: BaseConfig
) -> tf.keras.Sequential:
"""
Utilizing tf.keras.Sequential model (LSTM)
"""
model = tf.keras.Sequential(
model_cls = tf.keras.Sequential
optimizer_cls = select_optimizer(store)

if store.dp:
logging.info("Differentially private training enabled")
optimizer = optimizer_cls(
l2_norm_clip=store.dp_l2_norm_clip,
noise_multiplier=store.dp_noise_multiplier,
num_microbatches=store.dp_microbatches,
learning_rate=store.dp_learning_rate
)
# Compute vector of per-example loss rather than its mean over a minibatch.
# To support gradient manipulation over each training point.
loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE
)
else:
logging.info("Differentially private training _not_ enabled")
optimizer = optimizer_cls(learning_rate=0.01)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model = model_cls(
[
tf.keras.layers.Embedding(
vocab_size, store.embedding_dim, batch_input_shape=[batch_size, None]
Expand All @@ -40,49 +78,29 @@ def _build_sequential_model(
),
tf.keras.layers.Dropout(store.dropout_rate),
tf.keras.layers.Dense(vocab_size),
]
)

if store.dp:
logging.info("Differentially private training enabled")

rms_prop_optimizer = tf.compat.v1.train.RMSPropOptimizer
dp_rms_prop_optimizer = make_dp_optimizer(rms_prop_optimizer)

optimizer = dp_rms_prop_optimizer(
l2_norm_clip=store.dp_l2_norm_clip,
noise_multiplier=store.dp_noise_multiplier,
num_microbatches=store.dp_microbatches,
learning_rate=store.dp_learning_rate,
)

"""
Compute vector of per-example loss rather than its mean over a minibatch.
To support gradient manipulation over each training point.
"""
loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE
)

else:
logging.warning("Differentially private training _not_ enabled")
optimizer = RMSprop(learning_rate=0.01)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
])

logging.info(f"Using {optimizer._keras_api_names[0]} optimizer "
f"{'in differentially private mode' if store.dp else ''}")
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])
return model


def _compute_epsilon(steps: int, store: BaseConfig):
def compute_epsilon(steps: int, store: BaseConfig, epoch_number: int = None) -> Tuple[float, float]:
"""
Calculate epsilon and delta values for differential privacy
Returns:
Tuple of eps, opt_order
"""
# Note: inverse of number of training samples recommended for minimum
# delta in differential privacy
if epoch_number is None:
epoch_number = store.epochs - 1
return compute_dp_sgd_privacy.compute_dp_sgd_privacy(
n=steps,
batch_size=store.batch_size,
noise_multiplier=store.dp_noise_multiplier,
epochs=store.epochs,
epochs=epoch_number,
delta=1.0 / float(steps),
)

0 comments on commit 53c3df2

Please sign in to comment.