-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update tokenizer demo notebook * readme updates, DP notebook * dp updates * int tests update * use all procs for ITs * update netflix example for differential privacy * Update tokenizer documentation * Research papers and Slack badge * bugfix * add datetime validator
- Loading branch information
1 parent
5a58a06
commit e15479e
Showing
13 changed files
with
186 additions
and
181 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,4 +136,6 @@ venv* | |
checkpoints | ||
examples/checkpoints.zip | ||
|
||
docs/_build | ||
docs/_build | ||
examples/tokenizer_demo/ | ||
dp-checkpoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "synthetics_dp.py", | ||
"provenance": [], | ||
"collapsed_sections": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"accelerator": "GPU" | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "TY5zsaXme67e" | ||
}, | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"\n", | ||
"from gretel_synthetics.config import TensorFlowConfig\n", | ||
"from gretel_synthetics.tokenizers import CharTokenizerTrainer\n", | ||
"from gretel_synthetics.train import train" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "CNTeiB24e4TJ" | ||
}, | ||
"source": [ | ||
"# This config will utilize TensorFlow Privacy to inject noised data into \n", | ||
"# the model during training. Adjust the dp_* parameters to balance\n", | ||
"# privacy vs. accuracy for a synthetic model. \n", | ||
"\n", | ||
"config = TensorFlowConfig(\n", | ||
" gen_lines=1000,\n", | ||
" max_lines=1e5,\n", | ||
" dp=True,\n", | ||
" predict_batch_size=1,\n", | ||
" rnn_units=256,\n", | ||
" batch_size=16,\n", | ||
" learning_rate=0.0015,\n", | ||
" dp_noise_multiplier=0.2,\n", | ||
" dp_l2_norm_clip=1.0,\n", | ||
" dropout_rate=0.5,\n", | ||
" dp_microbatches=1,\n", | ||
" reset_states=False,\n", | ||
" overwrite=True,\n", | ||
" checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),\n", | ||
" # The \"Netflix Challenge\", dataset\n", | ||
" input_data_path='https://gretel-public-website.s3.amazonaws.com/datasets/netflix/netflix.txt'\n", | ||
")\n", | ||
"\n", | ||
"# Initialize the tokenizer\n", | ||
"tokenizer = CharTokenizerTrainer(config=config)\n", | ||
"\n", | ||
"# Train the model\n", | ||
"train(config, tokenizer)" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "EzAlijMSlSJz" | ||
}, | ||
"source": [ | ||
"from collections import Counter\n", | ||
"import datetime\n", | ||
"import pandas as pd\n", | ||
"import json\n", | ||
"\n", | ||
"from gretel_synthetics.generate import generate_text\n", | ||
"\n", | ||
"\n", | ||
"# extract training params\n", | ||
"def get_privacy_guarantees():\n", | ||
" df = pd.read_csv(f\"{config.checkpoint_dir}/model_history.csv\")\n", | ||
" epsilon = df[df['best'] == 1]['epsilon'].values[0]\n", | ||
" delta = df[df['best'] == 1]['delta'].values[0]\n", | ||
" return {\n", | ||
" \"epsilon\": epsilon,\n", | ||
" \"delta\": delta,\n", | ||
" }\n", | ||
"\n", | ||
"# Build a validator\n", | ||
"def validate_record(line):\n", | ||
" rec = line.split(\",\")\n", | ||
" if len(rec) == 4:\n", | ||
" datetime.datetime.strptime(rec[3], '%Y-%m-%d')\n", | ||
" int(rec[2])\n", | ||
" int(rec[1])\n", | ||
" int(rec[0])\n", | ||
" else:\n", | ||
" raise Exception('record not valid')\n", | ||
"\n", | ||
"\n", | ||
"# Print differential privacy epsilon and delta values\n", | ||
"print(json.dumps(get_privacy_guarantees(), indent=2))\n", | ||
"\n", | ||
"# Print CSV header and synthetic lines\n", | ||
"counter = 0\n", | ||
"print(\"movie_id,user_id,rating,date\")\n", | ||
"for line in generate_text(config, \n", | ||
" line_validator=validate_record, \n", | ||
" max_invalid=1e5):\n", | ||
" if line.valid:\n", | ||
" print(f\"{line.text}\")\n", | ||
" counter += 1\n", | ||
" if counter > config.gen_lines:\n", | ||
" break\n" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.