In [None]:
from pathlib import Path

from gretel_synthetics.train import train
from gretel_synthetics.generate import generate_text
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.tokenizers import CharTokenizerTrainer

In [None]:
# This config will utilize TensorFlow Privacy to inject noised into the trained model.
# The privacy guarantees and 

config = TensorFlowConfig(
    max_lines=1e5,
    dp=True,
    predict_batch_size=1,
    rnn_units=256,
    batch_size=16,
    learning_rate=0.0015,
    dp_noise_multiplier=0.2,
    dp_l2_norm_clip=1.0,
    dropout_rate=0.5,
    dp_microbatches=1,
    reset_states=False,
    overwrite=True,
    checkpoint_dir=(Path.cwd() / 'dp-checkpoints').as_posix(),
    # The "Netflix Challenge", dataset
    input_data_path='https://gretel-public-website.s3.amazonaws.com/datasets/netflix/netflix.txt'
)

In [None]:
tokenizer = CharTokenizerTrainer(config=config)

In [None]:
train(config, tokenizer)

In [None]:
# Build a validator
def validate_record(line):
    rec = line.split(",")
    if len(rec) == 4:
        int(rec[2])
        int(rec[1])
        int(rec[0])
    else:
        raise Exception('record not 4 parts')

In [None]:
from collections import Counter
import pandas as pd
import json
from gretel_synthetics.errors import TooManyInvalidError

counter = Counter()

try:
    for line in generate_text(config, line_validator=validate_record, max_invalid=config.gen_lines):
        counter[line.valid] += 1
        total_count = counter[True] + counter[False]
        if total_count % 10 == 0:
            print(f"{total_count}/{config.gen_lines} : {line.text}")
except TooManyInvalidError:
    pass

# extract training params
df = pd.read_csv(f"{config.checkpoint_dir}/model_history.csv")
loss = df[df['best'] == 1]['loss'].values[0]
accuracy = df[df['best'] == 1]['accuracy'].values[0]
epsilon = df[df['best'] == 1]['epsilon'].values[0]
delta = df[df['best'] == 1]['delta'].values[0]
total = sum(counter.values())
valid_percent = counter[True] / total * 100.0

run_stats = {
    "dp": config.dp,
    "epochs": config.epochs,
    "learning_rate": config.learning_rate,
    "loss": loss,
    "accuracy": accuracy,
    "epsilon": epsilon,
    "delta": delta,
    "valid_count": int(counter[True]),
    "invalid_count": int(counter[False]),
    "valid_percent": valid_percent
}

print(json.dumps(run_stats, indent=2))