Skip to content

Commit

Permalink
Jm/seed list (#72)
Browse files Browse the repository at this point in the history
* Allow smart seeding from a list of seeds
  • Loading branch information
johntmyers committed Nov 19, 2020
1 parent e15479e commit cd9de8b
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 39 deletions.
46 changes: 31 additions & 15 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,30 +430,46 @@ def set_batch_validator(self, batch_idx: int, validator: Callable):
except KeyError:
raise ValueError("invalid batch number!")

def _validate_batch_seed_values(self, batch: Batch, seed_values: dict) -> str:
def _validate_batch_seed_values(self, batch: Batch, seed_values: Union[dict, List[dict]]) -> Union[str, List[str]]:
"""Validate that seed values line up with the first N columns in a batch. Also construct
an appropiate seed string based on the values in the batch
"""
if len(seed_values) > len(batch.headers):
raise RuntimeError("The number of seed fields is greater than the number of columns in the first batch")

headers_to_seed = batch.headers[:len(seed_values)]
tmp = []
for header in headers_to_seed:
value = seed_values.get(header)
if value is None:
raise RuntimeError(f"The header: {header} is not in the seed values mapping") # noqa
tmp.append(str(value))

return batch.config.field_delimiter.join(tmp) + batch.config.field_delimiter
ret_str = True
if isinstance(seed_values, dict):
seed_values = [seed_values]
elif isinstance(seed_values, list):
ret_str = False
else:
raise TypeError("seed_values should be a dict or list of dicts")

seed_strings = []

for seed in seed_values:
if len(seed) > len(batch.headers):
raise RuntimeError("The number of seed fields is greater than the number of columns in the first batch")

headers_to_seed = batch.headers[:len(seed)]
tmp = []
for header in headers_to_seed:
value = seed.get(header)
if value is None:
raise RuntimeError(f"The header: {header} is not in the seed values mapping") # noqa
tmp.append(str(value))

seed_strings.append(batch.config.field_delimiter.join(tmp) + batch.config.field_delimiter)

if ret_str:
return seed_strings[0]
else:
return seed_strings

def generate_batch_lines(
self,
batch_idx: int,
max_invalid=MAX_INVALID,
raise_on_exceed_invalid: bool = False,
num_lines: int = None,
seed_fields: dict = None,
seed_fields: Union[dict, List[dict]] = None,
parallelism: int = 0,
) -> bool:
"""Generate lines for a single batch. Lines generated are added
Expand Down Expand Up @@ -527,7 +543,7 @@ def generate_all_batch_lines(
max_invalid=MAX_INVALID,
raise_on_failed_batch: bool = False,
num_lines: int = None,
seed_fields: dict = None,
seed_fields: Union[dict, List[dict]] = None,
parallelism: int = 0,
) -> dict:
"""Generate synthetic lines for all batches. Lines for each batch
Expand Down
42 changes: 30 additions & 12 deletions src/gretel_synthetics/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
from collections import namedtuple
from dataclasses import dataclass, asdict
from typing import TYPE_CHECKING, Optional, Callable, List, Iterable, Iterator
from typing import TYPE_CHECKING, Optional, Callable, List, Iterable, Union, Iterator
from abc import ABC, abstractmethod

from gretel_synthetics.generate_parallel import get_num_workers, generate_parallel
Expand Down Expand Up @@ -94,34 +94,44 @@ class Settings:
"""

config: BaseConfig
start_string: Optional[str] = None
start_string: Optional[Union[str, List[str]]] = None
multi_seed: bool = False
line_validator: Optional[Callable] = None
max_invalid: int = 1000
generator: BaseGenerator = None
tokenizer: BaseTokenizer = None

def __post_init__(self):
if self.start_string is not None:
self._process_start_string()
else:
if self.start_string is None:
self.start_string = self.tokenizer.newline_str

def _process_start_string(self):
if not isinstance(self.start_string, str):
else:
if isinstance(self.start_string, str):
self.start_string = self._process_start_string(self.start_string)
elif isinstance(self.start_string, list):
new_strings = []
for s in self.start_string:
new_strings.append(self._process_start_string(s))
self.start_string = new_strings
self.multi_seed = True
else:
raise GenerationError("start_string must be a string or list of strings")

def _process_start_string(self, start_str: str) -> str:
if not isinstance(start_str, str):
raise GenerationError("Seed start_string must be a str!")
if self.config.field_delimiter is not None:
# the start_string must end with the delim
if not self.start_string.endswith(self.config.field_delimiter):
raise GenerationError(f"start_string must end with the specified field delimiter: {self.config.field_delimiter}") # noqa
self.start_string = self.start_string.replace(
if not start_str.endswith(self.config.field_delimiter):
raise GenerationError(f"start_str must end with the specified field delimiter: {self.config.field_delimiter}") # noqa
return start_str.replace(
self.config.field_delimiter,
self.config.field_delimiter_token
)


def generate_text(
config: BaseConfig,
start_string: Optional[str] = None,
start_string: Optional[Union[str, List[str]]] = None,
line_validator: Optional[Callable] = None,
max_invalid: int = 1000,
num_lines: Optional[int] = None,
Expand Down Expand Up @@ -202,7 +212,15 @@ def my_validator(raw_line: str):
else:
_line_count = config.gen_lines

# If we are given a list of start strings, we assume that we
# want to generate a line for each start string, so set the
# line count to this number
if settings.multi_seed:
_line_count = len(start_string)

num_workers = get_num_workers(parallelism, _line_count, chunk_size=5)
if num_workers > 1 and settings.multi_seed:
raise RuntimeError("When providing a list of start strings, parallelism cannot be used")
if num_workers == 1:
gen = generator_class(settings)
yield from gen.generate_next(_line_count)
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_synthetics/generate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _generate(
) -> str:
batch_mode = is_model_dir_batch_mode(model_dir)
if batch_mode:
if seed is not None and not isinstance(seed, dict):
if seed is not None and not isinstance(seed, (dict, list)):
raise TypeError("Seed must be a dict in batch mode")
out_fname = f"{file_name}.csv"
batcher = DataFrameBatch(mode="read", checkpoint_dir=str(model_dir))
Expand Down
28 changes: 17 additions & 11 deletions src/gretel_synthetics/tensorflow/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Iterable,
Optional,
Tuple,
Union
)

import tensorflow as tf
Expand Down Expand Up @@ -48,6 +49,9 @@ class TensorFlowGenerator(BaseGenerator):

def __init__(self, settings: Settings):
self.settings = settings
if self.settings.multi_seed:
self.settings.config.predict_batch_size = 1
self.settings.reset_states = True
self.model = load_model(settings.config, self.settings.tokenizer)
self.delim = settings.config.field_delimiter
self._predictions = self._predict_forever()
Expand Down Expand Up @@ -99,10 +103,10 @@ def generate_next(
text=rec, valid=False, explain=str(err), delimiter=self.delim
)
else:
if self.settings.line_validator and _valid:
valid_lines_generated += 1
elif not self.settings.line_validator:
if (self.settings.line_validator and _valid) or not self.settings.line_validator:
valid_lines_generated += 1
if self.settings.multi_seed:
self.settings.start_string.pop(0)
else:
...

Expand Down Expand Up @@ -153,7 +157,7 @@ def _replace_prefix(
def _predict_chars(
model: tf.keras.Sequential,
tokenizer: BaseTokenizer,
start_string: str,
start_string: Union[str, List[str]],
store: TensorFlowConfig,
predict_and_sample: Optional[Callable] = None,
) -> GeneratorType[PredString, None, None]:
Expand All @@ -171,7 +175,12 @@ def _predict_chars(
"""

# Converting our start string to numbers (vectorizing)
start_vec = tokenizer.encode_to_ids(start_string)
if isinstance(start_string, str):
start_string = [start_string]

_start_string = start_string[0]

start_vec = tokenizer.encode_to_ids(_start_string)
input_eval = tf.constant([start_vec for _ in range(store.predict_batch_size)])

if predict_and_sample is None:
Expand All @@ -189,16 +198,14 @@ def predict_and_sample(this_input):
# expense of model accuracy
model.reset_states()

# if the start string is not the default newline, then we create a prefix string
# that we will append to each decoded prediction
prediction_prefix = None
if start_string != tokenizer.newline_str:
if _start_string != tokenizer.newline_str:
if store.field_delimiter is not None:
prediction_prefix = start_string.replace(
prediction_prefix = _start_string.replace(
store.field_delimiter_token, store.field_delimiter
)
else:
prediction_prefix = start_string
prediction_prefix = _start_string

while not_done:
input_eval = predict_and_sample(input_eval)
Expand All @@ -209,7 +216,6 @@ def predict_and_sample(this_input):
(i, tokenizer.decode_from_ids(batch_sentence_ids[i])) for i in not_done
]
batch_decoded = _replace_prefix(batch_decoded, prediction_prefix)

for i, decoded in batch_decoded:
end_idx = decoded.find(tokenizer.newline_str)
if end_idx >= 0:
Expand Down
17 changes: 17 additions & 0 deletions tests-integration/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ def test_generate_batch_smart_seed(model_path, seed, tmp_path):
row = dict(row)
for k, v in seed.items():
assert row[k] == v


@pytest.mark.parametrize(
"model_path,seed", [
("https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/safecast-batch-sp-0-14.tar.gz",
[{"payload.service_handler": "i-051a2a353509414f0"},
{"payload.service_handler": "i-051a2a353509414f1"},
{"payload.service_handler": "i-051a2a353509414f2"},
{"payload.service_handler": "i-051a2a353509414f3"}]) # noqa
]
)
def test_generate_batch_smart_seed_multi(model_path, seed, tmp_path):
gen = DataFileGenerator(model_path)
out_file = str(tmp_path / "outdata")
fname = gen.generate(100, out_file, seed=seed)
df = pd.read_csv(fname)
assert list(df["payload.service_handler"]) == list(pd.DataFrame(seed)["payload.service_handler"])
9 changes: 9 additions & 0 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def test_delim_multi_field(tf_config):
assert check.start_string == "foo<d>bar<d>baz<d>"


def test_delim_multi_field_multi_starts(tf_config):
check = Settings(
config=tf_config,
start_string=["one,two,three,", "four,five,six,", "seven,eight,nine,"],
tokenizer=mock_tokenizer
)
assert check.start_string == ['one<d>two<d>three<d>', 'four<d>five<d>six<d>', 'seven<d>eight<d>nine<d>']


def test_delim_single_field(tf_config):
check = Settings(config=tf_config, start_string="onlyonefield,", tokenizer=mock_tokenizer)
assert check.start_string == "onlyonefield<d>"

0 comments on commit cd9de8b

Please sign in to comment.