<a href="https://colab.research.google.com/github/mars137/synthetic-data/blob/main/docs/notebooks/balance_data_with_conditional_data_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Balancing datasets with conditional data generation

Imbalanced datasets are a common problem in machine learning. There are several different scenarios where an imbalanced dataset can lead to a less than optimal model solution. One scenario is when you're training a multi-class classifier and one or more of the classes have fewer training examples than the others. This can sometimes lead to a model that may look like it's doing well overall,when really the accuracy of the underepresented classes is inferior to that of the classes with good representation.

Another scenario is when the training data has imbalanced demographic data. Part of what the Fair AI movement is about is ensuring that AI models do equally well on all demographic slices.

One approach to improve representational biases in data is through by conditioning Gretel's synthetic data model to generate more examples of different classes of data.

You can use the approach to replace the original data with a balanced synthetic dataset or you can use it to augment the existing dataset, producing just enough synthetic data such that when added back into the original data, the imbalance is resolved.

In this notebook, we're going to step you through how to use Gretel synthetics to resolve demographic bias in a dataset. We will be creating a new synthetic dataset that can be used in place of the original one.


## Begin by authenticating


In [1]:
%%capture
!pip install -U gretel-client

In [2]:
# Specify your Gretel API key

import pandas as pd
from gretel_client import configure_session

pd.set_option("max_colwidth", None)

configure_session(api_key="prompt", cache="yes", validate=True)


Gretel Api Key··········
Caching Gretel config to disk.
Using endpoint https://api.gretel.cloud
Logged in as atif.tahir13@gmail.com ✅


## Load and view the dataset


In [3]:
a = pd.read_csv(
    "https://gretel-public-website.s3.amazonaws.com/datasets/experiments/healthcare_dataset_a.csv"
)

a


Unnamed: 0,DATE,PatientID,BirthDate,RACE,ETHNICITY,GENDER,BIRTHPLACE,ConditionDesc,EncounterDesc,EncounterReasonDesc,ObsDesc,ObsValueFloat,UNITS,ProcedureDate,ProcedureDesc
0,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2010-06-28,Measurement of respiratory function (procedure)
1,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2012-01-31,Sputum examination (procedure)
2,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-01-09,Bone immobilization
3,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-01-09,Bone density scan (procedure)
4,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-05-02,Documentation of current medications
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2280,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Chronic sinusitis (disorder),Outpatient Encounter,,Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2281,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Prediabetes,Outpatient Encounter,,Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2282,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Streptococcal sore throat (disorder),Encounter for symptom,Streptococcal sore throat (disorder),Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2283,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Acute viral pharyngitis (disorder),Encounter for symptom,Acute viral pharyngitis (disorder),Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications


## Isolate the fields that require balancing

- We'll balance "RACE", "ETHNICITY", and "GENDER"


In [4]:
a["RACE"].value_counts()


white       1616
black        332
hispanic     258
asian         79
Name: RACE, dtype: int64

In [5]:
a["ETHNICITY"].value_counts()


irish              426
german             267
african            266
french             240
italian            213
puerto_rican       175
english            162
polish             148
mexican             83
chinese             79
russian             78
swedish             50
dominican           35
west_indian         31
french_canadian     12
portuguese          12
american             8
Name: ETHNICITY, dtype: int64

In [6]:
a["GENDER"].value_counts()


F    1296
M     989
Name: GENDER, dtype: int64

## Create a seed file

- Create a csv with one column for each balance field and one record for each combination of the balance field values.
- Replicate the seeds to reach the desired synthetic data size.


In [7]:
import itertools

# Choose your balance columns
balance_columns = ["GENDER", "ETHNICITY", "RACE"]

# How many total synthetic records do you want
gen_lines = len(a)

# Get the list of values for each seed field and the
# overall percent we'll need for each seed value combination
categ_val_lists = []
seed_percent = 1
for field in balance_columns:
    values = set(pd.Series(a[field].dropna()))
    category_cnt = len(values)
    categ_val_lists.append(list(values))
    seed_percent = seed_percent * 1 / category_cnt
seed_gen_cnt = seed_percent * gen_lines

# Get the combo seeds we'll need. This is all combinations of all
# seed field values
seed_fields = []
for combo in itertools.product(*categ_val_lists):
    seed_dict = {}
    i = 0
    for field in balance_columns:
        seed_dict[field] = combo[i]
        i += 1
    seed = {}
    seed["seed"] = seed_dict
    seed["cnt"] = seed_gen_cnt
    seed_fields.append(seed)

# Create a dataframe with the seed values used to condition the synthetic model
gender_all = []
ethnicity_all = []
race_all = []
for seed in seed_fields:
    gender = seed["seed"]["GENDER"]
    ethnicity = seed["seed"]["ETHNICITY"]
    race = seed["seed"]["RACE"]
    cnt = seed["cnt"]
    for i in range(int(cnt)):
        gender_all.append(gender)
        ethnicity_all.append(ethnicity)
        race_all.append(race)

df_seed = pd.DataFrame(
    {"GENDER": gender_all, "ETHNICITY": ethnicity_all, "RACE": race_all}
)

# Save the seed dataframe to a file
seedfile = "/tmp/balance_seeds.csv"
df_seed.to_csv(seedfile, index=False, header=True)


## Create a synthetic config file


In [8]:
# Grab the default Synthetic Config file
from gretel_client.projects.models import read_model_config

config = read_model_config("synthetics/default")




In [9]:
# Adjust the desired number of synthetic records to generated

config["models"][0]["synthetics"]["generate"]["num_records"] = len(a)


In [10]:
# Adjust params for complex dataset

config["models"][0]["synthetics"]["params"]["data_upsample_limit"] = 10000


## Include a seeding task in the config


In [11]:
task = {"type": "seed", "attrs": {"fields": balance_columns}}
config["models"][0]["synthetics"]["task"] = task


## Train a synthetic model


In [12]:
from gretel_client import projects
from gretel_client.helpers import poll

training_path = "training_data.csv"
a.to_csv(training_path)

project = projects.create_or_get_unique_project(name="balancing-data-example")
model = project.create_model_obj(model_config=config, data_source=training_path)

model.submit_cloud()
poll(model)


INFO: Starting poller


{
    "uid": "641ecd2e7067bc6035077efb",
    "guid": "model_2NVD9qS29UcTWWtvqgDi8Eq0AWB",
    "model_name": "default-config",
    "runner_mode": "cloud",
    "user_id": "61779c3ebff62105d3757a71",
    "user_guid": "user_26hlyPRrQXap2t6NhfbC1G7JA0l",
    "billing_domain": null,
    "billing_domain_guid": null,
    "project_id": "641ecd27e5bc51838c29c8d0",
    "project_guid": "proj_2NVD8qcBYJeL5v7L2cvL7bzakJA",
    "status_history": {
        "created": "2023-03-25T10:30:06.390681Z"
    },
    "last_modified": "2023-03-25T10:30:06.480879Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "annotations": null,
    "container_image": "074762682575.dkr.ecr.us-west-2.amazonaws.com/gretelai/synthetics@sha256:0e0d8d352d355d498b9da449f6ffb4bb33e87f530380a98afe157718f66877d1",
    "container_image_version": "2.10.41",
    "model_type": "synthetics",
    "model_type_alias": null,
    "config"

INFO: Status is pending. A Gretel Cloud worker is being allocated to begin model creation.
INFO: Status is active. A worker has started creating your model!
2023-03-25T10:30:21.682115Z  Analyzing input data and checking for auto-params...
2023-03-25T10:30:37.541703Z  Starting synthetic model training
2023-03-25T10:30:37.543723Z  Loading training data
2023-03-25T10:30:37.551449Z  Running pre-flight data checks on input data

	1 field with missing data: 'EncounterReasonDesc' has more than 50% missing data. 
2023-03-25T10:30:39.003287Z  Training data loaded.
{
    "record_count": 2285,
    "field_count": 16,
    "upsample_count": 7715
}
2023-03-25T10:30:41.806432Z  Creating semantic validators and preparing training data
2023-03-25T10:30:51.451678Z  Beginning ML model training
2023-03-25T10:30:58.451084Z  Running training on 1 batches.
{
    "batch_sizes": "[16]"
}
2023-03-25T10:30:59.825654Z  Tokenizing input data
2023-03-25T10:31:00.418167Z  Shuffling input data
2023-03-25T10:31:02.0416

## Generate data using the balance seeds


In [13]:
rh = model.create_record_handler_obj(
    data_source=seedfile, params={"num_records": len(df_seed)}
)
rh.submit_cloud()
poll(rh)
synth_df = pd.read_csv(rh.get_artifact_link("data"), compression="gzip")
synth_df.head()


INFO: Starting poller


{
    "uid": "641ecf43e2cc0cb9e44697fc",
    "guid": "model_run_2NVEEjo59jBoOPq6KSh7TW3OYvn",
    "model_name": null,
    "runner_mode": "cloud",
    "user_id": "61779c3ebff62105d3757a71",
    "user_guid": "user_26hlyPRrQXap2t6NhfbC1G7JA0l",
    "billing_domain": null,
    "billing_domain_guid": null,
    "project_id": "641ecd27e5bc51838c29c8d0",
    "project_guid": "proj_2NVD8qcBYJeL5v7L2cvL7bzakJA",
    "status_history": {
        "created": "2023-03-25T10:38:59.555000Z"
    },
    "last_modified": "2023-03-25T10:38:59.695000Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "annotations": null,
    "container_image": "074762682575.dkr.ecr.us-west-2.amazonaws.com/gretelai/synthetics@sha256:0e0d8d352d355d498b9da449f6ffb4bb33e87f530380a98afe157718f66877d1",
    "container_image_version": "2.10.41",
    "model_id": "641ecd2e7067bc6035077efb",
    "model_guid": "model_2NVD9qS29UcTW

INFO: Status is created. A Record generation job has been queued.
INFO: Status is pending. A Gretel Cloud worker is being allocated to begin generating synthetic records.
INFO: Status is active. A worker has started!
2023-03-25T10:39:13.078068Z  Loading model to worker
2023-03-25T10:39:30.636239Z  Checking for synthetic smart seeds
2023-03-25T10:39:30.642935Z  Loaded 2176 smart seeds for generation
2023-03-25T10:39:30.643313Z  Loading model
2023-03-25T10:39:35.789825Z  LSTM model is available for generation.
2023-03-25T10:39:35.790235Z  Generating records
{
    "num_records": 2176
}
2023-03-25T10:39:40.799343Z  Generation in progress
{
    "current_valid_count": 26,
    "current_invalid_count": 1,
    "new_valid_count": 26,
    "new_invalid_count": 1,
    "completion_percent": 1.19
}
2023-03-25T10:39:45.806252Z  Generation in progress
{
    "current_valid_count": 64,
    "current_invalid_count": 2,
    "new_valid_count": 38,
    "new_invalid_count": 1,
    "completion_percent": 2.94
}


Unnamed: 0.1,Unnamed: 0,DATE,PatientID,BirthDate,RACE,ETHNICITY,GENDER,BIRTHPLACE,ConditionDesc,EncounterDesc,EncounterReasonDesc,ObsDesc,ObsValueFloat,UNITS,ProcedureDate,ProcedureDesc
0,1748,2015-11-29,1,1945-07-17,hispanic,mexican,M,Mashpee MA US,Rupture of patellar tendon,Emergency room admission,,Potassium,6.50738,mmol/L,2015-11-29,Documentation of current medications
1,2076,2010-06-11,31,1930-07-31,hispanic,mexican,M,Winchester MA US,Chronic sinusitis (disorder),Outpatient Encounter,,Potassium,6.54514,mmol/L,2011-06-24,Documentation of current medications
2,326,2013-02-25,40,1952-10-05,hispanic,mexican,M,Springfield MA US,Chronic sinusitis (disorder),Outpatient Encounter,,Potassium,6.50738,mmol/L,2016-05-29,Documentation of current medications
3,685,2011-09-18,14,1946-02-24,hispanic,mexican,M,Beverly MA US,Recurrent rectal polyp,Encounter for 'check-up',,Potassium,6.51996,mmol/L,2012-06-30,Measurement of respiratory function (procedure)
4,1914,2016-04-02,3,1971-06-20,hispanic,mexican,M,Boston MA US,Acute bronchitis (disorder),Encounter for symptom,Acute bronchitis (disorder),Potassium,6.50738,mmol/L,2016-04-02,Documentation of current medications


## Validate the balanced demographic data


In [14]:
synth_df["GENDER"].value_counts()


M    1088
F    1088
Name: GENDER, dtype: int64

In [15]:
synth_df["ETHNICITY"].value_counts()


mexican            128
swedish            128
puerto_rican       128
german             128
irish              128
chinese            128
italian            128
french_canadian    128
portuguese         128
russian            128
english            128
african            128
west_indian        128
dominican          128
french             128
american           128
polish             128
Name: ETHNICITY, dtype: int64

In [16]:
synth_df["RACE"].value_counts()


hispanic    544
black       544
asian       544
white       544
Name: RACE, dtype: int64