Skip to content

Commit

Permalink
Added read and write modes for batch interface (#38)
Browse files Browse the repository at this point in the history
* Added read and write modes for batch interface

* Update some defaults

Co-authored-by: John Myers <john@gretel.ai>
  • Loading branch information
johntmyers and John Myers committed Jul 7, 2020
1 parent d8394b7 commit ad1f7ec
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 48 deletions.
44 changes: 42 additions & 2 deletions examples/dataframe_batch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -60,6 +60,8 @@
"source": [
"from pathlib import Path\n",
"\n",
"checkpoint_dir = str(Path.cwd() / \"checkpoints\")\n",
"\n",
"config_template = {\n",
" \"max_lines\": 0,\n",
" \"max_line_len\": 2048,\n",
Expand All @@ -69,7 +71,7 @@
" \"dp\": True,\n",
" \"field_delimiter\": \",\",\n",
" \"overwrite\": True,\n",
" \"checkpoint_dir\": str(Path.cwd() / \"checkpoints\")\n",
" \"checkpoint_dir\": checkpoint_dir\n",
"}"
]
},
Expand Down Expand Up @@ -152,6 +154,44 @@
"# Finally, we can re-assemble all synthetic batches into our new synthetic DF\n",
"batcher.batches_to_df()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Read only mode\n",
"\n",
"If you've already created a model(s) and simply want to load that data to generate more lines, you can use the read-only mode for the batch interface. No input DataFrame is required and it will automatically try and load model information from a primary checkpoint directory.\n",
"\n",
"Additionally, you can also control the number of lines you wish to generate with the ``num_lines`` parameter for generation. This option exists for write mode as well and overrides the number of lines specified in the synthetic config that was used."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"read_batch = DataFrameBatch(mode=\"read\", checkpoint_dir=checkpoint_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"read_batch.generate_all_batch_lines(num_lines=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"read_batch.batches_to_df()"
]
}
],
"metadata": {
Expand Down
192 changes: 157 additions & 35 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from copy import deepcopy
import logging
import io
import json
import glob

import pandas as pd
import numpy as np
Expand All @@ -35,6 +37,11 @@
BATCH_SIZE = 15
FIELD_DELIM = "field_delimiter"
GEN_LINES = "gen_lines"
READ = "read"
WRITE = "write"
HEADER_FILE = "headers.json"
CONFIG_FILE = "model_params.json"
TRAIN_FILE = "train.csv"


@dataclass
Expand All @@ -44,6 +51,7 @@ class Batch:
such as ``DataFrameBatch``. This class holds all of the necessary information
for training, data generation and DataFrame re-assembly.
"""

checkpoint_dir: str
input_data_path: str
headers: List[str]
Expand Down Expand Up @@ -119,6 +127,52 @@ def get_validator(self):
return self._basic_validator


def _create_batch_from_dir(batch_dir: str):
path = Path(batch_dir)
if not path.is_dir(): # pragma: no cover
raise ValueError("%s is not a directory" % batch_dir)

if not (path / HEADER_FILE).is_file(): # pragma: no cover
raise ValueError("missing headers")
headers = json.loads(open(path / HEADER_FILE).read())

if not (path / CONFIG_FILE).is_file(): # pragma: no cover
raise ValueError("missing model param file")
config = json.loads(open(path / CONFIG_FILE).read())

if not (path / TRAIN_FILE).is_file(): # pragma: no cover
raise ValueError("missing training data")
train_path = str(path / TRAIN_FILE)

batch = Batch(
checkpoint_dir=batch_dir,
input_data_path=train_path,
headers=headers,
config=LocalConfig(**config),
)

batch.load_validator_from_file()

return batch


def _crawl_checkpoint_for_batches(checkpoint_dir: str):
logger.info("Looking for and loading batch data...")
matching_dirs = glob.glob(str(Path(checkpoint_dir) / "batch_*"))
if not matching_dirs:
raise ValueError(
"checkpoint directory does not exist or does not contain batch data"
)

batches = []
for batch_dir in matching_dirs:
idx = int(Path(batch_dir).name.split("_")[-1])
batches.append((idx, _create_batch_from_dir(batch_dir)))

logger.info("Found and loaded %d batches", len(batches))
return dict(sorted(batches, key=lambda b: b[0]))


def _build_batch_dirs(
base_ckpoint: str, headers: List[List[str]], config: dict
) -> dict:
Expand Down Expand Up @@ -148,6 +202,12 @@ def _build_batch_dirs(
# try and load any previously saved validators
out[i].load_validator_from_file()

# we write the headers out as well incase we load these
# batches back in via "read" mode only later...it's the only
# way to get the header names back
with open(ckpoint / "headers.json", "w") as fout:
fout.write(json.dumps(headers))

return out


Expand Down Expand Up @@ -194,40 +254,62 @@ class DataFrameBatch:
def __init__(
self,
*,
df: pd.DataFrame,
df: pd.DataFrame = None,
batch_size: int = BATCH_SIZE,
batch_headers: List[List[str]] = None,
config: dict = None
config: dict = None,
mode: str = WRITE,
checkpoint_dir: str = None,
):

if not config:
raise ValueError("config is required!")
if mode not in (WRITE, READ): # pragma: no cover
raise ValueError("mode must be read or write")

self.mode = mode

if not isinstance(df, pd.DataFrame):
raise ValueError("df must be a Data Frame")
if self.mode == READ:
if isinstance(config, dict):
_ckpoint_dir = config.get("checkpoint_dir")
else:
_ckpoint_dir = checkpoint_dir

if FIELD_DELIM not in config:
raise ValueError("field_delimiter must be in config")
if _ckpoint_dir is None:
raise ValueError("checkpoint_dir required for read mode")
else:
self._read_checkpoint_dir = _ckpoint_dir

if GEN_LINES not in config:
config[GEN_LINES] = df.shape[0]
if self.mode == WRITE:
if not config:
raise ValueError("config is required!")

self._source_df = df
self.batch_size = batch_size
self.config = config
if not isinstance(df, pd.DataFrame):
raise ValueError("df must be a DataFrame in write mode")

self._source_df.fillna("", inplace=True)
if FIELD_DELIM not in config:
raise ValueError("field_delimiter must be in config")

self.master_header_list = list(self._source_df.columns)
if GEN_LINES not in config:
config[GEN_LINES] = df.shape[0]

if not batch_headers:
self.batch_headers = self._create_header_batches()
else:
self.batch_headers = batch_headers
self._source_df = df
self.batch_size = batch_size
self.config = config
self._source_df.fillna("", inplace=True)
self.master_header_list = list(self._source_df.columns)

self.batches = _build_batch_dirs(
self.config["checkpoint_dir"], self.batch_headers, self.config
)
if not batch_headers:
self.batch_headers = self._create_header_batches()
else: # pragma: no cover
self.batch_headers = batch_headers

self.batches = _build_batch_dirs(
self.config["checkpoint_dir"], self.batch_headers, self.config
)
else:
self.batches = _crawl_checkpoint_for_batches(self._read_checkpoint_dir)
self.master_header_list = []
for batch in self.batches.values():
self.master_header_list.extend(batch.headers)

def _create_header_batches(self):
num_batches = ceil(len(self._source_df.columns) / self.batch_size)
Expand All @@ -246,6 +328,8 @@ def create_training_data(self):
Finally, a training CSV is written to disk in the specific
batch directory
"""
if self.mode == READ: # pragma: no cover
raise RuntimeError("Method cannot be used in read-only mode")
for i, batch in self.batches.items():
logger.info(f"Generating training DF and CSV for batch {i}")
out_df = self._source_df[batch.headers]
Expand All @@ -264,6 +348,8 @@ def train_batch(self, batch_idx: int):
Args:
batch_idx: The index of the batch, from the ``batches`` dictionary
"""
if self.mode == READ: # pragma: no cover
raise RuntimeError("Method cannot be used in read-only mode")
try:
train_rnn(self.batches[batch_idx].config)
except KeyError:
Expand All @@ -272,6 +358,8 @@ def train_batch(self, batch_idx: int):
def train_all_batches(self):
"""Train a model for each batch.
"""
if self.mode == READ: # pragma: no cover
raise RuntimeError("Method cannot be used in read-only mode")
for idx in self.batches.keys():
self.train_batch(idx)

Expand All @@ -286,14 +374,22 @@ def set_batch_validator(self, batch_idx: int, validator: Callable):
which will be the raw line generated from the ``generate_text``
function.
"""
if self.mode == READ: # pragma: no cover
raise RuntimeError("Method cannot be used in read-only mode")
if not callable(validator):
raise ValueError("validator must be callable!")
try:
self.batches[batch_idx].set_validator(validator)
except KeyError:
raise ValueError("invalid batch number!")

def generate_batch_lines(self, batch_idx: int, max_invalid=MAX_INVALID):
def generate_batch_lines(
self,
batch_idx: int,
max_invalid=MAX_INVALID,
raise_on_exceed_invalid: bool = False,
num_lines: int = None,
) -> bool:
"""Generate lines for a single batch. Lines generated are added
to the underlying ``Batch`` object for each batch. The lines
can be accessed after generation and re-assembled into a DataFrame.
Expand All @@ -302,6 +398,12 @@ def generate_batch_lines(self, batch_idx: int, max_invalid=MAX_INVALID):
batch_idx: The batch number
max_invalid: The max number of invalid lines that can be generated, if
this is exceeded, generation will stop
raise_on_exceed_invalid: If true and if the number of lines generated exceeds the ``max_invalid``
amount, we will re-raise the error thrown by the generation module which will interrupt
the running process. Otherwise, we will not raise the caught exception and just return ``False``
indicating that the batch failed to generate all lines.
num_lines: The number of lines to generate, if ``None``, then we use the number from the
batch's config
"""
try:
batch = self.batches[batch_idx]
Expand All @@ -310,23 +412,33 @@ def generate_batch_lines(self, batch_idx: int, max_invalid=MAX_INVALID):
batch: Batch
batch.reset_gen_data()
validator = batch.get_validator()
t = tqdm(total=batch.config.gen_lines, desc="Valid record count ")
if num_lines is None:
num_lines = batch.config.gen_lines
t = tqdm(total=num_lines, desc="Valid record count ")
t2 = tqdm(total=max_invalid, desc="Invalid record count ")
line: gen_text
for line in generate_text(
batch.config, line_validator=validator, max_invalid=max_invalid
):
if line.valid is None or line.valid is True:
batch.add_valid_data(line)
t.update(1)
try:
for line in generate_text(
batch.config, line_validator=validator, max_invalid=max_invalid, num_lines=num_lines
):
if line.valid is None or line.valid is True:
batch.add_valid_data(line)
t.update(1)
else:
t2.update(1)
batch.gen_data_invalid.append(line)
except RuntimeError:
if raise_on_exceed_invalid:
raise
else:
t2.update(1)
batch.gen_data_invalid.append(line)
return False
t.close()
t2.close()
return batch.gen_data_count == batch.config.gen_lines
return batch.gen_data_count == num_lines

def generate_all_batch_lines(self, max_invalid=MAX_INVALID) -> dict:
def generate_all_batch_lines(
self, max_invalid=MAX_INVALID, raise_on_failed_batch: bool = False, num_lines: int = None
) -> dict:
"""Generate synthetic lines for all batches. Lines for each batch
are added to the individual ``Batch`` objects. Once generateion is
done, you may re-assemble the dataset into a DataFrame.
Expand All @@ -340,6 +452,11 @@ def generate_all_batch_lines(self, max_invalid=MAX_INVALID) -> dict:
Args:
max_invalid: The number of invalid lines, per batch. If this number
is exceeded for any batch, generation will stop.
raise_on_failed_batch: If True, then an exception will be raised if any single batch
fails to generate the requested number of lines. If False, then the failed batch
will be set to ``False`` in the result dictionary from this method.
num_lines: The number of lines to create from each batch. If ``None`` then the value
from the config template will be used.
Returns:
A dictionary of batch number to a bool value that shows if each batch
Expand All @@ -352,7 +469,12 @@ def generate_all_batch_lines(self, max_invalid=MAX_INVALID) -> dict:
"""
batch_status = {}
for idx in self.batches.keys():
batch_status[idx] = self.generate_batch_lines(idx, max_invalid=max_invalid)
batch_status[idx] = self.generate_batch_lines(
idx,
max_invalid=max_invalid,
raise_on_exceed_invalid=raise_on_failed_batch,
num_lines=num_lines
)
return batch_status

def batch_to_df(self, batch_idx: int) -> pd.DataFrame: # pragma: no cover
Expand Down

0 comments on commit ad1f7ec

Please sign in to comment.