Skip to content

Commit

Permalink
Updated text generation to return gen_text object (#16)
Browse files Browse the repository at this point in the history
* Updated text generation to return gen_text object instead of dict to be more compatible with other Gretel services

* Added overwrite protection, field delimiters, black formatting

* Allow custom delim token, setting default to <d>. Update delim token replacement routines.

* No field delim by default, updated tests/notebooks

Co-authored-by: John Myers <john@gretel.ai>
  • Loading branch information
johntmyers and John Myers committed May 19, 2020
1 parent 6b6c2a9 commit e164132
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 148 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.0
0.9.0
11 changes: 6 additions & 5 deletions examples/research/heart_disease_uci.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
" dp_microbatches=256, # split batches into minibatches for parallelism\n",
" checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),\n",
" save_all_checkpoints=False,\n",
" field_delimiter=\",\",\n",
" input_data_path=annotated_file # filepath or S3\n",
")"
]
Expand Down Expand Up @@ -197,8 +198,8 @@
"synth_df = pd.DataFrame(columns=df.columns)\n",
"\n",
"\n",
"for idx, line in enumerate(generate_text(config, line_validator=validate_record)):\n",
" status = line['valid']\n",
"for idx, record in enumerate(generate_text(config, line_validator=validate_record)):\n",
" status = record.valid\n",
" \n",
" # ensure all generated records are unique\n",
" synth_df = synth_df.drop_duplicates()\n",
Expand All @@ -209,8 +210,8 @@
" # if generated record passes validation, save it\n",
" if status:\n",
" print(f\"({synth_cnt}/{records_to_generate} : {status})\") \n",
" print(f\"{line['text']}\")\n",
" data = line['text'].split(\",\")\n",
" print(f\"{line.text}\")\n",
" data = line.values_as_list()\n",
" synth_df = synth_df.append({k:v for k,v in zip(df.columns, data)}, ignore_index=True)\n",
" "
]
Expand Down Expand Up @@ -315,7 +316,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions examples/synthetic_records.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
" dp_l2_norm_clip=1.0, # bound optimizer's sensitivity to individual training points\n",
" dp_microbatches=256, # split batches into minibatches for parallelism\n",
" checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),\n",
" field_delimiter=\",\", # if the training text is structured\n",
" # overwrite=True, # enable this if you want to keep training models to the same checkpoint location\n",
" input_data_path=\"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uber_scooter_rides_1day.csv\" # filepath or S3\n",
")"
]
Expand Down
40 changes: 27 additions & 13 deletions src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
from pathlib import Path
from abc import abstractmethod
from dataclasses import dataclass, asdict, field

from typing import Optional

logging.basicConfig(
format='%(asctime)s : %(threadName)s : %(levelname)s : %(message)s',
level=logging.INFO)
format="%(asctime)s : %(threadName)s : %(levelname)s : %(message)s",
level=logging.INFO,
)


TOKENIZER_PREFIX = 'm'
MODEL_PARAMS = 'model_params.json'
TOKENIZER_PREFIX = "m"
MODEL_PARAMS = "model_params.json"


@dataclass
Expand All @@ -27,6 +28,7 @@ class _BaseConfig:
should not be used directly. Instead you should use one of the
subclasses which are specific to model and checkpoint storage.
"""

# Training configurations
max_lines: int = 0
epochs: int = 30
Expand All @@ -35,16 +37,20 @@ class _BaseConfig:
seq_length: int = 100
embedding_dim: int = 256
rnn_units: int = 256
dropout_rate: float = .2
rnn_initializer: str = 'glorot_uniform'
dropout_rate: float = 0.2
rnn_initializer: str = "glorot_uniform"

# Input data configs
field_delimiter: Optional[str] = None
field_delimiter_token: str = "<d>"

# Tokenizer settings
vocab_size: int = 500
character_coverage: float = 1.0

# Diff privacy configs
dp: bool = False
dp_learning_rate: float = .015
dp_learning_rate: float = 0.015
dp_noise_multiplier: float = 1.1
dp_l2_norm_clip: float = 1.0
dp_microbatches: int = 256
Expand All @@ -56,6 +62,7 @@ class _BaseConfig:

# Checkpoint storage
save_all_checkpoints: bool = True
overwrite: bool = False

@abstractmethod
def _set_tokenizer(self): # pragma: no cover
Expand All @@ -69,6 +76,7 @@ class _PathSettings:
be used directly. It will be utilized by any configuration
classes that need to utilize path-based storage.
"""

tokenizer_model: str = None
training_data: str = None
tokenizer_prefix: str = TOKENIZER_PREFIX
Expand All @@ -85,6 +93,7 @@ class _PathSettingsMixin:
This makes it possible to easily remove the path
settings when serializing the configuration.
"""

paths: _PathSettings = field(default_factory=_PathSettings)

@property
Expand Down Expand Up @@ -112,29 +121,34 @@ class LocalConfig(_BaseConfig, _PathSettingsMixin):
This file will be opened, annotated, and then written out to a path
that is generated based on the ``checkpoint_dir.``
"""

checkpoint_dir: str = None
input_data_path: str = None

def __post_init__(self):
if not self.checkpoint_dir or not self.input_data_path:
raise AttributeError('Must provide checkpoint_dir and input_path_dir params!')
raise AttributeError(
"Must provide checkpoint_dir and input_path_dir params!"
)
if not Path(self.checkpoint_dir).exists():
Path(self.checkpoint_dir).resolve().mkdir()
self._set_tokenizer()

def _set_tokenizer(self):
self.paths.tokenizer_prefix = "m"
self.paths.tokenizer_model = Path(self.checkpoint_dir, 'm.model').as_posix()
self.paths.training_data = Path(self.checkpoint_dir, 'training_data.txt').as_posix()
self.paths.tokenizer_model = Path(self.checkpoint_dir, "m.model").as_posix()
self.paths.training_data = Path(
self.checkpoint_dir, "training_data.txt"
).as_posix()

def as_dict(self):
d = asdict(self)
d.pop('paths')
d.pop("paths")
return d

def save_model_params(self):
save_path = Path(self.checkpoint_dir) / MODEL_PARAMS
logging.info(f"Saving model history to {save_path.name}")
with open(save_path, 'w') as f:
with open(save_path, "w") as f:
json.dump(self.as_dict(), f, indent=2)
return save_path
97 changes: 63 additions & 34 deletions src/gretel_synthetics/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@
- Created a config
- Trained a model
"""
from __future__ import annotations
import logging
import sentencepiece as spm
import tensorflow as tf
from collections import namedtuple
from dataclasses import dataclass, asdict
from typing import Tuple, TYPE_CHECKING, List

from gretel_synthetics.config import _BaseConfig
from gretel_synthetics.model import _build_sequential_model

_pred_string = namedtuple('pred_string', ['data'])
if TYPE_CHECKING:
from gretel_synthetics.config import _BaseConfig

_pred_string = namedtuple("pred_string", ["data"])


@dataclass
Expand All @@ -32,15 +36,37 @@ class gen_text:
explain: A string that describes why a record failed validation. This is the
string representation of the ``Exception`` that is thrown in a validation
function. This will only be set if validation fails, otherwise will be ``None.``
delimiter: If the generated text are column/field based records. This will hold the delimiter
used to separate the fields from each other.
"""

valid: bool = None
text: str = None
explain: str = None
delimiter: str = None

def as_dict(self) -> dict:
"""Serialize the generated record to a dictionary
"""
return asdict(self)

def values_as_list(self) -> List[str]:
"""Attempt to split the generated text on the provided delimiter
Returns:
A list of values that are separated by the object's delimiter or None is there
is no delimiter in the text
"""
if self.delimiter is not None:
tmp = self.text.rstrip(self.delimiter)
return tmp.split(self.delimiter)
return None


logging.basicConfig(
format='%(asctime)s : %(threadName)s : %(levelname)s : %(message)s',
level=logging.INFO)
format="%(asctime)s : %(threadName)s : %(levelname)s : %(message)s",
level=logging.INFO,
)


def _load_tokenizer(store: _BaseConfig) -> spm.SentencePieceProcessor:
Expand All @@ -52,29 +78,28 @@ def _load_tokenizer(store: _BaseConfig) -> spm.SentencePieceProcessor:

def _prepare_model(sp: spm, batch_size: int, store: _BaseConfig) -> tf.keras.Sequential:
model = _build_sequential_model(
vocab_size=len(sp),
batch_size=batch_size,
store=store)
vocab_size=len(sp), batch_size=batch_size, store=store
)

load_dir = store.checkpoint_dir

model.load_weights(
tf.train.latest_checkpoint(
load_dir)).expect_partial()
model.load_weights(tf.train.latest_checkpoint(load_dir)).expect_partial()

model.build(tf.TensorShape([1, None]))
model.summary()

return model


def _gen_text_factory(text: str, valid, explain) -> dict:
return dict(
asdict(gen_text(valid=valid, text=text, explain=explain))
)
def _load_model(store: _BaseConfig) -> Tuple[spm.SentencePieceProcessor, tf.keras.Sequential]:
sp = _load_tokenizer(store)
model = _prepare_model(sp, 1, store)
return sp, model


def generate_text(store: _BaseConfig, start_string: str = "<n>", line_validator: callable = None):
def generate_text(
store: _BaseConfig, start_string: str = "<n>", line_validator: callable = None
):
"""A generator that will load a model and start creating records.
Args:
Expand Down Expand Up @@ -113,38 +138,39 @@ def my_validator(raw_line: str):
"""
logging.info(
f"Latest checkpoint: {tf.train.latest_checkpoint(store.checkpoint_dir)}") # noqa
f"Latest checkpoint: {tf.train.latest_checkpoint(store.checkpoint_dir)}"
) # noqa

# Restore the latest SentencePiece model
sp = _load_tokenizer(store)

# Load the RNN
model = _prepare_model(sp, 1, store)
sp, model = _load_model(store)

lines_generated = 0

delim = store.field_delimiter

while True:
rec = _predict_chars(model, sp, start_string, store).data
try:
if not line_validator:
yield _gen_text_factory(rec, None, None)
yield gen_text(text=rec, valid=None, explain=None, delimiter=delim)
else:
line_validator(rec)
yield _gen_text_factory(rec, True, None)
yield gen_text(text=rec, valid=True, explain=None, delimiter=delim)
except Exception as err:
# logging.warning(f'Line failed validation: {rec} errored with {str(err)}')
yield _gen_text_factory(rec, False, str(err))
yield gen_text(text=rec, valid=False, explain=str(err), delimiter=delim)
finally:
lines_generated += 1

if lines_generated >= store.gen_lines:
break


def _predict_chars(model: tf.keras.Sequential,
sp: spm.SentencePieceProcessor,
start_string: str,
store: _BaseConfig) -> str:
def _predict_chars(
model: tf.keras.Sequential,
sp: spm.SentencePieceProcessor,
start_string: str,
store: _BaseConfig,
) -> str:
"""
Evaluation step (generating text using the learned model).
Expand Down Expand Up @@ -175,18 +201,21 @@ def _predict_chars(model: tf.keras.Sequential,
# using a categorical distribution to
# predict the word returned by the model
predictions = predictions / store.gen_temp
predicted_id = tf.random.categorical(
predictions, num_samples=1)[-1, 0].numpy()
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()

# We pass the predicted word as the next input to the model
# along with the previous hidden state
input_eval = tf.expand_dims([predicted_id], 0)
sentence_ids.append(int(predicted_id))

decoded = sp.DecodeIds(sentence_ids)
decoded = decoded.replace('<c>', ',')

if '<n>' in decoded:
return _pred_string(decoded.replace('<n>', ''))
if store.field_delimiter is not None:
decoded = decoded.replace(
store.field_delimiter_token,
store.field_delimiter
)

if "<n>" in decoded:
return _pred_string(decoded.replace("<n>", ""))
elif 0 < store.gen_chars <= len(decoded):
return _pred_string(decoded)

0 comments on commit e164132

Please sign in to comment.