# DataFrame Batch Training

This notebook explores the new batch training feature in Gretel Synthetics. This interface will create N synthetic training configurations, where N is a specific number of batches of column names. We break down the source DataFrame into smaller DataFrames that have the same number of rows, but only a subset of total columns.

In [None]:
# If you are using Colab, you may wish to mount your Google Drive, once that is done, you can create a symlinked
# directory that you can use to store the checkpoint directories in.
#
# For this example we are using some Google data that can be learned and trained relatively quickly
# 
# NOTE: Gretel Synthetic paths must NOT contain whitespaces, which is why we have to symlink to a more local directory
# in /content. Unfortunately, Google Drive mounts contain whitespaces either in the "My drive" or "Shared drives" portion
# of the path
#
# !ln -s "/content/drive/Shared drives[My Drive]/YOUR_TARGET_DIRECTORY" checkpoints
#
# !pip install -U gretel-synthetics

In [None]:
import pandas as pd
from gretel_synthetics.batch import DataFrameBatch

source_df = pd.read_csv("https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/notebooks/google_marketplace_analytics.csv")

In [None]:
source_df.shape

In [None]:
# Here we create a dict with our config params, these are identical to when creating a LocalConfig object
#
# NOTE: We do not specify a ``input_data_path`` as this is automatically created for each batch

In [None]:
from pathlib import Path

checkpoint_dir = str(Path.cwd() / "checkpoints")

config_template = {
    "max_line_len": 2048,
    "vocab_size": 200000,
    "field_delimiter": ",",
    "overwrite": True,
    "checkpoint_dir": checkpoint_dir
}

In [None]:
# Create our batch handler. During construction, checkpoint directories are automatically created
# based on the configured batch size
batcher = DataFrameBatch(df=source_df, config=config_template)

# Optionally, you can also provide your own batches, which can be a list of lists of strings:
#
# my_batches = [["col1", "col2"], ["col3", "col4", "col5"]]
# batcher = DataFrameBatch(df=source_df, batch_headers=my_batches, config=config_template)

In [None]:
# Next we generate our actual training DataFrames and Training text files
#
# Each batch directory will now have it's own "train.csv" file
# Each Batch object now has a ``training_df`` associated with it
batcher.create_training_data()

In [None]:
# Now we can trigger each batch to train
batcher.train_all_batches()

In [None]:
# Next, we can trigger all batched models to create output. This loops over each model and will attempt to generate
# ``gen_lines`` valid lines for each model. This method returns a dictionary of bools that is indexed by batch number
# and tells us if, for each batch, we were able to generate the requested number of valid lines
status = batcher.generate_all_batch_lines(num_lines=2000)

In [None]:
batcher.batches[2].gen_data_stream.getvalue()

In [None]:
status

In [None]:
# We can grab a DataFrame for each batch index
batcher.batch_to_df(0)

In [None]:
# Finally, we can re-assemble all synthetic batches into our new synthetic DF
batcher.batches_to_df()

# Read only mode

If you've already created a model(s) and simply want to load that data to generate more lines, you can use the read-only mode for the batch interface. No input DataFrame is required and it will automatically try and load model information from a primary checkpoint directory.

Additionally, you can also control the number of lines you wish to generate with the ``num_lines`` parameter for generation. This option exists for write mode as well and overrides the number of lines specified in the synthetic config that was used.

In [None]:
read_batch = DataFrameBatch(mode="read", checkpoint_dir=checkpoint_dir)

In [None]:
read_batch.generate_all_batch_lines(num_lines=5)

In [None]:
read_batch.batches_to_df()