In [None]:
#!pip install gretel-synthetics --upgrade
#!pip install matplotlib
#!pip install smart_open

In [None]:
# load source training set
import logging
import os
import sys
import pandas as pd
from smart_open import open

source_file = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uci-heart-disease/train.csv"
annotated_file = "./heart_annotated.csv"

def annotate_dataset(df):
    df = df.fillna("")
    df = df.replace(',', '[c]', regex=True)
    df = df.replace('\r', '', regex=True)
    df = df.replace('\n', ' ', regex=True)
    return df

# Preprocess dataset, store annotated file to disk
# Protip: Training set is very small, repeat so RNN can learn structure
df = annotate_dataset(pd.read_csv(source_file))
while not len(df.index) > 15000:
    df = df.append(df)
    
# Write annotated training data to disk
df.to_csv(annotated_file, index=False, header=False)

# Preview dataset
df.head(15)

In [None]:
# Plot distribution
counts = df['sex'].value_counts().sort_values(ascending=False)
counts.rename({1:"Male", 0:"Female"}).plot.pie()

In [None]:
from pathlib import Path

from gretel_synthetics.config import LocalConfig

# Create a config that we can use for both training and generating, with CPU-friendly settings
# The default values for ``max_chars`` and ``epochs`` are better suited for GPUs
config = LocalConfig(
    max_lines=0,         # maximum lines of training data. Set to ``0`` to train on entire file
    epochs=15,           # 15-50 epochs with GPU for best performance
    vocab_size=20000,    # tokenizer model vocabulary size
    gen_lines=1000,      # the number of generated text lines
    dp=True,             # train with differential privacy enabled (privacy assurances, but reduced accuracy)
    field_delimiter=",", # specify if the training text is structured, else ``None``
    overwrite=True,      # overwrite previously trained model checkpoints
    checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),
    input_data_path=annotated_file # filepath or S3
)

In [None]:
# Train a model
# The training function only requires our config as a single arg
from gretel_synthetics.train import train_rnn

train_rnn(config)

In [None]:
# Let's generate some records!

from collections import Counter
from gretel_synthetics.generate import generate_text

# Generate this many records
records_to_generate = 111

# Validate each generated record
# Note: This custom validator verifies the record structure matches
# the expected format for UCI healthcare data, and also that 
# generated records are Female (e.g. column 1 is 0)

def validate_record(line):
    rec = line.strip().split(",")
    if not int(rec[1]) == 0:
        raise Exception("record generated must be female")
    if len(rec) == 14:
        int(rec[0])
        int(rec[2])
        int(rec[3])
        int(rec[4])
        int(rec[5])
        int(rec[6])
        int(rec[7])
        int(rec[8])
        float(rec[9])
        int(rec[10])
        int(rec[11])
        int(rec[12])
        int(rec[13])
    else:
        raise Exception('record not 14 parts')
        
# Dataframe to hold synthetically generated records       
synth_df = pd.DataFrame(columns=df.columns)


for idx, record in enumerate(generate_text(config, line_validator=validate_record)):
    status = record.valid
    
    # ensure all generated records are unique
    synth_df = synth_df.drop_duplicates()
    synth_cnt = len(synth_df.index)
    if synth_cnt > records_to_generate:
        break 

    # if generated record passes validation, save it
    if status:
        print(f"({synth_cnt}/{records_to_generate} : {status})")        
        print(f"{record.text}")
        data = record.values_as_list()
        synth_df = synth_df.append({k:v for k,v in zip(df.columns, data)}, ignore_index=True)
            

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns

# Load model history from file
history = pd.read_csv(f"{(Path(config.checkpoint_dir) / 'model_history.csv').as_posix()}")

# Plot output
def plot_training_data(history: pd.DataFrame):
    sns.set(style="whitegrid")
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18,4))
    sns.lineplot(x=history['epoch'], y=history['loss'], ax=ax1, color='orange').set(title='Model training loss')
    history[['perplexity', 'epoch']].plot('epoch', ax=ax2, color='orange').set(title='Perplexity')
    history[['accuracy', 'epoch']].plot('epoch', ax=ax3, color='blue').set(title='% Accuracy')
    plt.show()

plot_training_data(history)


In [None]:
# Preview the synthetic dataset
synth_df.head(10)

In [None]:
# As a final step, combine the original training data + 
# our synthetic records, and shuffle them to prepare for training
train_df = annotate_dataset(pd.read_csv(source_file))
combined_df = synth_df.append(train_df).sample(frac=1)

# Write our final training dataset to disk (download this for the Kaggle experiment!)
combined_df.to_csv('synthetic_train_shuffled.csv', index=False)
combined_df.head(10)

In [None]:
# Plot distribution
counts = combined_df['sex'].astype(int).value_counts().sort_values(ascending=False)
counts.rename({1:"Male", 0:"Female"}).plot.pie()