-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(WIP) Batching functionality, doc strings, and notebook param updates (…
…#30) * update config.py * Converting to Google DocString format https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html * docstrings for default training configurations * Updated input data config docstrings * added tokenizer settings docstrings * Updated. line generation behavior based on validator existence * updated differential privacy docstrings * text generation docstrings * updates * formatting, updated gen_lines docstring * simplified notebook configuration settings * simplify model params * updated doc strings to per #27 recommendations * Updated test to match new defaults * Initial DataFrameBatch trainer and sample notebook * Added pip install cmd for the new feature branch * Create default basic validator for each batch * Add counter for valid line generation * Add invalid counter too? * typo * Updated notebook * update * Set generated valid lines to write into a string buffer for more reliable DF creation * Docs updates, format updates, test updates * Enable batch module for Sphinx, notebook and docstring updates * Update default data for notebook * More docs updates, notebook update, test updates * Automatically set gen_lines to num rows if not overridden * Doc updates * newline Co-authored-by: John Myers <john@gretel.ai>
- Loading branch information
Showing
18 changed files
with
16,468 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.9.3 | ||
0.10.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Batch | ||
====== | ||
|
||
.. automodule:: gretel_synthetics.batch | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ Modules | |
api/config.rst | ||
api/train.rst | ||
api/generate.rst | ||
api/batch.rst | ||
|
||
|
||
Indices and tables | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# DataFrame Batch Training\n", | ||
"\n", | ||
"This notebook explores the new batch training feature in Gretel Synthetics. This interface will create N synthetic training configurations, where N is a specific number of batches of column names. We break down the source DataFrame into smaller DataFrames that have the same number of rows, but only a subset of total columns." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# If you are using Colab, you may wish to mount your Google Drive, once that is done, you can create a symlinked\n", | ||
"# directory that you can use to store the checkpoint directories in.\n", | ||
"#\n", | ||
"# For this example we are using some Google data that can be learned and trained relatively quickly\n", | ||
"# \n", | ||
"# NOTE: Gretel Synthetic paths must NOT contain whitespaces, which is why we have to symlink to a more local directory\n", | ||
"# in /content. Unfortunately, Google Drive mounts contain whitespaces either in the \"My drive\" or \"Shared drives\" portion\n", | ||
"# of the path\n", | ||
"#\n", | ||
"# !ln -s \"/content/drive/Shared drives[My Drive]/YOUR_TARGET_DIRECTORY\" checkpoints\n", | ||
"#\n", | ||
"# !pip install -U gretel-synthetics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"from gretel_synthetics.batch import DataFrameBatch\n", | ||
"\n", | ||
"source_df = pd.read_csv(\"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/notebooks/google_marketplace_analytics.csv\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Here we create a dict with our config params, these are identical to when creating a LocalConfig object\n", | ||
"#\n", | ||
"# NOTE: We do not specify a ``input_data_path`` as this is automatically created for each batch" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"\n", | ||
"config_template = {\n", | ||
" \"max_lines\": 0,\n", | ||
" \"max_line_len\": 2048,\n", | ||
" \"epochs\": 15,\n", | ||
" \"vocab_size\": 20000,\n", | ||
" \"gen_lines\": 100,\n", | ||
" \"dp\": True,\n", | ||
" \"field_delimiter\": \",\",\n", | ||
" \"overwrite\": True,\n", | ||
" \"checkpoint_dir\": str(Path.cwd() / \"checkpoints\")\n", | ||
"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create our batch handler. During construction, checkpoint directories are automatically created\n", | ||
"# based on the configured batch size\n", | ||
"batcher = DataFrameBatch(df=source_df, batch_size=1, config=config_template)\n", | ||
"\n", | ||
"# Optionally, you can also provide your own batches, which can be a list of lists of strings:\n", | ||
"#\n", | ||
"# my_batches = [[\"col1\", \"col2\"], [\"col3\", \"col4\", \"col5\"]]\n", | ||
"# batcher = DataFrameBatch(df=source_df, batch_headers=my_batches, config=config_template)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Next we generate our actual training DataFrames and Training text files\n", | ||
"#\n", | ||
"# Each batch directory will now have it's own \"train.csv\" file\n", | ||
"# Each Batch object now has a ``training_df`` associated with it\n", | ||
"batcher.create_training_data()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Now we can trigger each batch to train\n", | ||
"batcher.train_all_batches()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Next, we can trigger all batched models to create output. This loops over each model and will attempt to generate\n", | ||
"# ``gen_lines`` valid lines for each model. This method returns a dictionary of bools that is indexed by batch number\n", | ||
"# and tells us if, for each batch, we were able to generate the requested number of valid lines\n", | ||
"status = batcher.generate_all_batch_lines()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"status" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# We can grab a DataFrame for each batch index\n", | ||
"batcher.batch_to_df(0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Finally, we can re-assemble all synthetic batches into our new synthetic DF\n", | ||
"batcher.batches_to_df()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.