# Generating synthetic data

This notebook walks through training a probabilistic, generative RNN model
on a rental scooter location dataset, and then generating a synthetic
dataset with greater privacy guarantees. 

For both training and generating data, we can use the ``config.py`` module and
create a ``LocalConfig`` instance that contains all the attributes that we need
for both activities.

In the below example, we will create a config that can work on a CPU. Performing 
operations on a GPU is recommended with more complex settings.

In [1]:
# pip install gretel-synthetics

In [2]:
# Change to project directory
from pathlib import Path
import os

current_path = Path.cwd()
if current_path.parts[-1] == 'examples':
    working_dir = Path(Path.cwd()).parents[0]
    src_dir = working_dir / 'src'
    os.chdir(src_dir)

!pwd

/Users/redlined/Documents/GitHub/gretel-synthetics/src


In [3]:
import os
from gretel_synthetics.config import LocalConfig

# Create a config that we can use for both training and generating, with CPU-friendly settings
# The default values for ``max_chars`` and ``epochs`` are better suited for GPUs

config = LocalConfig(
    max_chars=100000,  # friendly towards CPUs
    epochs=15,  # friendly towards CPUs
    gen_chars=0, # the maximum number of characters possible per-generated line of text
    gen_lines=100, # the number of generated text lines
    rnn_units=256, # dimensionality of LSTM output space
    batch_size=64, # batch size
    buffer_size=1000, # buffer size to shuffle the dataset
    dropout_rate=0.2, # fraction of the inputs to drop
    dp=True, # let's use differential privacy
    dp_learning_rate=0.015, # learning rate
    dp_noise_multiplier=1.1, # control how much noise is added to gradients
    dp_l2_norm_clip=1.0, # bound optimizer's sensitivity to individual training points
    dp_microbatches=256, # split batches into minibatches for parallelism
    checkpoint_dir=os.path.join(working_dir, 'checkpoints'),
    training_data=os.path.join(working_dir, 'examples/data/uber_scooter_rides_1day.csv')
)

In [4]:
# Train a model
# The training function only requires our config as a single arg
from gretel_synthetics.train import train_rnn

#train_rnn(config)

  (tf.__version__, ', '.join(versions)))


In [5]:
# Let's generate some text!
#
# The ``generate_text`` funtion is a generator that will return
# a line of predicted text based on the ``gen_lines`` setting in your
# config.
#
# There is no limit on the line length as with proper training, your model
# should learn where newlines generally occur. However, if you want to
# specify a maximum char len for each line, you may set the ``gen_chars``
# attribute in your config object
from gretel_synthetics.generate import generate_text

# Optionally, when generating text, you can provide a callable that takes the 
# generated line as a single arg. If this function raises any errors, the 
# line will fail validation and will not be returned.  The exception message
# will be provided as a ``explain`` field in the resulting dict that gets
# created by ``generate_text``
def validate_record(line):
    rec = line.split(", ")
    if len(rec) == 6:
        float(rec[5])
        float(rec[4])
        float(rec[3])
        float(rec[2])
        int(rec[0])
    else:
        raise Exception('record not 6 parts')
        
for line in generate_text(config, line_validator=validate_record):
    print(line)


2020-03-17 17:00:54,117 : MainThread : INFO : Latest checkpoint: /Users/redlined/Documents/GitHub/gretel-synthetics/checkpoints/ckpt_15
2020-03-17 17:00:54,118 : MainThread : INFO : Loading SentencePiece tokenizer
2020-03-17 17:00:54,470 : MainThread : INFO : Utilizing differential privacy in optimizer
2020-03-17 17:00:54,472 : MainThread : INFO : None


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (1, None, 256)            128000    
_________________________________________________________________
dropout (Dropout)            (1, None, 256)            0         
_________________________________________________________________
lstm (LSTM)                  (1, None, 256)            525312    
_________________________________________________________________
dropout_1 (Dropout)          (1, None, 256)            0         
_________________________________________________________________
lstm_1 (LSTM)                (1, None, 256)            525312    
_________________________________________________________________
dropout_2 (Dropout)          (1, None, 256)            0         
_________________________________________________________________
dense (Dense)                (1, None, 500)            1

{'valid': True, 'text': '16, 29928, 38.576585, -121.501841, 38.573711, -121.463351', 'explain': None}
{'valid': True, 'text': '13, 54951, 37.719593, -122.42825, 37.778366, -122.406816', 'explain': None}
{'valid': True, 'text': '0, 12922, 37.792286, -122.421713, 37.767493, -122.422755', 'explain': None}
{'valid': False, 'text': '6, BQU602, 34.058011, -118.329066, 34.061848, -118.299856, 34.029668, -118.50692', 'explain': 'record not 6 parts'}
{'valid': True, 'text': '20, OLX820, 34.036233, -118.433851, 34.031006, -118.433735', 'explain': None}
{'valid': True, 'text': '5, 30411, 34.027801, -118.50206, 34.037171, -118.492075', 'explain': None}
{'valid': True, 'text': '22, HMS554, 38.559681, -121.454175, 38.590566, -121.461841', 'explain': None}
{'valid': True, 'text': '5, CDQ265, 34.073625, -118.29199, 34.062088, -118.303166', 'explain': None}
{'valid': True, 'text': '15, YIJ105, 34.050821, -118.344206, 34.066261, -118.30198', 'explain': None}
{'valid': True, 'text': '22, RPW109, 34.03669