-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from gretelai/augment-data-example-aw
Augment data example aw
- Loading branch information
Showing
1 changed file
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,323 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 1000 | ||
}, | ||
"colab_type": "code", | ||
"id": "iVlbpffzGxby", | ||
"outputId": "4d56f959-108b-41d1-bd94-19fb08479cc1", | ||
"tags": [ | ||
"outputPrepend" | ||
] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#!pip install gretel-synthetics --upgrade\n", | ||
"#!pip install matplotlib\n", | ||
"#!pip install smart_open" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 514 | ||
}, | ||
"colab_type": "code", | ||
"id": "DL60W_-jGxb5", | ||
"outputId": "fbd4fce8-6fbe-40dc-b479-52c3f2f8bbb6" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# load source training set\n", | ||
"import logging\n", | ||
"import os\n", | ||
"import sys\n", | ||
"import pandas as pd\n", | ||
"from smart_open import open\n", | ||
"\n", | ||
"source_file = \"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uci-heart-disease/train.csv\"\n", | ||
"annotated_file = \"./heart_annotated.csv\"\n", | ||
"\n", | ||
"def annotate_dataset(df):\n", | ||
" df = df.fillna(\"\")\n", | ||
" df = df.replace(',', '[c]', regex=True)\n", | ||
" df = df.replace('\\r', '', regex=True)\n", | ||
" df = df.replace('\\n', ' ', regex=True)\n", | ||
" return df\n", | ||
"\n", | ||
"# Preprocess dataset, store annotated file to disk\n", | ||
"# Protip: Training set is very small, repeat so RNN can learn structure\n", | ||
"df = annotate_dataset(pd.read_csv(source_file))\n", | ||
"while not len(df.index) > 15000:\n", | ||
" df = df.append(df)\n", | ||
" \n", | ||
"# Write annotated training data to disk\n", | ||
"df.to_csv(annotated_file, index=False, header=False)\n", | ||
"\n", | ||
"# Preview dataset\n", | ||
"df.head(15)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 265 | ||
}, | ||
"colab_type": "code", | ||
"id": "n8hXICCwGxb7", | ||
"outputId": "5c821f71-075b-4aa3-b0d1-c5023908771f" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Plot distribution\n", | ||
"counts = df['sex'].value_counts().sort_values(ascending=False)\n", | ||
"counts.rename({1:\"Male\", 0:\"Female\"}).plot.pie()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "9vvh8q12Gxb-", | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"\n", | ||
"from gretel_synthetics.config import LocalConfig\n", | ||
"\n", | ||
"# Create a config that we can use for both training and generating, with CPU-friendly settings\n", | ||
"# The default values for ``max_chars`` and ``epochs`` are better suited for GPUs\n", | ||
"config = LocalConfig(\n", | ||
" max_lines=0, # read all lines (zero)\n", | ||
" epochs=15, # 15-30 epochs for production\n", | ||
" vocab_size=200, # tokenizer model vocabulary size\n", | ||
" character_coverage=1.0, # tokenizer model character coverage percent\n", | ||
" gen_chars=0, # the maximum number of characters possible per-generated line of text\n", | ||
" gen_lines=10000, # the number of generated text lines\n", | ||
" rnn_units=256, # dimensionality of LSTM output space\n", | ||
" batch_size=64, # batch size\n", | ||
" buffer_size=1000, # buffer size to shuffle the dataset\n", | ||
" dropout_rate=0.2, # fraction of the inputs to drop\n", | ||
" dp=True, # let's use differential privacy\n", | ||
" dp_learning_rate=0.015, # learning rate\n", | ||
" dp_noise_multiplier=1.1, # control how much noise is added to gradients\n", | ||
" dp_l2_norm_clip=1.0, # bound optimizer's sensitivity to individual training points\n", | ||
" dp_microbatches=256, # split batches into minibatches for parallelism\n", | ||
" checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),\n", | ||
" save_all_checkpoints=False,\n", | ||
" input_data_path=annotated_file # filepath or S3\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 1000 | ||
}, | ||
"colab_type": "code", | ||
"id": "lEhcYd8gGxcD", | ||
"outputId": "9d7ccb8e-3e77-46eb-e0b9-8c1394ea7132" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Train a model\n", | ||
"# The training function only requires our config as a single arg\n", | ||
"from gretel_synthetics.train import train_rnn\n", | ||
"\n", | ||
"train_rnn(config)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 1000 | ||
}, | ||
"colab_type": "code", | ||
"id": "GKEHQEHnGxcG", | ||
"outputId": "0d3bcd56-07c0-41d4-8958-dcf3127c3c15" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Let's generate some records!\n", | ||
"\n", | ||
"from collections import Counter\n", | ||
"from gretel_synthetics.generate import generate_text\n", | ||
"\n", | ||
"# Generate this many records\n", | ||
"records_to_generate = 111\n", | ||
"\n", | ||
"# Validate each generated record\n", | ||
"# Note: This custom validator verifies the record structure matches\n", | ||
"# the expected format for UCI healthcare data, and also that \n", | ||
"# generated records are Female (e.g. column 1 is 0)\n", | ||
"\n", | ||
"def validate_record(line):\n", | ||
" rec = line.strip().split(\",\")\n", | ||
" if not int(rec[1]) == 0:\n", | ||
" raise Exception(\"record generated must be female\")\n", | ||
" if len(rec) == 14:\n", | ||
" int(rec[0])\n", | ||
" int(rec[2])\n", | ||
" int(rec[3])\n", | ||
" int(rec[4])\n", | ||
" int(rec[5])\n", | ||
" int(rec[6])\n", | ||
" int(rec[7])\n", | ||
" int(rec[8])\n", | ||
" float(rec[9])\n", | ||
" int(rec[10])\n", | ||
" int(rec[11])\n", | ||
" int(rec[12])\n", | ||
" int(rec[13])\n", | ||
" else:\n", | ||
" raise Exception('record not 14 parts')\n", | ||
" \n", | ||
"# Dataframe to hold synthetically generated records \n", | ||
"synth_df = pd.DataFrame(columns=df.columns)\n", | ||
"\n", | ||
"\n", | ||
"for idx, line in enumerate(generate_text(config, line_validator=validate_record)):\n", | ||
" status = line['valid']\n", | ||
" \n", | ||
" # ensure all generated records are unique\n", | ||
" synth_df = synth_df.drop_duplicates()\n", | ||
" synth_cnt = len(synth_df.index)\n", | ||
" if synth_cnt > records_to_generate:\n", | ||
" break \n", | ||
"\n", | ||
" # if generated record passes validation, save it\n", | ||
" if status:\n", | ||
" print(f\"({synth_cnt}/{records_to_generate} : {status})\") \n", | ||
" print(f\"{line['text']}\")\n", | ||
" data = line['text'].split(\",\")\n", | ||
" synth_df = synth_df.append({k:v for k,v in zip(df.columns, data)}, ignore_index=True)\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "hs6y7RyYGxcI" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"from pathlib import Path\n", | ||
"import seaborn as sns\n", | ||
"\n", | ||
"# Load model history from file\n", | ||
"history = pd.read_csv(f\"{(Path(config.checkpoint_dir) / 'model_history.csv').as_posix()}\")\n", | ||
"\n", | ||
"# Plot output\n", | ||
"def plot_training_data(history: pd.DataFrame):\n", | ||
" sns.set(style=\"whitegrid\")\n", | ||
" fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18,4))\n", | ||
" sns.lineplot(x=history['epoch'], y=history['loss'], ax=ax1, color='orange').set(title='Model training loss')\n", | ||
" history[['perplexity', 'epoch']].plot('epoch', ax=ax2, color='orange').set(title='Perplexity')\n", | ||
" history[['accuracy', 'epoch']].plot('epoch', ax=ax3, color='blue').set(title='% Accuracy')\n", | ||
" plt.show()\n", | ||
"\n", | ||
"plot_training_data(history)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "irurlPx1GxcK" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Preview the synthetic dataset\n", | ||
"synth_df.head(10)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "CT_m94w6GxcM" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# As a final step, combine the original training data + \n", | ||
"# our synthetic records, and shuffle them to prepare for training\n", | ||
"train_df = annotate_dataset(pd.read_csv(source_file))\n", | ||
"combined_df = synth_df.append(train_df).sample(frac=1)\n", | ||
"\n", | ||
"# Write our final training dataset to disk (download this for the Kaggle experiment!)\n", | ||
"combined_df.to_csv('synthetic_train_shuffled.csv', index=False)\n", | ||
"combined_df.head(10)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "9w2IdK_8GxcO" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Plot distribution\n", | ||
"counts = combined_df['sex'].astype(int).value_counts().sort_values(ascending=False)\n", | ||
"counts.rename({1:\"Male\", 0:\"Female\"}).plot.pie()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"name": "heart_disease_uci.ipynb", | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |