# Balancing datasets with conditional data generation

Imbalanced datasets are a common problem in machine learning. There are several different scenarios where an imbalanced dataset can lead to a less than optimal model solution. One scenario is when you're training a multi-class classifier and one or more of the classes have fewer training examples than the others. This can sometimes lead to a model that may look like it's doing well overall,when really the accuracy of the underepresented classes is inferior to that of the classes with good representation.

Another scenario is when the training data has imbalanced demographic data. Part of what the Fair AI movement is about is ensuring that AI models do equally well on all demographic slices.

One approach to improve representational biases in data is through by conditioning Gretel's synthetic data model to generate more examples of different classes of data.

You can use the approach to replace the original data with a balanced synthetic dataset or you can use it to augment the existing dataset, producing just enough synthetic data such that when added back into the original data, the imbalance is resolved.

In this notebook, we're going to step you through how to use Gretel synthetics to resolve demographic bias in a dataset. We will be creating a new synthetic dataset that can be used in place of the original one.


## Begin by authenticating


In [None]:
%%capture
!pip install -U gretel-client

In [1]:
# Specify your Gretel API key

import pandas as pd
from gretel_client import configure_session

pd.set_option("max_colwidth", None)

configure_session(api_key="prompt", cache="yes", validate=True)


Found cached Gretel credentials
Using endpoint https://api-dev.gretel.cloud
Logged in as drew@gretel.ai ✅


## Load and view the dataset


In [17]:
a = pd.read_csv(
    "https://gretel-public-website.s3.amazonaws.com/datasets/experiments/healthcare_dataset_a.csv"
)

a


Unnamed: 0,DATE,PatientID,BirthDate,RACE,ETHNICITY,GENDER,BIRTHPLACE,ConditionDesc,EncounterDesc,EncounterReasonDesc,ObsDesc,ObsValueFloat,UNITS,ProcedureDate,ProcedureDesc
0,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2010-06-28,Measurement of respiratory function (procedure)
1,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2012-01-31,Sputum examination (procedure)
2,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-01-09,Bone immobilization
3,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-01-09,Bone density scan (procedure)
4,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-05-02,Documentation of current medications
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2280,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Chronic sinusitis (disorder),Outpatient Encounter,,Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2281,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Prediabetes,Outpatient Encounter,,Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2282,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Streptococcal sore throat (disorder),Encounter for symptom,Streptococcal sore throat (disorder),Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2283,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Acute viral pharyngitis (disorder),Encounter for symptom,Acute viral pharyngitis (disorder),Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications


## Isolate the fields that require balancing

- We'll balance "RACE", "ETHNICITY", and "GENDER"


In [18]:
a["RACE"].value_counts()


white       1616
black        332
hispanic     258
asian         79
Name: RACE, dtype: int64

In [19]:
a["ETHNICITY"].value_counts()


irish              426
german             267
african            266
french             240
italian            213
puerto_rican       175
english            162
polish             148
mexican             83
chinese             79
russian             78
swedish             50
dominican           35
west_indian         31
french_canadian     12
portuguese          12
american             8
Name: ETHNICITY, dtype: int64

In [20]:
a["GENDER"].value_counts()


F    1296
M     989
Name: GENDER, dtype: int64

## Create a seed file

- Create a csv with one column for each balance field and one record for each combination of the balance field values.
- Replicate the seeds to reach the desired synthetic data size.


In [21]:
import itertools

# Choose your balance columns
balance_columns = ["GENDER", "ETHNICITY", "RACE"]

# How many total synthetic records do you want
gen_lines = len(a)

# Get the list of values for each seed field and the
# overall percent we'll need for each seed value combination
categ_val_lists = []
seed_percent = 1
for field in balance_columns:
    values = set(pd.Series(a[field].dropna()))
    category_cnt = len(values)
    categ_val_lists.append(list(values))
    seed_percent = seed_percent * 1 / category_cnt
seed_gen_cnt = seed_percent * gen_lines

# Get the combo seeds we'll need. This is all combinations of all
# seed field values
seed_fields = []
for combo in itertools.product(*categ_val_lists):
    seed_dict = {}
    i = 0
    for field in balance_columns:
        seed_dict[field] = combo[i]
        i += 1
    seed = {}
    seed["seed"] = seed_dict
    seed["cnt"] = seed_gen_cnt
    seed_fields.append(seed)

# Create a dataframe with the seed values used to condition the synthetic model
gender_all = []
ethnicity_all = []
race_all = []
for seed in seed_fields:
    gender = seed["seed"]["GENDER"]
    ethnicity = seed["seed"]["ETHNICITY"]
    race = seed["seed"]["RACE"]
    cnt = seed["cnt"]
    for i in range(int(cnt)):
        gender_all.append(gender)
        ethnicity_all.append(ethnicity)
        race_all.append(race)

df_seed = pd.DataFrame(
    {"GENDER": gender_all, "ETHNICITY": ethnicity_all, "RACE": race_all}
)

# Save the seed dataframe to a file
seedfile = "/tmp/balance_seeds.csv"
df_seed.to_csv(seedfile, index=False, header=True)


## Create a synthetic config file


In [3]:
# Grab the default Synthetic Config file
from gretel_client.projects.models import read_model_config

config = read_model_config("synthetics/default")


In [23]:
# Adjust the desired number of synthetic records to generated

config["models"][0]["synthetics"]["generate"]["num_records"] = len(a)


In [24]:
# Adjust params for complex dataset

config["models"][0]["synthetics"]["params"]["data_upsample_limit"] = 10000


## Include a seeding task in the config


In [25]:
task = {"type": "seed", "attrs": {"fields": balance_columns}}
config["models"][0]["synthetics"]["task"] = task


## Train a synthetic model


In [26]:
from gretel_client import projects
from gretel_client.helpers import poll

training_path = "training_data.csv"
a.to_csv(training_path)

project = projects.create_or_get_unique_project(name="balancing-data-example")
model = project.create_model_obj(model_config=config, data_source=training_path)

model.submit_cloud()
poll(model)


[32mINFO: [0mStarting poller


{
    "uid": "624e2f167bb99b28a1267855",
    "guid": "model_27Rly8B14FXXQfotbGbMNVO5HvJ",
    "model_name": "awesome-mysterious-snake",
    "runner_mode": "cloud",
    "user_id": "5f28907abff6215b847fea12",
    "user_guid": "user_26hlXr8FAOILKp6R3okyfJy7iPZ",
    "billing_domain": "gretel.ai",
    "billing_domain_guid": null,
    "project_id": "624dd337297bcee982e8e562",
    "project_guid": "proj_27R0IXq18YG8WQlPZ6ZYh8fGNcb",
    "status_history": {
        "created": "2022-04-07T00:23:50.957625Z"
    },
    "last_modified": "2022-04-07T00:23:51.202454Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "container_image": "074762682575.dkr.ecr.us-west-2.amazonaws.com/gretelai/synthetics@sha256:1006058feecf797de94a83dcc54d07265b53d4142f073a3f09902383368006c4",
    "model_type": "synthetics",
    "config": {
        "schema_version": "1.0",
        "name": null,
        "models": [
 

[32mINFO: [0mStatus is created. Model creation has been queued.
[32mINFO: [0mStatus is pending. A Gretel Cloud worker is being allocated to begin model creation.


## Generate data using the balance seeds


In [None]:
rh = model.create_record_handler_obj(
    data_source=seedfile, params={"num_records": len(df_seed)}
)
rh.submit_cloud()
poll(rh)
synth_df = pd.read_csv(rh.get_artifact_link("data"), compression="gzip")
synth_df.head()


## Validate the balanced demographic data


In [None]:
synth_df["GENDER"].value_counts()


In [None]:
synth_df["ETHNICITY"].value_counts()


In [None]:
synth_df["RACE"].value_counts()
