Skip to content

Commit

Permalink
Merge pull request #12 from gretelai/augment-data-example-aw
Browse files Browse the repository at this point in the history
Augment data example aw
  • Loading branch information
zredlined committed May 8, 2020
2 parents 8def903 + 0aaf9be commit a92a026
Showing 1 changed file with 323 additions and 0 deletions.
323 changes: 323 additions & 0 deletions examples/research/heart_disease_uci.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"colab_type": "code",
"id": "iVlbpffzGxby",
"outputId": "4d56f959-108b-41d1-bd94-19fb08479cc1",
"tags": [
"outputPrepend"
]
},
"outputs": [],
"source": [
"#!pip install gretel-synthetics --upgrade\n",
"#!pip install matplotlib\n",
"#!pip install smart_open"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 514
},
"colab_type": "code",
"id": "DL60W_-jGxb5",
"outputId": "fbd4fce8-6fbe-40dc-b479-52c3f2f8bbb6"
},
"outputs": [],
"source": [
"# load source training set\n",
"import logging\n",
"import os\n",
"import sys\n",
"import pandas as pd\n",
"from smart_open import open\n",
"\n",
"source_file = \"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uci-heart-disease/train.csv\"\n",
"annotated_file = \"./heart_annotated.csv\"\n",
"\n",
"def annotate_dataset(df):\n",
" df = df.fillna(\"\")\n",
" df = df.replace(',', '[c]', regex=True)\n",
" df = df.replace('\\r', '', regex=True)\n",
" df = df.replace('\\n', ' ', regex=True)\n",
" return df\n",
"\n",
"# Preprocess dataset, store annotated file to disk\n",
"# Protip: Training set is very small, repeat so RNN can learn structure\n",
"df = annotate_dataset(pd.read_csv(source_file))\n",
"while not len(df.index) > 15000:\n",
" df = df.append(df)\n",
" \n",
"# Write annotated training data to disk\n",
"df.to_csv(annotated_file, index=False, header=False)\n",
"\n",
"# Preview dataset\n",
"df.head(15)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 265
},
"colab_type": "code",
"id": "n8hXICCwGxb7",
"outputId": "5c821f71-075b-4aa3-b0d1-c5023908771f"
},
"outputs": [],
"source": [
"# Plot distribution\n",
"counts = df['sex'].value_counts().sort_values(ascending=False)\n",
"counts.rename({1:\"Male\", 0:\"Female\"}).plot.pie()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9vvh8q12Gxb-",
"scrolled": true
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from gretel_synthetics.config import LocalConfig\n",
"\n",
"# Create a config that we can use for both training and generating, with CPU-friendly settings\n",
"# The default values for ``max_chars`` and ``epochs`` are better suited for GPUs\n",
"config = LocalConfig(\n",
" max_lines=0, # read all lines (zero)\n",
" epochs=15, # 15-30 epochs for production\n",
" vocab_size=200, # tokenizer model vocabulary size\n",
" character_coverage=1.0, # tokenizer model character coverage percent\n",
" gen_chars=0, # the maximum number of characters possible per-generated line of text\n",
" gen_lines=10000, # the number of generated text lines\n",
" rnn_units=256, # dimensionality of LSTM output space\n",
" batch_size=64, # batch size\n",
" buffer_size=1000, # buffer size to shuffle the dataset\n",
" dropout_rate=0.2, # fraction of the inputs to drop\n",
" dp=True, # let's use differential privacy\n",
" dp_learning_rate=0.015, # learning rate\n",
" dp_noise_multiplier=1.1, # control how much noise is added to gradients\n",
" dp_l2_norm_clip=1.0, # bound optimizer's sensitivity to individual training points\n",
" dp_microbatches=256, # split batches into minibatches for parallelism\n",
" checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),\n",
" save_all_checkpoints=False,\n",
" input_data_path=annotated_file # filepath or S3\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"colab_type": "code",
"id": "lEhcYd8gGxcD",
"outputId": "9d7ccb8e-3e77-46eb-e0b9-8c1394ea7132"
},
"outputs": [],
"source": [
"# Train a model\n",
"# The training function only requires our config as a single arg\n",
"from gretel_synthetics.train import train_rnn\n",
"\n",
"train_rnn(config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"colab_type": "code",
"id": "GKEHQEHnGxcG",
"outputId": "0d3bcd56-07c0-41d4-8958-dcf3127c3c15"
},
"outputs": [],
"source": [
"# Let's generate some records!\n",
"\n",
"from collections import Counter\n",
"from gretel_synthetics.generate import generate_text\n",
"\n",
"# Generate this many records\n",
"records_to_generate = 111\n",
"\n",
"# Validate each generated record\n",
"# Note: This custom validator verifies the record structure matches\n",
"# the expected format for UCI healthcare data, and also that \n",
"# generated records are Female (e.g. column 1 is 0)\n",
"\n",
"def validate_record(line):\n",
" rec = line.strip().split(\",\")\n",
" if not int(rec[1]) == 0:\n",
" raise Exception(\"record generated must be female\")\n",
" if len(rec) == 14:\n",
" int(rec[0])\n",
" int(rec[2])\n",
" int(rec[3])\n",
" int(rec[4])\n",
" int(rec[5])\n",
" int(rec[6])\n",
" int(rec[7])\n",
" int(rec[8])\n",
" float(rec[9])\n",
" int(rec[10])\n",
" int(rec[11])\n",
" int(rec[12])\n",
" int(rec[13])\n",
" else:\n",
" raise Exception('record not 14 parts')\n",
" \n",
"# Dataframe to hold synthetically generated records \n",
"synth_df = pd.DataFrame(columns=df.columns)\n",
"\n",
"\n",
"for idx, line in enumerate(generate_text(config, line_validator=validate_record)):\n",
" status = line['valid']\n",
" \n",
" # ensure all generated records are unique\n",
" synth_df = synth_df.drop_duplicates()\n",
" synth_cnt = len(synth_df.index)\n",
" if synth_cnt > records_to_generate:\n",
" break \n",
"\n",
" # if generated record passes validation, save it\n",
" if status:\n",
" print(f\"({synth_cnt}/{records_to_generate} : {status})\") \n",
" print(f\"{line['text']}\")\n",
" data = line['text'].split(\",\")\n",
" synth_df = synth_df.append({k:v for k,v in zip(df.columns, data)}, ignore_index=True)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "hs6y7RyYGxcI"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"import seaborn as sns\n",
"\n",
"# Load model history from file\n",
"history = pd.read_csv(f\"{(Path(config.checkpoint_dir) / 'model_history.csv').as_posix()}\")\n",
"\n",
"# Plot output\n",
"def plot_training_data(history: pd.DataFrame):\n",
" sns.set(style=\"whitegrid\")\n",
" fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18,4))\n",
" sns.lineplot(x=history['epoch'], y=history['loss'], ax=ax1, color='orange').set(title='Model training loss')\n",
" history[['perplexity', 'epoch']].plot('epoch', ax=ax2, color='orange').set(title='Perplexity')\n",
" history[['accuracy', 'epoch']].plot('epoch', ax=ax3, color='blue').set(title='% Accuracy')\n",
" plt.show()\n",
"\n",
"plot_training_data(history)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "irurlPx1GxcK"
},
"outputs": [],
"source": [
"# Preview the synthetic dataset\n",
"synth_df.head(10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "CT_m94w6GxcM"
},
"outputs": [],
"source": [
"# As a final step, combine the original training data + \n",
"# our synthetic records, and shuffle them to prepare for training\n",
"train_df = annotate_dataset(pd.read_csv(source_file))\n",
"combined_df = synth_df.append(train_df).sample(frac=1)\n",
"\n",
"# Write our final training dataset to disk (download this for the Kaggle experiment!)\n",
"combined_df.to_csv('synthetic_train_shuffled.csv', index=False)\n",
"combined_df.head(10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9w2IdK_8GxcO"
},
"outputs": [],
"source": [
"# Plot distribution\n",
"counts = combined_df['sex'].astype(int).value_counts().sort_values(ascending=False)\n",
"counts.rename({1:\"Male\", 0:\"Female\"}).plot.pie()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "heart_disease_uci.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

0 comments on commit a92a026

Please sign in to comment.