Skip to content

Commit

Permalink
Jm/seed updates (#73)
Browse files Browse the repository at this point in the history
* Allow smart seeding from a list of seeds

* type hint update

* Updates

* Only use .01 for learning rate

Co-authored-by: John Myers <john@gretel.ai>
  • Loading branch information
johntmyers and John Myers committed Nov 19, 2020
1 parent f2c450d commit acc2edb
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 18 deletions.
4 changes: 4 additions & 0 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,10 @@ def generate_batch_lines(
validator = batch.get_validator()
if num_lines is None:
num_lines = batch.config.gen_lines

if isinstance(seed_fields, list):
num_lines = len(seed_fields)

t = tqdm(total=num_lines, desc="Valid record count ")
t2 = tqdm(total=max_invalid, desc="Invalid record count ")
line: GenText
Expand Down
6 changes: 3 additions & 3 deletions src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class TensorFlowConfig(BaseConfig):
training matters. Note: When training with differential privacy enabled,
if the updates are noisy (such as when the additive noise is large
compared to the clipping threshold), a low learning rate may help with training.
Default is ``0.001``.
Default is ``0.01``.
dp_noise_multiplier (optional): The amount of noise sampled and added to gradients during
training. Generally, more noise results in better privacy, at the expense of
model accuracy. Default is ``0.1``.
Expand Down Expand Up @@ -218,7 +218,7 @@ class TensorFlowConfig(BaseConfig):
seq_length: int = 100
embedding_dim: int = 256
rnn_units: int = 256
learning_rate: float = 0.001
learning_rate: float = 0.01
dropout_rate: float = 0.2
rnn_initializer: str = "glorot_uniform"

Expand All @@ -241,7 +241,7 @@ class TensorFlowConfig(BaseConfig):

def __post_init__(self):
if self.dp:
major, minor, micro = tf.__version__.split(".")
major, minor, _ = tf.__version__.split(".")
if (int(major), int(minor)) < (2, 4):
raise RuntimeError(
"Running in differential privacy mode requires TensorFlow 2.4.x or greater. "
Expand Down
5 changes: 1 addition & 4 deletions src/gretel_synthetics/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,7 @@ def _process_start_string(self, start_str: str) -> str:
# the start_string must end with the delim
if not start_str.endswith(self.config.field_delimiter):
raise GenerationError(f"start_str must end with the specified field delimiter: {self.config.field_delimiter}") # noqa
return start_str.replace(
self.config.field_delimiter,
self.config.field_delimiter_token
)
return self.tokenizer.tokenize_delimiter(start_str)


def generate_text(
Expand Down
9 changes: 5 additions & 4 deletions src/gretel_synthetics/generate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def generate(
*,
seed: Optional[Union[str, dict]] = None,
validator: Optional[Callable] = None,
parallelism: int = 1
):
if self.model_path.is_dir():
return self._generate(self.model_path, count, file_name, seed, validator)
Expand All @@ -87,10 +88,10 @@ def generate(
logging.info("Extracting archive to temp dir...")
tar_in.extractall(tmpdir)

return self._generate(Path(tmpdir), count, file_name, seed, validator)
return self._generate(Path(tmpdir), count, file_name, seed, validator, parallelism)

def _generate(
self, model_dir: Path, count: int, file_name: str, seed, validator
self, model_dir: Path, count: int, file_name: str, seed, validator, parallelism
) -> str:
batch_mode = is_model_dir_batch_mode(model_dir)
if batch_mode:
Expand All @@ -101,7 +102,7 @@ def _generate(
batcher.generate_all_batch_lines(
num_lines=count,
max_invalid=max(count, MAX_INVALID),
parallelism=1,
parallelism=parallelism,
seed_fields=seed
)
out_df = batcher.batches_to_df()
Expand All @@ -120,7 +121,7 @@ def _generate(
num_lines=count,
line_validator=validator,
max_invalid=max(count, MAX_INVALID),
parallelism=1,
parallelism=parallelism,
start_string=seed
):
if data.valid or data.valid is None:
Expand Down
4 changes: 1 addition & 3 deletions src/gretel_synthetics/tensorflow/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ def predict_and_sample(this_input):
prediction_prefix = None
if _start_string != tokenizer.newline_str:
if store.field_delimiter is not None:
prediction_prefix = _start_string.replace(
store.field_delimiter_token, store.field_delimiter
)
prediction_prefix = tokenizer.detokenize_delimiter(_start_string)
else:
prediction_prefix = _start_string

Expand Down
12 changes: 12 additions & 0 deletions src/gretel_synthetics/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ def _replace_decoded_tokens(self, decoded_line: str) -> str:
"""
return decoded_line

def tokenize_delimiter(self, line: str) -> str:
return line

def detokenize_delimiter(self, line: str) -> str:
return line


##################
# Single Char
Expand Down Expand Up @@ -526,6 +532,12 @@ def _replace_decoded_tokens(self, decoded_line: str) -> str:
)
return decoded_line

def tokenize_delimiter(self, line: str) -> str:
return line.replace(self.field_delimiter, self.field_delimiter_token)

def detokenize_delimiter(self, line: str) -> str:
return line.replace(self.field_delimiter_token, self.field_delimiter)


##########
# Factory
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow/test_tf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_local_config_settings(mkdir):
"character_coverage": 1.0,
"pretrain_sentence_count": 1000000,
"dp": False,
"learning_rate": 0.001,
"learning_rate": 0.01,
"dp_noise_multiplier": 0.1,
"dp_l2_norm_clip": 3.0,
"dp_microbatches": 64,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def mock_tokenizer():
t = Mock()
t.newline_str = NEWLINE
t.tokenize_delimiter = lambda s: s.replace(",", "<d>")
return t


Expand All @@ -31,12 +32,12 @@ def test_delim_missing_trailing_delim(tf_config):
Settings(config=tf_config, start_string="foo,bar", tokenizer=mock_tokenizer)


def test_delim_multi_field(tf_config):
def test_delim_multi_field(tf_config, mock_tokenizer):
check = Settings(config=tf_config, start_string="foo,bar,baz,", tokenizer=mock_tokenizer)
assert check.start_string == "foo<d>bar<d>baz<d>"


def test_delim_multi_field_multi_starts(tf_config):
def test_delim_multi_field_multi_starts(tf_config, mock_tokenizer):
check = Settings(
config=tf_config,
start_string=["one,two,three,", "four,five,six,", "seven,eight,nine,"],
Expand All @@ -45,6 +46,6 @@ def test_delim_multi_field_multi_starts(tf_config):
assert check.start_string == ['one<d>two<d>three<d>', 'four<d>five<d>six<d>', 'seven<d>eight<d>nine<d>']


def test_delim_single_field(tf_config):
def test_delim_single_field(tf_config, mock_tokenizer):
check = Settings(config=tf_config, start_string="onlyonefield,", tokenizer=mock_tokenizer)
assert check.start_string == "onlyonefield<d>"

0 comments on commit acc2edb

Please sign in to comment.