# Generating synthetic data

This notebook walks through training a probabilistic, generative RNN model<br>
on a rental scooter location dataset, and then generating a synthetic<br>
dataset with greater privacy guarantees. 

For both training and generating data, we can use the ``config.py`` module and<br>
create a ``LocalConfig`` instance that contains all the attributes that we need<br>
for both activities.

In [None]:
# Google Colab support
# Note: Click "Runtime->Change Runtime Type" set Hardware Accelerator to "GPU"
# Note: Use pip install gretel-synthetics[tf] to install tensorflow if necessary
# 
#!pip install gretel-synthetics --upgrade

In [1]:
from pathlib import Path

from gretel_synthetics.config import LocalConfig

# Create a config that we can use for both training and generating data
# The default values for ``max_lines`` and ``epochs`` are optimized for training on a GPU.

config = LocalConfig(
    max_lines=0,         # maximum lines of training data. Set to ``0`` to train on entire file
    max_line_len=2048,   # the max line length for input training data
    epochs=15,           # 15-50 epochs with GPU for best performance
    vocab_size=20000,    # tokenizer model vocabulary size
    gen_lines=1000,      # the number of generated text lines
    dp=True,             # train with differential privacy enabled (privacy assurances, but reduced accuracy)
    field_delimiter=",", # specify if the training text is structured, else ``None``
    overwrite=True,      # overwrite previously trained model checkpoints
    checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),
    input_data_path="https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uber_scooter_rides_1day.csv" # filepath or S3
)


In [None]:
# Train a model
# The training function only requires our config as a single arg
from gretel_synthetics.train import train_rnn

train_rnn(config)

In [2]:
# Let's generate some text!
#
# The ``generate_text`` funtion is a generator that will return
# a line of predicted text based on the ``gen_lines`` setting in your
# config.
#
# There is no limit on the line length as with proper training, your model
# should learn where newlines generally occur. However, if you want to
# specify a maximum char len for each line, you may set the ``gen_chars``
# attribute in your config object
from gretel_synthetics.generate import generate_text

# Optionally, when generating text, you can provide a callable that takes the 
# generated line as a single arg. If this function raises any errors, the 
# line will fail validation and will not be returned.  The exception message
# will be provided as a ``explain`` field in the resulting dict that gets
# created by ``generate_text``
def validate_record(line):
    rec = line.split(", ")
    if len(rec) == 6:
        float(rec[5])
        float(rec[4])
        float(rec[3])
        float(rec[2])
        int(rec[0])
    else:
        raise Exception('record not 6 parts')
        
for line in generate_text(config, line_validator=validate_record):
    print(line)

2020-08-04 14:33:14,009 : MainThread : INFO : Latest checkpoint: /Users/jtm/gretel/gretel-synthetics/examples/checkpoints/synthetic


gen_text(valid=False, text='4,MLD092,39.730938,-121.4647241,34.056905', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='7,211,38.543638', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='1', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='13,44118,-121.478581,-97.747006', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='22,10175,-97.739991,33.9980486', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='22,45052,-97.742016,38.887371,-97.738651,-122.33708', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='18,MOG294,-122.399096', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,31452,37.761175,-97.740716', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='4,JDJ641', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='1,44449,-122.397001', explain='record not 6 parts', delimiter

gen_text(valid=False, text='18,,S48389,-97.7178418', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='15,JZY080,MSS68,27.945488', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,8177,30.270723,-121.481176', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='23,8804,266393,30.281738,-118.420725', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='22,ZKU372,-122.037878,36.971041', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='224,KL30540,32.678995', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='16,82,-122.39518,37.754475', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='22,YCJ494', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='3,743,-122.3904823,47.592495', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='14,QCY514,38.5394931,38.900015', explain='record not 6 parts'

gen_text(valid=False, text='8,8177,-121.505468,4991', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='16,12366,-117.120596', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='7,OLE65,-121.757481', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='12,SBT897,34.024106', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='23,AYS747,38.575291,-121.7515381,445,34.0980131,47.662128,-122.34268,39.303331', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='18,YNU269,-121.490023,-118.428361,-122.326891,774723', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='23,XAH016,38.9232998,23,30.2841748,33.986259,251588', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='2,37453,38.868211,-121.499311,38.5821343,-121.494878,38.579351', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,EMS396,47.596203,38.5389848', explain

gen_text(valid=False, text='4,GUY934,38.557145', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='0,316,-121.757881', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text=',38.939I793,-97.7448305', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='1,B94531,39.747718,47.612283268,-122.258407,923,-77.031748', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='2,103,9653,28323', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='23,KHA814,30.20269,-122.332996', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,DPNK008,-121.472218,38.569291', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='2,AMA683,SQ023,-122.442538', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='16,MWH972,30.261165,30.265425,47.61486', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='11,2996', explain='record not 

gen_text(valid=False, text='1,KWSXI915,38.90978,30.2690118,-76.9854685', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='12,14989,-118.45634,47.618706', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='1,QMGB44118,661,-118.457461', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='8,14Z513,33.985331,-82.34641', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='4,45707,38.585206,-122.021911,5981,-118.466301,FJZ1', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='5811,-121.498766,34.025375,-82.476078', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='3,EBNZ916,47.61578,-118.295375', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='22,EN298,34.06169,33.98418', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='20,98148', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='22,YVPB975

gen_text(valid=False, text='1,3730984,EKU45', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='768,31351', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='138,47.605405', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='23,1895,34.098523,-118.43271', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='0,IRG063,-77.043135', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='8', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='0,37567,-122.3987941', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='1,CDK791,38.5654098,-118.432765', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='17,GJG424,38.906795', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,WWA584,38.581372,-122.271,-121.44543', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='1,94,38.539216,30.

gen_text(valid=False, text='22,OUI134,47.615668,-122.33109', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='2,IRT543,39.28371', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='7,EVI022,9362,38.902831,-118.2916138', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='1,216,37.797545', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='OM,-122.003184,30.260351', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='4,JKW749,-77.031781', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='2,BGW862,34.012475,34.011528,-77.043768,-122.331628', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='9,-,-118.476453,38.900751,873,-76.610108', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,112', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='3,nT110,38.947886,38.5557328,-121.75738,-77.024271

gen_text(valid=False, text='23,31381K993,-97.7436935,g668', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,13680,39.296646,38.581943,006,-122.28891,-118.327116,37.76094', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,DMC608,38.923608', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='17,38.87709478,47.650778', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='14,WOD766,39.75961,-122.009789,-122.329578,-77.032041,38.908605,38.580728,-122.33031,38.5742555,34.014048', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='11,7899,37.776778', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='11,836,-77.0476763', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='14,4050984,-82.468428', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='20,TGAY958,-122.409201,38.9190696,38.937068,-77.00648,37.78511,38.5

gen_text(valid=False, text='14,41270,-76.9949726', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='23,WMW765,34.053678,QNG8', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='288,28754,86548,37.784056', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='KOL1,17878,38.579971,37.77781', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='22,31330,38.579005,-118.29283,-122.34108,0905', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='22,UUD637,33.99701,-122.342406', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='12,BUY363,-97.743595,-97.744479,-122.352003,-122.403345,38.58245,-97.75805', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='11,28771,-122.341888,30.262465', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='15,226', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='17,32

gen_text(valid=False, text='3,10595,-97.744221', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='8,48236,-122.33359', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='4,12601,-77.01914,-77.009003', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='597', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='13,854721,30.20281', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='20,gGL167,30.279106,-77.04025,05356', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,JON734,38.93056,-97.737108', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='13,NHB266,34.1051573,-76.592436,47.600843,30.239058', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='14,4724,38.565328', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='4,XZK774,47.61916,88073,38.55820,38.58211', explain='record not 6 parts

gen_text(valid=False, text='14,3756,38.577103', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='20,KZT453', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='11,NYC630,WN798,-121.48743', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='4,48115,-82.47278,-97.7304998,-121.466138,-122.420958', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='0,14515', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='0,Q953,-122.054121,38.88808,34.0152591,-118.38741', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='20,RMQ393,-82.484478,-76.60197,30.268888,38.578373,-97.7454416', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='303,-122.33459,47.9813745', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='518,47.621995', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='14,OEU778,30.267408,-77.017

gen_text(valid=False, text='21,HOM921,5443', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,UGI397,-122.320601,33.991405', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='20,31934,8509,37.788952', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='20,2531,38.583128', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='3,30927,30.2637133', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='4,7927,-118.292565,38.5957103,-122.418111', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,211,-122.42429,47.655311,-97.73656', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='2,303,-97.738053,9428,-77.005676,-82.460265,-97.745281', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='4,30.2591596,-118.291001', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='3,53245,34.019928,37.776728,3

gen_text(valid=False, text='17,QNN263,-122.4252306,BR105,34.052555,34.05779,38.595855', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='13,LYS500,-122.346603', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='17,DTO859,38.901333,30.291118,5733,47.6108218,-97.742393', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='20,793932,-121.4873201,-121.466701', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='2,11577,47.603613,-118.474156', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='21,DFO635,38.925138,885,LI786,-122.388313,-118.47975', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='1,LLP238,27.938093,34.003173,30.282715', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='17,OKG986,-118.28163,-118.4434878,34.042496', explain='record not 6 parts', delimiter=',')
gen_text(valid=False, text='14,KOL386,34.008005,-77.059991,34.