Skip to content

Commit

Permalink
Jm/df batch header preserve (#75)
Browse files Browse the repository at this point in the history
* Header order preservation, test generation during model load in read-mode

* In DP mode, globally patch TF to use new Keras code path
  • Loading branch information
johntmyers committed Nov 25, 2020
1 parent 9f09b91 commit 794fde0
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 7 deletions.
47 changes: 45 additions & 2 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io
import json
import glob
import shutil

import pandas as pd
import numpy as np
Expand All @@ -43,6 +44,8 @@
READ = "read"
WRITE = "write"
HEADER_FILE = "headers.json"
ORIG_HEADERS = "original_headers.json"
CHECKPOINT_DIR = "checkpoint_dir"
CONFIG_FILE = "model_params.json"
TRAIN_FILE = "train.csv"
PATH_HOLDER = "___path_holder___"
Expand Down Expand Up @@ -220,7 +223,7 @@ def _build_batch_dirs(
# we write the headers out as well incase we load these
# batches back in via "read" mode only later...it's the only
# way to get the header names back
with open(ckpoint / "headers.json", "w") as fout:
with open(ckpoint / HEADER_FILE, "w") as fout:
fout.write(json.dumps(headers))

return out
Expand Down Expand Up @@ -282,6 +285,18 @@ class DataFrameBatch:

mode: Union[WRITE, READ]

master_header_list: List[str]
"""During training, this is the original column order. When reading from
disk, we concatenate all headers from all batches together. This list is not
guaranteed to preserve the original header order.
"""

original_headers: List[str]
"""Stores the original header list / order from the original training data that was used.
This is written out to the model directory during training and loaded back in when
using read-only mode.
"""

def __init__(
self,
*,
Expand All @@ -308,6 +323,8 @@ def __init__(

self.tokenizer = tokenizer

self.original_headers = None

if self.mode == READ:
if isinstance(config, dict):
_ckpoint_dir = config.get("checkpoint_dir")
Expand All @@ -323,6 +340,14 @@ def __init__(
if not config:
raise ValueError("config is required!")

checkpoint_path = Path(config[CHECKPOINT_DIR])
overwrite = config.get("overwrite", False)
if not overwrite and checkpoint_path.is_dir() and any(checkpoint_path.iterdir()):
raise RuntimeError("checkpoint_dir already exists and is non-empty, set overwrite on config or remove model directory!") # noqa

if overwrite and checkpoint_path.is_dir():
shutil.rmtree(checkpoint_path)

if not isinstance(df, pd.DataFrame):
raise ValueError("df must be a DataFrame in write mode")

Expand All @@ -346,12 +371,30 @@ def __init__(
self.batches = _build_batch_dirs(
self.config["checkpoint_dir"], self.batch_headers, self.config
)

# Preserve the original order of the DF headers
self.original_headers = list(self._source_df)
with open(Path(self.config[CHECKPOINT_DIR]) / ORIG_HEADERS, "w") as fout:
fout.write(json.dumps(list(self.original_headers)))
else:
self.batches = _crawl_checkpoint_for_batches(self._read_checkpoint_dir)
self.master_header_list = []
for batch in self.batches.values():
self.master_header_list.extend(batch.headers)

try:
self.original_headers = json.loads(
open(Path(self._read_checkpoint_dir) / ORIG_HEADERS).read()
)
except FileNotFoundError:
self.original_headers = None

logger.info("Validating underlying models exist via generation test...")
try:
self.generate_all_batch_lines(parallelism=1, num_lines=1)
except Exception as err:
raise RuntimeError("Error testing generation during model load") from err

def _create_header_batches(self):
num_batches = ceil(len(self._source_df.columns) / self.batch_size)
tmp = np.array_split(list(self._source_df.columns), num_batches)
Expand Down Expand Up @@ -625,4 +668,4 @@ def batches_to_df(self) -> pd.DataFrame:
for batch in batch_iter:
accum_df = pd.concat([accum_df, batch.synthetic_df], axis=1)

return accum_df[self.master_header_list]
return accum_df[self.original_headers or self.master_header_list]
1 change: 1 addition & 0 deletions src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import tensorflow as tf


import gretel_synthetics.const as const
from gretel_synthetics.tensorflow.train import train_rnn
from gretel_synthetics.tensorflow.generator import TensorFlowGenerator
Expand Down
10 changes: 10 additions & 0 deletions src/gretel_synthetics/tensorflow/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Tuple, TYPE_CHECKING
import logging
import importlib

import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy
Expand Down Expand Up @@ -29,6 +31,14 @@ def build_dp_model(store, batch_size, vocab_size) -> tf.keras.Sequential:
"""
logging.warning("Experimental: Differentially private training enabled")

recurrent_v2 = importlib.import_module("tensorflow.python.keras.layers.recurrent_v2")
# NOTE: This patches the LSTMs to use the new Keras 2.4.x code paths
# and will have no effect when the module function is removed
use_new_code = getattr(recurrent_v2, "_use_new_code", None)
if use_new_code is not None:
logging.warning("******* Patching TensorFlow to utilize new Keras code paths, see: %s", "https://github.com/tensorflow/tensorflow/issues/44917 *******") # noqa
recurrent_v2._use_new_code = lambda: True # pylint: disable=protected-access

optimizer = make_keras_optimizer_class(RMSprop)(
l2_norm_clip=store.dp_l2_norm_clip,
noise_multiplier=store.dp_noise_multiplier,
Expand Down
16 changes: 11 additions & 5 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import random
from copy import deepcopy
from dataclasses import asdict
import json

import pytest
import pandas as pd

from gretel_synthetics.batch import DataFrameBatch, MAX_INVALID, READ
from gretel_synthetics.batch import DataFrameBatch, MAX_INVALID, READ, ORIG_HEADERS
from gretel_synthetics.generate import GenText
from gretel_synthetics.errors import TooManyInvalidError

Expand All @@ -34,7 +35,7 @@
"dp_l2_norm_clip": 1.0,
"dp_microbatches": 256,
"field_delimiter": "|",
"overwrite": False,
"overwrite": True,
"checkpoint_dir": checkpoint_dir,
}

Expand Down Expand Up @@ -111,6 +112,9 @@ def test_init(test_data):
assert Path(batch.checkpoint_dir).is_dir()
assert Path(batch.checkpoint_dir).name == f"batch_{i}"

orig_headers = json.loads(open(Path(config_template["checkpoint_dir"]) / ORIG_HEADERS).read())
assert list(set(orig_headers)) == list(set(test_data.columns))

batches.create_training_data()
df = pd.read_csv(batches.batches[0].input_data_path, sep=config_template["field_delimiter"])
assert len(df.columns) == len(first_row)
Expand Down Expand Up @@ -184,8 +188,9 @@ def bad():


def test_batches_to_df(test_data):
batches = DataFrameBatch(df=pd.DataFrame([
{"foo": "bar", "foo1": "bar1", "foo2": "bar2", "foo3": 3}]), config=config_template, batch_size=2)
_df = pd.DataFrame([
{"foo": "bar", "foo1": "bar1", "foo2": "bar2", "foo3": 3}])
batches = DataFrameBatch(df=_df, config=config_template, batch_size=2)

batches.batches[0].add_valid_data(
GenText(text="baz|baz1", valid=True, delimiter="|")
Expand Down Expand Up @@ -258,7 +263,8 @@ def test_generate_all_batch_lines_raise_on_failed(test_data):
}


def test_read_mode(test_data):
@patch("gretel_synthetics.batch.DataFrameBatch.generate_all_batch_lines")
def test_read_mode(mock_gen, test_data):
writer = DataFrameBatch(df=test_data, config=config_template)
writer.create_training_data()

Expand Down

0 comments on commit 794fde0

Please sign in to comment.