<a href="https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/AW%2Fgtcc-example/docs/notebooks/balance_uci_heart_disease.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates using Gretel.ai's conditional sampling to balance the gender attributes in a popular healthcare dataset, resulting in both better ML model accuracy, and potentially a more ethically fair training set.

The Heart Disease dataset published by University of California Irvine is one of the top 5 datasets on the data science competition site Kaggle, with 9 data science tasks listed and 1,014+ notebook kernels created by data scientists. It is a series of health 14 attributes and is labeled with whether the patient had a heart disease or not, making it a great dataset for prediction.


In [None]:
%%capture
!pip install gretel_client xgboost

In [None]:
from gretel_client import configure_session

configure_session(api_key="prompt", cache="yes", validate=True)

Enter Gretel API key··········


In [None]:
# Load and preview dataset

import pandas as pd

# Create from Kaggle dataset using an 70/30% split.
train = pd.read_csv('https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uci-heart-disease/heart_train.csv')
test = pd.read_csv('https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uci-heart-disease/heart_test.csv')

train

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,63,0,1,140,195,0,1,179,0,0.0,2,2,2,1
1,55,0,0,180,327,0,2,117,1,3.4,1,0,2,0
2,53,1,2,130,246,1,0,173,0,0.0,2,3,2,1
3,59,1,0,170,326,0,0,140,1,3.4,0,0,3,0
4,66,1,0,160,228,0,0,138,0,2.3,2,0,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
712,42,1,0,136,315,0,1,125,1,1.8,1,0,1,0
713,49,1,2,118,149,0,0,126,0,0.8,2,3,2,0
714,50,0,2,120,219,0,1,158,0,1.6,1,0,2,1
715,61,0,0,145,307,0,0,146,1,1.0,1,0,3,0


In [None]:
# Plot distributions in real world data

pd.options.plotting.backend = 'plotly'

df = train.sex.copy()
df = df.replace(0, 'female').replace(1, 'male')

print(f"We will need to augment training set with an additional {train.sex.value_counts()[1] - train.sex.value_counts()[0]} records to balance gender class")
df.value_counts().sort_values().plot(kind = 'barh', title='Real world distribution')


We will need to augment training set with an additional 271 records to balance gender class


In [None]:
# Train a synthetic model on the training set

from smart_open import open
import yaml

from gretel_client import projects
from gretel_client.helpers import poll

# Create a project and model configuration.
project = projects.create_or_get_unique_project(name='uci-heart-disease')

with open("https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/config_templates/gretel/synthetics/default.yml", 'r') as stream:
    config = yaml.safe_load(stream)

# Here we prepare an object to specify the conditional data generation task.
fields=["sex"]
task = {
    'type': 'seed',
    'attrs': {
        'fields': fields
    }
}
config['models'][0]['synthetics']['task'] = task
config['models'][0]['synthetics']['generate'] = {'num_records': 500}
config['models'][0]['synthetics']['privacy_filters'] = {'similarity': None, 'outliers': None}


# Fit the model on the training set
model = project.create_model_obj(model_config=config)
train.to_csv('train.csv', index=False)
model.data_source = 'train.csv'
model.submit_cloud()

poll(model)

synthetic = pd.read_csv(model.get_artifact_link("data_preview"), compression='gzip')
synthetic

[32mINFO: [0mStarting poller


{
    "uid": "622691fb4c1fc91c717fbc0b",
    "guid": "model_264tuQIYO3qNkLQstGK1BdVhqDR",
    "model_name": "pretty-delightful-toad",
    "runner_mode": "cloud",
    "user_id": "5f3c3afbbff62139634c66ca",
    "user_guid": null,
    "billing_domain": "gretel.ai",
    "billing_domain_guid": null,
    "project_id": "62268cc207d85f2fc307c400",
    "project_guid": "proj_264rCKf3f0vOrfCYpMbeIdYmtTt",
    "status_history": {
        "created": "2022-03-07T23:15:07.948097Z"
    },
    "last_modified": "2022-03-07T23:15:08.042493Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "container_image": "074762682575.dkr.ecr.us-west-2.amazonaws.com/gretelai/synthetics@sha256:717a68c0e4ef3000c8b650bbed308162ef10c1b2cb4bfc3026b773bc908ee577",
    "model_type": "synthetics",
    "config": {
        "schema_version": "1.0",
        "name": null,
        "models": [
            {
                "sy

[32mINFO: [0mStatus is created. Model creation has been queued.
[32mINFO: [0mStatus is pending. A Gretel Cloud worker is being allocated to begin model creation.
[32mINFO: [0mStatus is active. A worker has started creating your model!
2022-03-07T23:15:25.445570Z  Starting synthetic model training
2022-03-07T23:15:25.447592Z  Loading training data
2022-03-07T23:15:25.658573Z  Training data loaded, detected format: 'csv'
2022-03-07T23:15:25.673007Z  Training data loaded
{
    "record_count": 717,
    "field_count": 14,
    "upsample_count": 9283
}
2022-03-07T23:15:29.080470Z  Creating semantic validators and preparing training data
2022-03-07T23:15:36.033983Z  Beginning ML model training
2022-03-07T23:15:44.263820Z  Training epoch completed
{
    "epoch": 0,
    "accuracy": 0.3424,
    "loss": 2.3506,
    "val_accuracy": 0,
    "val_loss": 0,
    "batch": 0
}
2022-03-07T23:15:45.836130Z  Training epoch completed
{
    "epoch": 1,
    "accuracy": 0.5623,
    "loss": 1.3575,
    "val

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,63,0,0,150,407,0,0,154,0,4.0,1,3,3,0
1,58,0,2,120,340,0,1,172,0,0.0,2,0,2,1
2,52,1,1,134,201,0,1,158,0,0.8,2,1,2,1
3,56,1,1,130,221,0,0,163,0,0.0,2,0,3,1
4,52,1,2,138,223,0,1,169,0,0.0,2,4,2,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,46,1,0,140,311,0,1,120,1,1.8,1,2,3,0
496,48,1,1,110,229,0,1,168,0,1.0,0,0,3,0
497,61,1,3,134,234,0,1,145,0,2.6,1,2,2,0
498,42,0,0,102,265,0,0,122,0,0.6,1,0,2,1


In [None]:
# Conditionaly sample records from the synthetic data model using `seeds`
# to augment the real world training data


num_rows = 5000
seeds = pd.DataFrame(index=range(num_rows), columns=['sex']).fillna(0)
delta = train.sex.value_counts()[1] - train.sex.value_counts()[0]
seeds['sex'][int((num_rows + delta) / 2):] = 1
seeds.sample(frac=1).to_csv('seeds.csv', index=False)

rh = model.create_record_handler_obj(data_source="seeds.csv", params={"num_records": len(seeds)})
rh.submit_cloud()

poll(rh)

synthetic = pd.read_csv(rh.get_artifact_link("data"), compression='gzip')
augmented = pd.concat([synthetic, train])
augmented

[32mINFO: [0mStarting poller


{
    "uid": "622693fc6b0dce1074b369be",
    "guid": "model_run_264uwoNwWwQdNA1CQvh9YrXWIHX",
    "model_name": null,
    "runner_mode": "cloud",
    "user_id": "5f3c3afbbff62139634c66ca",
    "user_guid": null,
    "billing_domain": "gretel.ai",
    "billing_domain_guid": null,
    "project_id": "62268cc207d85f2fc307c400",
    "project_guid": "proj_264rCKf3f0vOrfCYpMbeIdYmtTt",
    "status_history": {
        "created": "2022-03-07T23:23:40.318000Z"
    },
    "last_modified": "2022-03-07T23:23:40.430000Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "container_image": "074762682575.dkr.ecr.us-west-2.amazonaws.com/gretelai/synthetics@sha256:717a68c0e4ef3000c8b650bbed308162ef10c1b2cb4bfc3026b773bc908ee577",
    "model_id": "622691fb4c1fc91c717fbc0b",
    "model_guid": "model_264tuQIYO3qNkLQstGK1BdVhqDR",
    "action": "generate",
    "config": {
        "params": {
           

[32mINFO: [0mStatus is created. A Record generation job has been queued.
[32mINFO: [0mStatus is pending. A Gretel Cloud worker is being allocated to begin generating synthetic records.
[32mINFO: [0mStatus is active. A worker has started!
2022-03-07T23:23:56.950331Z  Loading model to worker
2022-03-07T23:23:57.322729Z  Checking for synthetic smart seeds
2022-03-07T23:23:57.453394Z  Loaded 5000 smart seeds for generation
2022-03-07T23:23:57.454106Z  Loading model
2022-03-07T23:23:59.263863Z  Generating records
{
    "num_records": 5000
}
2022-03-07T23:24:04.276116Z  Generation in progress
{
    "current_valid_count": 104,
    "current_invalid_count": 4,
    "new_valid_count": 104,
    "new_invalid_count": 4,
    "completion_percent": 2.08
}
2022-03-07T23:24:09.283978Z  Generation in progress
{
    "current_valid_count": 256,
    "current_invalid_count": 9,
    "new_valid_count": 152,
    "new_invalid_count": 5,
    "completion_percent": 5.12
}
2022-03-07T23:24:14.290434Z  Generatio

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,67,1,0,120,229,0,0,129,1,2.6,1,2,3,0
1,52,1,0,112,230,0,1,160,0,0.0,2,1,2,0
2,63,0,0,108,269,0,1,169,1,1.8,1,2,2,0
3,39,1,2,140,321,0,0,182,0,0.0,2,0,2,1
4,41,1,1,135,203,0,1,132,0,0.0,1,0,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
712,42,1,0,136,315,0,1,125,1,1.8,1,0,1,0
713,49,1,2,118,149,0,0,126,0,0.8,2,3,2,0
714,50,0,2,120,219,0,1,158,0,1.6,1,0,2,1
715,61,0,0,145,307,0,0,146,1,1.0,1,0,3,0


In [None]:
# Plot distributions in the synthetic data


print(f"Augmented synthetic dataset with an additional {delta} records to balance gender class")
df = augmented.sex.copy()
df = df.replace(0, 'female').replace(1, 'male')
df.value_counts().sort_values().plot(kind = 'barh', title='Augmented dataset distribution')

Augmented synthetic dataset with an additional 271 records to balance gender class


In [None]:
# Compare real world vs. synthetic accuracies using popular classifiers

import matplotlib.pyplot as plt

from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier

import plotly.express as px


def classification_accuracy(data_type, dataset, test) -> dict:
    
    accuracies = []
    x_cols = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 
             'exang', 'oldpeak', 'slope', 'ca', 'thal']
    y_col = 'target'

    rf = RandomForestClassifier(n_estimators = 1000, random_state = 1)
    rf.fit(dataset[x_cols], dataset[y_col])
    acc = rf.score(test[x_cols], test[y_col])*100
    accuracies.append([data_type, 'RandomForest', acc])
    print(" -- Random Forest: {:.2f}%".format(acc))
    
    svm = SVC(random_state = 1)
    svm.fit(dataset[x_cols], dataset[y_col])
    acc = svm.score(test[x_cols], test[y_col])*100
    accuracies.append([data_type, 'SVM', acc])
    print(" -- SVM: {:.2f}%".format(acc))
    
    knn = KNeighborsClassifier(n_neighbors = 2)  # n_neighbors means k
    knn.fit(dataset[x_cols], dataset[y_col])
    acc = knn.score(test[x_cols], test[y_col])*100
    accuracies.append([data_type, 'KNN', acc])
    print(" -- KNN: {:.2f}%".format(acc))
    
    dtc = DecisionTreeClassifier()
    dtc.fit(dataset[x_cols], dataset[y_col])
    acc = dtc.score(test[x_cols], test[y_col])*100
    accuracies.append([data_type, 'DecisionTree', acc])
    print(" -- Decision Tree Test Accuracy {:.2f}%".format(acc))
    
    xgb = XGBClassifier(use_label_encoder=False, eval_metric='error')
    xgb.fit(dataset[x_cols], dataset[y_col])
    acc = xgb.score(test[x_cols], test[y_col])*100
    accuracies.append([data_type, 'XGBoost', acc])
    print(" -- XGBoostClassifier: {:.2f}%".format(acc))
    
    return accuracies

print("Calculating real world accuracies")
realworld_acc = classification_accuracy('real world', train, test)
print("Calculating synthetic accuracies")
synthetic_acc = classification_accuracy('synthetic', augmented, test)

comparison = pd.DataFrame(realworld_acc + synthetic_acc, columns=['data_type', 'algorithm', 'acc'])
colours = {
    "synthetic": "#3EC1CD",
    "synthetic1": "#FCB94D",
    "real world": "#9ee0e6",
    "real world1": "#fddba5"
    }

fig = px.bar(comparison, x='algorithm', y='acc',  color='data_type', 
             color_discrete_map=colours, barmode='group', text_auto='.4s',
             title='Real World vs. Synthetic Data for <b>all classes</b>')
fig.update_layout(legend_title_text = '<b>Real world v. Synthetic</b>')
fig.show()

Calculating real world accuracies
 -- Random Forest: 95.13%
 -- SVM: 70.78%
 -- KNN: 89.61%
 -- Decision Tree Test Accuracy 95.13%
 -- XGBoostClassifier: 94.16%
Calculating synthetic accuracies
 -- Random Forest: 96.10%
 -- SVM: 68.83%
 -- KNN: 94.16%
 -- Decision Tree Test Accuracy 96.10%
 -- XGBoostClassifier: 96.43%


In [None]:
print("Calculating real world class accuracies")
realworld_male = classification_accuracy('realworld_male', train, test.loc[test['sex'] == 1])
realworld_female = classification_accuracy('realworld_female', train, test.loc[test['sex'] == 0])
print("Calculating synthetic class accuracies")
synthetic_male = classification_accuracy('synthetic_male', augmented, test.loc[test['sex'] == 1])
synthetic_female = classification_accuracy('synthetic_female', augmented, test.loc[test['sex'] == 0])

Calculating real world class accuracies
 -- Random Forest: 94.52%
 -- SVM: 73.97%
 -- KNN: 91.32%
 -- Decision Tree Test Accuracy 95.89%
 -- XGBoostClassifier: 93.15%
 -- Random Forest: 96.63%
 -- SVM: 62.92%
 -- KNN: 85.39%
 -- Decision Tree Test Accuracy 93.26%
 -- XGBoostClassifier: 96.63%
Calculating synthetic class accuracies
 -- Random Forest: 95.89%
 -- SVM: 65.30%
 -- KNN: 93.15%
 -- Decision Tree Test Accuracy 95.89%
 -- XGBoostClassifier: 94.98%
 -- Random Forest: 96.63%
 -- SVM: 77.53%
 -- KNN: 96.63%
 -- Decision Tree Test Accuracy 96.63%
 -- XGBoostClassifier: 100.00%


In [None]:
# Plot male (majority class) heart disease detection accuracies (real world vs. synthetic)
colours = {
    "synthetic_male": "#3EC1CD",
    "synthetic_female": "#FCB94D",
    "realworld_male": "#9ee0e6",
    "realworld_female": "#fddba5",
}

comparison = pd.DataFrame(realworld_male + synthetic_male + realworld_female + synthetic_female, columns=['data_type', 'algorithm', 'acc'])
fig = px.bar(comparison, x='algorithm', y='acc',  color='data_type', 
             color_discrete_map=colours, barmode='group', text_auto='.4s', title='Real World vs. Synthetic Accuracy for <b>Male and Female Heart Disease Detection</b>')
fig.update_layout(legend_title_text = '<b>Real world v. Synthetic</b>')
fig.show()