In [None]:
%%capture
# This Notebook serves as RFC / demo for a new low level API for synthetic record generation. Currently, when training
# on a DataFrame, we will break the columns up into clusters and train individual models on DataFrames built
# from those clusters. So a 3-cluster model will have 3 actual TF models where each TF model covers a certain subset
# of columns.
#
# When we generate data, let's say 500 records, we generate the 500 records for each batch, buffering them into
# memory and re-creating DataFrames and eventuall concat these DFs together. This uses a growing unbounded amount of
# memory and also makes it challenging to do entire record validation.
#
# This update introduces a new low-level primitive, ``RecordFactory`` that uses an underlying generator to load all
# the TF models into memory initially, then generate partial records sequentially for each model and construct an
# entire record at a time.
#
# For this demo, we'll use an already built model, which you can download and extract to it's own directory, for this
# demo I use "test-model" as the checkpoint dir.
#
# https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/safecast-batch-sp-0-14.tar.gz
#
from gretel_synthetics.batch import DataFrameBatch

batch = DataFrameBatch(mode="read", checkpoint_dir="test-model")

In [None]:
# You can create the factory with a method on the batch object, you can also provide the entire record
# validator directly to the method.

def validator(rec: dict):
    """NOTE: The values of each record will be the raw strings
    that were generated from the NN so you will have to handle
    any type casting.
    """
    assert float(rec["payload.loc_lat"])
    

factory = batch.create_record_factory(num_lines=50, validator=validator)
type(factory)

In [None]:
# The factory is stateful, and is designed to track its current capacity to generate
factory.summary

In [None]:
%%capture
# The entire factory can be treated as an iterator, and it will only provide valid records while still tracking
# the number of invalid records under the hood (NOTE: This model might need a few iterations to generate invalids)
rec = next(factory)

In [None]:
rec

In [None]:
# See that we have incremented our valid count now
factory.summary

In [None]:
# We can exhaust the rest of the underlying record iterator
the_rest = list(factory)
the_rest[0]

In [None]:
len(the_rest)

In [None]:
# Maybe we also experienced some invalid records while generating
factory.summary

In [None]:
# The factory is now exhaused:
assert list(factory) == []

In [None]:
# We can reset it's state so we're ready to generate again:
factory.reset()

In [None]:
# We'll update the validator with a validator that always fails and update our ``max_invalid`` for demo purposes
# This will force the RunTimeError when generating
factory.validator = lambda x: False
factory.max_invalid = 10
list(factory)

# invalid count should match the max invalid now
print(factory.summary)
factory.validator = validator  # reset our original validator

In [None]:
# There's another utility that will auto-reset the factory state and attempt to generate all records. Optionally
# with a specific output type. Currently DFs are supported.
#
# This will buffer records as they are generated, so will consume memory in that sense. When returning a DF, we'll 
# try to infer the dtypes as if we are loading the DF from a CSV on disk.

syn_df = factory.generate_all(output="df")
print(syn_df.head())
syn_df.dtypes

In [None]:
# Next steps
# - Do we still need a progress bar chart to show the number of invalid records? That has confused folks previously.
# - Create some other helper methods directly on the DataFrameBatch to get out a synthetic DF using this factory