# Balancing datasets with conditional data generation
Imbalanced datasets are a common problem in machine learning. There are several different scenarios where an imbalanced dataset can lead to a less than optimal model solution. One scenario is when you're training a multi-class classifier and one or more of the classes have fewer training examples than the others. This can sometimes lead to a model that may look like it's doing well overall,when really the accuracy of the underepresented classes is inferior to that of the classes with good representation.

Another scenario is when the training data has imbalanced demographic data. Part of what the Fair AI movement is about is ensuring that AI models do equally well on all demographic slices.

One approach to improve representational biases in data is through by conditioning Gretel's synthetic data model to generate more examples of different classes of data.

You can use the approach to replace the original data with a balanced synthetic dataset or you can use it to augment the existing dataset, producing just enough synthetic data such that when added back into the original data, the imbalance is resolved.

In this notebook, we're going to step you through how to use Gretel synthetics to resolve demographic bias in a dataset. We will be creating a new synthetic dataset that can be used in place of the original one.


## Begin by authenticating

In [None]:
%%capture
!pip install -U gretel-client

In [None]:
# Specify your Gretel API key

import pandas as pd
from gretel_client import configure_session

pd.set_option('max_colwidth', None)

configure_session(api_key="prompt", cache="yes", validate=True)

## Load and view the dataset

In [None]:
a = pd.read_csv('https://gretel-public-website.s3.amazonaws.com/datasets/experiments/healthcare_dataset_a.csv')

a

Unnamed: 0,DATE,PatientID,BirthDate,RACE,ETHNICITY,GENDER,BIRTHPLACE,ConditionDesc,EncounterDesc,EncounterReasonDesc,ObsDesc,ObsValueFloat,UNITS,ProcedureDate,ProcedureDesc
0,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2010-06-28,Measurement of respiratory function (procedure)
1,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2012-01-31,Sputum examination (procedure)
2,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-01-09,Bone immobilization
3,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-01-09,Bone density scan (procedure)
4,2013-05-02,34,1940-04-24,white,irish,F,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2013-05-02,Documentation of current medications
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2280,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Chronic sinusitis (disorder),Outpatient Encounter,,Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2281,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Prediabetes,Outpatient Encounter,,Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2282,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Streptococcal sore throat (disorder),Encounter for symptom,Streptococcal sore throat (disorder),Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications
2283,2015-11-05,0,1956-03-31,white,swedish,M,Lanesborough MA US,Acute viral pharyngitis (disorder),Encounter for symptom,Acute viral pharyngitis (disorder),Potassium,6.50738,mmol/L,2016-11-18,Documentation of current medications


## Isolate the fields that require balancing
* We'll balance "RACE", "ETHNICITY", and "GENDER"

In [None]:
a["RACE"].value_counts()

white       1616
black        332
hispanic     258
asian         79
Name: RACE, dtype: int64

In [None]:
a["ETHNICITY"].value_counts()

irish              426
german             267
african            266
french             240
italian            213
puerto_rican       175
english            162
polish             148
mexican             83
chinese             79
russian             78
swedish             50
dominican           35
west_indian         31
french_canadian     12
portuguese          12
american             8
Name: ETHNICITY, dtype: int64

In [None]:
a["GENDER"].value_counts()

F    1296
M     989
Name: GENDER, dtype: int64

## Create a seed file
* Create a csv with one column for each balance field and one record for each combination of the balance field values.  
* Replicate the seeds to reach the desired synthetic data size.

In [None]:
import itertools

# Choose your balance columns
balance_columns = ["GENDER","ETHNICITY","RACE"]

# How many total synthetic records do you want
gen_lines = len(a)

# Get the list of values for each seed field and the
# overall percent we'll need for each seed value combination
categ_val_lists = []
seed_percent = 1
for field in balance_columns:
    values = set(pd.Series(a[field].dropna()))
    category_cnt = len(values)
    categ_val_lists.append(list(values))
    seed_percent = seed_percent * 1/category_cnt  
seed_gen_cnt = seed_percent * gen_lines

# Get the combo seeds we'll need. This is all combinations of all
# seed field values
seed_fields = []
for combo in itertools.product(*categ_val_lists):
    seed_dict = {}
    i = 0
    for field in balance_columns:
        seed_dict[field] = combo[i]
        i += 1
    seed = {}
    seed["seed"] = seed_dict
    seed["cnt"] = seed_gen_cnt
    seed_fields.append(seed)

# Create a dataframe with the seed values used to condition the synthetic model
gender_all = []
ethnicity_all = []
race_all = []
for seed in seed_fields:
    gender = seed["seed"]["GENDER"]
    ethnicity = seed["seed"]["ETHNICITY"]
    race = seed["seed"]["RACE"]
    cnt = seed["cnt"]
    for i in range(int(cnt)):
        gender_all.append(gender)
        ethnicity_all.append(ethnicity)
        race_all.append(race)
    
df_seed = pd.DataFrame({"GENDER": gender_all, "ETHNICITY": ethnicity_all, "RACE": race_all})

# Save the seed dataframe to a file
seedfile = "/tmp/balance_seeds.csv"
df_seed.to_csv(seedfile, index=False, header=True)

## Create a synthetic config file

In [None]:
# Grab the default Synthetic Config file

from smart_open import open
import yaml

with open("https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/config_templates/gretel/synthetics/default.yml", 'r') as stream:
    config = yaml.safe_load(stream)

In [None]:
# Adjust the desired number of synthetic records to generated

config['models'][0]['synthetics']['generate']['num_records'] = len(a)

In [None]:
# Adjust params for complex dataset

config['models'][0]['synthetics']['params']['data_upsample_limit'] = 10000


## Include a seeding task in the config

In [None]:
task = {
    'type': 'seed',
    'attrs': {
        'fields': balance_columns
    }
}
config['models'][0]['synthetics']['task'] = task

## Train a synthetic model

In [None]:
from gretel_client import projects
from gretel_client.helpers import poll

training_path = 'training_data.csv'
a.to_csv(training_path)

project = projects.create_or_get_unique_project(name='balancing-data-example')
model = project.create_model_obj(model_config=config, data_source=training_path)

model.submit_cloud()
poll(model)


[32mINFO: [0mStarting poller


{
    "uid": "61bb99bc4e42e9e012d0f358",
    "model_name": "marvelous-talented-skunk",
    "runner_mode": "cloud",
    "user_id": "5f45aedbbff62139017abfeb",
    "project_id": "61bb99b628c6db305b210906",
    "status_history": {
        "created": "2021-12-16T19:55:40.069638Z"
    },
    "last_modified": "2021-12-16T19:55:40.265795Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "container_image": "074762682575.dkr.ecr.us-west-2.amazonaws.com/gretelai/synthetics@sha256:58fa70040c5cf1c5105820b35de35ba0b7d2c3a07ce84b6ac41ddad94933143f",
    "model_type": "synthetics",
    "config": {
        "schema_version": "1.0",
        "name": null,
        "models": [
            {
                "synthetics": {
                    "params": {
                        "field_delimiter": null,
                        "epochs": 100,
                        "batch_size": 64,
                   

[32mINFO: [0mStatus is pending. A Gretel Cloud worker is being allocated to begin model creation.


## Generate data using the balance seeds

In [None]:
rh = model.create_record_handler_obj(data_source=seedfile, params={"num_records": len(df_seed)})
rh.submit_cloud()
poll(rh)
synth_df = pd.read_csv(rh.get_artifact_link("data"), compression='gzip')
synth_df.head()

[32mINFO: [0mStarting poller


{
    "uid": "61bb962034de0442ad49f2de",
    "model_name": null,
    "runner_mode": "cloud",
    "user_id": "5f3c3afbbff62139634c66ca",
    "project_id": "61b93765539a405c2032dc0e",
    "status_history": {
        "created": "2021-12-16T19:40:15.967000Z"
    },
    "last_modified": "2021-12-16T19:40:16.108000Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "container_image": "074762682575.dkr.ecr.us-west-2.amazonaws.com/gretelai/synthetics@sha256:58fa70040c5cf1c5105820b35de35ba0b7d2c3a07ce84b6ac41ddad94933143f",
    "model_id": "61bb9458b270de2538a554f7",
    "action": "generate",
    "config": {
        "data_source": "gretel_21138d722eef44be955910eae63e7a70_balance_seeds.csv",
        "params": {
            "num_records": 2176,
            "max_invalid": 4352
        }
    }
}


[32mINFO: [0mStatus is created. A Record generation job has been queued.
[32mINFO: [0mStatus is pending. A Gretel Cloud worker is being allocated to begin generating synthetic records.
[32mINFO: [0mStatus is active. A worker has started!
2021-12-16T19:40:32.721531Z  Loading model to worker
2021-12-16T19:40:33.210957Z  Checking for synthetic smart seeds
2021-12-16T19:40:33.332231Z  Loaded 2176 smart seeds for generation
2021-12-16T19:40:33.333067Z  Loading model
2021-12-16T19:40:35.780181Z  Generating records
{
    "num_records": 2176
}
2021-12-16T19:40:40.789659Z  Generation in progress
{
    "current_valid_count": 41,
    "current_invalid_count": 3,
    "new_valid_count": 41,
    "new_invalid_count": 3,
    "completion_percent": 1.88
}
2021-12-16T19:40:45.796923Z  Generation in progress
{
    "current_valid_count": 94,
    "current_invalid_count": 5,
    "new_valid_count": 53,
    "new_invalid_count": 2,
    "completion_percent": 4.32
}
2021-12-16T19:40:50.803249Z  Generation in

Unnamed: 0.1,Unnamed: 0,DATE,PatientID,BirthDate,RACE,ETHNICITY,GENDER,BIRTHPLACE,ConditionDesc,EncounterDesc,EncounterReasonDesc,ObsDesc,ObsValueFloat,UNITS,ProcedureDate,ProcedureDesc
0,5,2013-05-02,34,1940-04-24,white,french,M,Braintree Town MA US,Viral sinusitis (disorder),Encounter for symptom,Viral sinusitis (disorder),Potassium,6.53255,mmol/L,2015-03-23,Bone immobilization
1,1179,2010-11-27,30,1984-03-09,white,french,M,Fitchburg MA US,Normal pregnancy,Prenatal visit,Normal pregnancy,Potassium,6.50738,mmol/L,2011-01-13,Injection of tetanus antitoxin
2,1222,2014-09-06,7,1966-03-02,white,french,M,West Boylston MA US,Laceration of hand,Emergency room admission,,Potassium,6.54514,mmol/L,2016-08-07,Replacement of contraceptive intrauterine device
3,1865,2016-10-26,33,1974-11-20,white,french,M,West Springfield Town MA US,Asthma,Prenatal visit,Normal pregnancy,Potassium,6.51996,mmol/L,2011-12-22,Bilateral tubal ligation
4,2280,2014-02-25,42,1955-07-01,white,french,M,Framingham MA US,Prediabetes,Outpatient Encounter,,Potassium,6.51996,mmol/L,2012-12-25,Documentation of current medications


## Validate the balanced demographic data

In [None]:
synth_df["GENDER"].value_counts()

F    1088
M    1088
Name: GENDER, dtype: int64

In [None]:
synth_df["ETHNICITY"].value_counts()

italian            128
american           128
dominican          128
russian            128
mexican            128
polish             128
french_canadian    128
portuguese         128
german             128
west_indian        128
french             128
african            128
swedish            128
puerto_rican       128
chinese            128
irish              128
english            128
Name: ETHNICITY, dtype: int64

In [None]:
synth_df["RACE"].value_counts()

white       544
asian       544
hispanic    544
black       544
Name: RACE, dtype: int64