Skip to content

Commit

Permalink
ENGPROD-33: Increase log verbosity for synthetics training
Browse files Browse the repository at this point in the history
GitOrigin-RevId: fc43866916d034fc9c41d924346c6ec8d5498cf0
  • Loading branch information
pimlock committed Mar 11, 2022
1 parent 90ae80a commit d8be88e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
11 changes: 11 additions & 0 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,9 +1129,20 @@ def train_all_batches(self):
"""Train a model for each batch."""
if self.mode == READ: # pragma: no cover
raise RuntimeError("Method cannot be used in read-only mode")

self._log_batches()
for idx in self.batches.keys():
self.train_batch(idx)

def _log_batches(self):
batch_sizes = ", ".join(str(len(b.headers)) for b in self.batches.values())
batch_sizes = f"[{batch_sizes}]"

logger.info(
f"Running training on {len(self.batches)} batches.",
extra={"user_log": True, "ctx": {"batch_sizes": batch_sizes}},
)

def set_batch_validator(self, batch_idx: int, validator: Callable):
"""Set a validator for a specific batch. If a validator is configured
for a batch, each generated record from that batch will be sent
Expand Down
18 changes: 10 additions & 8 deletions src/gretel_synthetics/tensorflow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
spm_logger = logging.getLogger("sentencepiece")
spm_logger.setLevel(logging.INFO)

logger = logging.getLogger(__name__)

logging.basicConfig(
format="%(asctime)s : %(threadName)s : %(levelname)s : %(message)s",
level=logging.INFO,
Expand Down Expand Up @@ -194,15 +196,15 @@ def _save_history_csv(
# Log differential privacy settings from best training checkpoint
epsilon = df.at[best_idx, "epsilon"]
delta = df.at[best_idx, "delta"]
logging.warning(
logger.warning(
f"Model satisfies differential privacy with epsilon ε={epsilon:.2f} "
f"and delta δ={delta:.6f}"
)
else:
df.drop(["epsilon", "delta"], axis=1, inplace=True)

save_path = Path(save_dir) / "model_history.csv"
logging.info(f"Saving model history to {save_path.name}")
logger.info(f"Saving model history to {save_path.name}")
df.to_csv(save_path.as_posix(), index=False)


Expand Down Expand Up @@ -243,7 +245,7 @@ def train_rnn(params: TrainingParams):
num_batches_train, validation_dataset, training_dataset = _create_dataset(
store, text_iter, num_lines, tokenizer
)
logging.info("Initializing synthetic model")
logger.info("Initializing synthetic model", extra={"user_log": True})
model = build_model(
vocab_size=tokenizer.total_vocab_size,
batch_size=store.batch_size,
Expand Down Expand Up @@ -320,7 +322,7 @@ def train_rnn(params: TrainingParams):
store.best_model_metric,
best_val,
)
logging.info(f"Saving model to {tf.train.latest_checkpoint(store.checkpoint_dir)}")
logger.info(f"Saving model to {tf.train.latest_checkpoint(store.checkpoint_dir)}")


def _create_dataset(
Expand All @@ -334,13 +336,13 @@ def _create_dataset(
Create two lookup tables: one mapping characters to numbers,
and another for numbers to characters.
"""
logging.info("Tokenizing input data")
logger.info("Tokenizing input data", extra={"user_log": True})
ids = []
for line in tqdm(text_iter, total=num_lines):
_tokens = tokenizer.encode_to_ids(line)
ids.extend(_tokens)

logging.info("Shuffling input data")
logger.info("Shuffling input data", extra={"user_log": True})
char_dataset = tf.data.Dataset.from_tensor_slices(ids)
sequences = char_dataset.batch(store.seq_length + 1, drop_remainder=True)
full_dataset = (
Expand All @@ -365,13 +367,13 @@ def recover(x, y):
return y

if store.validation_split:
logging.info("Creating validation dataset")
logger.info("Creating validation dataset", extra={"user_log": True})
validation_dataset = (
full_dataset.enumerate()
.filter(is_validation)
.map(recover, num_parallel_calls=tf.data.AUTOTUNE)
)
logging.info("Creating training dataset")
logger.info("Creating training dataset", extra={"user_log": True})
train_dataset = (
full_dataset.enumerate()
.filter(is_train)
Expand Down

0 comments on commit d8be88e

Please sign in to comment.