In [None]:
# Train a synthetic data model with differential privacy guarantees

#!pip install -Uqq "gretel-synthetics>=0.15.0"
#!pip install -Uqq "tensorflow==2.4.0rc1"

from pathlib import Path

from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.tokenizers import CharTokenizerTrainer
from gretel_synthetics.train import train

In [None]:
# This config will utilize TensorFlow Privacy to inject noised data into 
# the model during training. Adjust the dp_* parameters to balance
# privacy vs. accuracy for a synthetic model. 

config = TensorFlowConfig(
    gen_lines=1000,
    max_lines=1e5,
    dp=True,
    predict_batch_size=1,
    rnn_units=256,
    batch_size=16,
    learning_rate=0.0015,
    dp_noise_multiplier=0.2,
    dp_l2_norm_clip=1.0,
    dropout_rate=0.5,
    dp_microbatches=1,
    reset_states=False,
    overwrite=True,
    checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),
    # The "Netflix Challenge", dataset
    input_data_path='https://gretel-public-website.s3.amazonaws.com/datasets/netflix/netflix.txt'
)

# Initialize the tokenizer
tokenizer = CharTokenizerTrainer(config=config)

# Train the model
train(config, tokenizer)

In [None]:
from collections import Counter
import datetime
import pandas as pd
import json

from gretel_synthetics.generate import generate_text


# extract training params
def get_privacy_guarantees():
    df = pd.read_csv(f"{config.checkpoint_dir}/model_history.csv")
    epsilon = df[df['best'] == 1]['epsilon'].values[0]
    delta = df[df['best'] == 1]['delta'].values[0]
    return {
        "epsilon": epsilon,
        "delta": delta,
    }

# Build a validator
def validate_record(line):
    rec = line.split(",")
    if len(rec) == 4:
        datetime.datetime.strptime(rec[3], '%Y-%m-%d')
        int(rec[2])
        int(rec[1])
        int(rec[0])
    else:
        raise Exception('record not valid')


# Print differential privacy epsilon and delta values
print(json.dumps(get_privacy_guarantees(), indent=2))

# Print CSV header and synthetic lines
counter = 0
print("movie_id,user_id,rating,date")
for line in generate_text(config, 
                          line_validator=validate_record, 
                          max_invalid=1e5):
    if line.valid:
        print(f"{line.text}")
        counter += 1
    if counter > config.gen_lines:
        break
