Skip to content

Commit

Permalink
Jm/syn 33 (#60)
Browse files Browse the repository at this point in the history
* Configure low level API to utilize start strings with delims and re-assemble prefix into predicted records

* Update batch df mode to utilize seed options
  • Loading branch information
johntmyers committed Oct 20, 2020
1 parent 2b05b71 commit 0d26b6f
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 20 deletions.
47 changes: 44 additions & 3 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cloudpickle

from gretel_synthetics.config import LocalConfig
from gretel_synthetics.generate import gen_text, generate_text
from gretel_synthetics.generate import gen_text, generate_text, NEWLINE
from gretel_synthetics.generator import TooManyInvalidError
from gretel_synthetics.train import train_rnn

Expand Down Expand Up @@ -389,12 +389,30 @@ def set_batch_validator(self, batch_idx: int, validator: Callable):
except KeyError:
raise ValueError("invalid batch number!")

def _validate_batch_seed_values(self, batch: Batch, seed_values: dict) -> str:
"""Validate that seed values line up with the first N columns in a batch. Also construct
an appropiate seed string based on the values in the batch
"""
if len(seed_values) > len(batch.headers):
raise RuntimeError("The number of seed fields is greater than the number of columns in the first batch")

headers_to_seed = batch.headers[:len(seed_values)]
tmp = []
for header in headers_to_seed:
value = seed_values.get(header)
if value is None:
raise RuntimeError(f"The header: {header} is not in the seed values mapping") # noqa
tmp.append(str(value))

return batch.config.field_delimiter.join(tmp) + batch.config.field_delimiter

def generate_batch_lines(
self,
batch_idx: int,
max_invalid=MAX_INVALID,
raise_on_exceed_invalid: bool = False,
num_lines: int = None,
seed_fields: dict = None,
parallelism: int = 0,
) -> bool:
"""Generate lines for a single batch. Lines generated are added
Expand All @@ -411,6 +429,9 @@ def generate_batch_lines(
indicating that the batch failed to generate all lines.
num_lines: The number of lines to generate, if ``None``, then we use the number from the
batch's config
seed_fields: A dictionary that maps field/column names to initial seed values for those columns. This seed
will only apply to the first batch that gets trained and generated. Additionally, the fields provided
in the mapping MUST exist at the front of the first batch.
parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
Expand All @@ -420,6 +441,14 @@ def generate_batch_lines(
batch = self.batches[batch_idx]
except KeyError: # pragma: no cover
raise ValueError("invalid batch index")

seed_string = NEWLINE

# If we are on batch 0 and we have seed values, we want to validate that
# the seed values line up properly with the first N columns.
if batch_idx == 0 and seed_fields is not None:
seed_string = self._validate_batch_seed_values(batch, seed_fields)

batch: Batch
batch.reset_gen_data()
validator = batch.get_validator()
Expand All @@ -430,7 +459,11 @@ def generate_batch_lines(
line: gen_text
try:
for line in generate_text(
batch.config, line_validator=validator, max_invalid=max_invalid, num_lines=num_lines,
batch.config,
line_validator=validator,
max_invalid=max_invalid,
num_lines=num_lines,
start_string=seed_string,
parallelism=parallelism,
):
if line.valid is None or line.valid is True:
Expand All @@ -449,7 +482,11 @@ def generate_batch_lines(
return batch.gen_data_count >= num_lines

def generate_all_batch_lines(
self, max_invalid=MAX_INVALID, raise_on_failed_batch: bool = False, num_lines: int = None,
self,
max_invalid=MAX_INVALID,
raise_on_failed_batch: bool = False,
num_lines: int = None,
seed_fields: dict = None,
parallelism: int = 0,
) -> dict:
"""Generate synthetic lines for all batches. Lines for each batch
Expand All @@ -470,6 +507,9 @@ def generate_all_batch_lines(
will be set to ``False`` in the result dictionary from this method.
num_lines: The number of lines to create from each batch. If ``None`` then the value
from the config template will be used.
seed_fields: A dictionary that maps field/column names to initial seed values for those columns. This seed
will only apply to the first batch that gets trained and generated. Additionally, the fields provided
in the mapping MUST exist at the front of the first batch.
parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
Expand All @@ -491,6 +531,7 @@ def generate_all_batch_lines(
max_invalid=max_invalid,
raise_on_exceed_invalid=raise_on_failed_batch,
num_lines=num_lines,
seed_fields=seed_fields,
parallelism=parallelism,
)
return batch_status
Expand Down
10 changes: 7 additions & 3 deletions src/gretel_synthetics/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import tensorflow as tf

from gretel_synthetics.generator import Generator, Settings
from gretel_synthetics.generator import Generator, Settings, NEWLINE
from gretel_synthetics.generator import gen_text, PredString # noqa # pylint: disable=unused-import
from gretel_synthetics.generate_parallel import get_num_workers, generate_parallel

Expand All @@ -29,7 +29,7 @@

def generate_text(
config: LocalConfig,
start_string: str = "<n>",
start_string: str = NEWLINE,
line_validator: Callable = None,
max_invalid: int = 1000,
num_lines: int = None,
Expand All @@ -41,7 +41,11 @@ def generate_text(
config: A configuration object, which you must have created previously
start_string: A prefix string that is used to seed the record generation.
By default we use a newline, but you may substitue any initial value here
which will influence how the generator predicts what to generate.
which will influence how the generator predicts what to generate. If you
are working with a field delimiter, and you want to seed more than one column
value, then you MUST utilize the field delimiter specified in your config.
An example would be "foo,bar,baz,". Also, if using a field delimiter, the string
MUST end with the delimiter value.
line_validator: An optional callback validator function that will take
the raw string value from the generator as a single argument. This validator
can executue arbitrary code with the raw string value. The validator function
Expand Down
70 changes: 63 additions & 7 deletions src/gretel_synthetics/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
BaseConfig = None
LocalConfig = None

NEWLINE = "<n>"


class GenerationError(Exception):
pass


def _load_tokenizer(store: LocalConfig) -> spm.SentencePieceProcessor:
sp = spm.SentencePieceProcessor()
Expand Down Expand Up @@ -55,13 +61,34 @@ class Settings:
This class contains basic settings for a generation process. It is separated from the Generator class
for ensuring reliable serializability without an excess amount of code tied to it.
This class also will take a provided start string and validate that it can be utilized for text
generation. If the ``start_string`` is something other than the default, we have to do a couple things:
- If the config utilizes a field delimiter, the ``start_string`` MUST end with that delimiter
- Convert the user-facing delim char into the special delim token specified in the config
"""

config: LocalConfig
start_string: str = "<n>"
start_string: str = NEWLINE
line_validator: Optional[Callable] = None
max_invalid: int = 1000

def __post_init__(self):
if self.start_string != NEWLINE:
self._process_start_string()

def _process_start_string(self):
if not isinstance(self.start_string, str):
raise GenerationError("Seed start_string must be a str!")
if self.config.field_delimiter is not None:
# the start_string must end with the delim
if not self.start_string.endswith(self.config.field_delimiter):
raise GenerationError(f"start_string must end with the specified field delimiter: {self.config.field_delimiter}") # noqa
self.start_string = self.start_string.replace(
self.config.field_delimiter,
self.config.field_delimiter_token
)


@dataclass
class gen_text:
Expand Down Expand Up @@ -118,6 +145,11 @@ class Generator:
Args:
settings: the generator settings to use.
NOTE:
If the ``settings`` object has a non-default ``start_string`` set, then that ``start_string`` MUST have
already had special tokens inserted. This should generally be handled during the construction of the Settings
object.
"""
settings: Settings
model: tf.keras.Sequential
Expand Down Expand Up @@ -197,6 +229,23 @@ def compiled_predict_and_sample(input_eval):
compiled_predict_and_sample)


def _replace_decoded_tokens(batch_decoded, store: BaseConfig, prefix: str = None) -> List[Tuple[int, str]]:
"""Given a decoded predicted string, that contains special tokens for things like field
delimiters, we restore those tokens back to the original char they were previously.
Additionally, if a ``start_string`` was provided to seed the generation, we need to restore
the delim tokens in that start string and preprend it to the predicted string.
"""
out = []
for i, decoded in batch_decoded:
if store.field_delimiter is not None:
decoded = decoded.replace(store.field_delimiter_token, store.field_delimiter)
if prefix is not None:
decoded = "".join([prefix, decoded])
out.append((i, decoded))
return out


def _predict_chars(
model: tf.keras.Sequential,
sp: spm.SentencePieceProcessor,
Expand All @@ -210,7 +259,8 @@ def _predict_chars(
Args:
model: tf.keras.Sequential model
sp: SentencePiece tokenizer
start_string: string to bootstrap model
start_string: string to bootstrap model. NOTE: this string MUST already have had special tokens
inserted (i.e. <d>)
store: our config object
Returns:
Yields line of text per iteration
Expand All @@ -230,19 +280,25 @@ def predict_and_sample(this_input):

model.reset_states()

# if the start string is not the default newline, then we create a prefix string
# that we will append to each decoded prediction
prediction_prefix = None
if start_string != NEWLINE:
if store.field_delimiter is not None:
prediction_prefix = start_string.replace(store.field_delimiter_token, store.field_delimiter)
else:
prediction_prefix = start_string

while not_done:
input_eval = predict_and_sample(input_eval)
for i in not_done:
batch_sentence_ids[i].append(int(input_eval[i, 0].numpy()))

batch_decoded = [(i, sp.DecodeIds(batch_sentence_ids[i])) for i in not_done]
if store.field_delimiter is not None:
batch_decoded = [(i, decoded.replace(
store.field_delimiter_token, store.field_delimiter
)) for i, decoded in batch_decoded]
batch_decoded = _replace_decoded_tokens(batch_decoded, store, prediction_prefix)

for i, decoded in batch_decoded:
end_idx = decoded.find("<n>")
end_idx = decoded.find(NEWLINE)
if end_idx >= 0:
decoded = decoded[:end_idx]
yield PredString(decoded)
Expand Down
6 changes: 3 additions & 3 deletions src/gretel_synthetics/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from gretel_synthetics.model import build_sequential_model, compute_epsilon
from gretel_synthetics.config import BaseConfig, VAL_ACC, VAL_LOSS
from gretel_synthetics.generator import _load_model
from gretel_synthetics.generator import _load_model, NEWLINE


spm_logger = logging.getLogger("sentencepiece")
Expand Down Expand Up @@ -219,7 +219,7 @@ def _annotate_training_data(store: BaseConfig):
labeled_text = ""
with open(store.training_data, "w") as f:
for sample in training_text:
chunk = f"{sample}<n>\n"
chunk = f"{sample}{NEWLINE}\n"
f.write(chunk)
labeled_text += chunk
logging.info(
Expand All @@ -246,7 +246,7 @@ def _train_tokenizer(store: BaseConfig) -> spm.SentencePieceProcessor:
spm.SentencePieceTrainer.Train(
input=store.training_data,
model_prefix=store.tokenizer_prefix,
user_defined_symbols=["<n>", store.field_delimiter_token],
user_defined_symbols=[NEWLINE, store.field_delimiter_token],
vocab_size=store.vocab_size,
hard_vocab_limit=False,
max_sentence_length=store.max_line_len,
Expand Down
73 changes: 73 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def test_generate_all_batch_lines_raise_on_failed(test_data):
"raise_on_exceed_invalid": False,
"num_lines": None,
"parallelism": 0,
"seed_fields": None
}

batches.generate_batch_lines = Mock()
Expand All @@ -252,6 +253,7 @@ def test_generate_all_batch_lines_raise_on_failed(test_data):
"raise_on_exceed_invalid": True,
"num_lines": 5,
"parallelism": 0,
"seed_fields": None
}


Expand Down Expand Up @@ -285,3 +287,74 @@ def test_read_mode(test_data):
assert write_batch.headers == read_batch.headers
assert asdict(write_batch.config) == asdict(read_batch.config)
assert reader.master_header_list == writer.master_header_list


def test_validate_seed_lines_too_many_fields(test_data):
batches = DataFrameBatch(df=test_data, config=config_template, batch_size=3)

with pytest.raises(RuntimeError) as err:
batches._validate_batch_seed_values(
batches.batches[0],
{
"ID_Code": "foo",
"target": 0,
"var_0": 33,
"var_1": 33
}
)
assert "number of seed fields" in str(err.value)



def test_validate_seed_lines_field_not_present(test_data):
batches = DataFrameBatch(df=test_data, config=config_template, batch_size=3)

with pytest.raises(RuntimeError) as err:
batches._validate_batch_seed_values(
batches.batches[0],
{
"ID_code": "foo",
"target": 0,
"var_1": 33,
}
)
assert "The header: var_0 is not in the seed" in str(err.value)


def test_validate_seed_lines_ok_full_size(test_data):
batches = DataFrameBatch(df=test_data, config=config_template, batch_size=3)

check = batches._validate_batch_seed_values(
batches.batches[0],
{
"ID_code": "foo",
"target": 0,
"var_0": 33,
}
)
assert check == "foo|0|33|"


def test_validate_seed_lines_ok_one_field(test_data):
batches = DataFrameBatch(df=test_data, config=config_template, batch_size=3)

check = batches._validate_batch_seed_values(
batches.batches[0],
{
"ID_code": "foo",
}
)
assert check == "foo|"


def test_validate_seed_lines_ok_two_field(test_data):
batches = DataFrameBatch(df=test_data, config=config_template, batch_size=3)

check = batches._validate_batch_seed_values(
batches.batches[0],
{
"ID_code": "foo",
"target": 1
}
)
assert check == "foo|1|"

0 comments on commit 0d26b6f

Please sign in to comment.