<a target="_blank" href="https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/docs/notebooks/demo/gretel-demo-conditional-generation.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In this notebook we demonstrate how you can leverage conditional data generation using the Gretel-Actgan to simulate various use cases and test scenarios on your data, such as balancing records among categories, introducing a spike in a categorical field or generating synthetic data within a specific range for multiple numerical columns.

## Setup and Installation

This section installs required python and system dependencies for the notebook to run, and then it creates a session with the Gretel API endpoint so that we can communicate with Gretel Cloud. Learn more in our documentation covering [environment setup](https://docs.gretel.ai/guides/environment-setup/cli-and-sdk).

In [None]:
%%capture
!pip install -U gretel-client

## Fetch and prepare data
Load and display the source data.

In [None]:
import pandas as pd
import numpy as np

DATA_PATH = "https://gretel-datasets.s3.us-west-2.amazonaws.com/rossman_store_sales/train_50k.csv"
data_source = pd.read_csv(DATA_PATH)
display(data_source.sample(n=10))

## Configure your Gretel session

Set up the Gretel API connection

In [None]:
from gretel_client import Gretel

GRETEL_PROJECT = 'demo-conditional-generation'

gretel = Gretel(project_name=GRETEL_PROJECT, api_key="prompt", validate=True)

## Train Gretel-ACTGAN model on data
Note that in this example we will use our tabular-actgan model as it support conditional data generation.

In [None]:
trained = gretel.submit_train(
    "tabular-actgan", 
    data_source = data_source,
    params = {
        "conditional_vector_type": "anyway",
        "conditional_select_mean_columns": 2,
        "reconstruction_loss_coef": 10.0,
        "force_conditioning": True,
    }
)

## Use Case 1: Unconditional Synthetic data generation

By default, Gretel will generate synthetic data that has similar properties as the source data.

In [None]:
# Unconditional generation

NUMBER_OF_RECORDS = len(data_source)

generated = gretel.submit_generate(trained.model_id, num_records=NUMBER_OF_RECORDS)

In [None]:
# Inspect the synthetic data

print("Source data")
display(data_source.head())

print("Synthetic data")
display(generated.synthetic_data.head())

## Use Case 2: Balance data in a column

In this example, we use conditional generation to balance across a the categorical column `DayOfWeek` to establish a balance accross weekdays after augmenting the synthetic records with the original data.

In [None]:
# Inspect synthetic data

import matplotlib.pyplot as plt

SEED_COLUMN = "DayOfWeek"
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
data_source[SEED_COLUMN].value_counts().plot(kind="barh", rot=90, grid=True, ax=ax)
ax.set_xlim(7000,7500)


In [None]:
# Conditional generation

# First, let's create seed data that will produce an equal number of records for each day of the week
n_rows_to_add = data_source[SEED_COLUMN].value_counts().max() - data_source[SEED_COLUMN].value_counts()
seed_data = sum([[day] * n_rows for day, n_rows in n_rows_to_add.items()], [])
seeds = pd.DataFrame(data=seed_data, columns=[SEED_COLUMN])

generated = gretel.submit_generate(trained.model_id, seed_data=seeds)

In [None]:
# Inspect synthetic data

df_balanced = pd.concat([data_source, generated.synthetic_data])

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
df_join = pd.DataFrame(np.column_stack((data_source[SEED_COLUMN].value_counts().values, df_balanced[SEED_COLUMN].value_counts())), columns=['Source','Source+Synthetic'])
df_join[SEED_COLUMN] = data_source[SEED_COLUMN].value_counts().index
df_join.plot(x=SEED_COLUMN, y=["Source", "Source+Synthetic"], kind="barh", rot=0, grid=True, ax=ax)
ax.set_xlim(7000,7500)

# Use Case 3: Simulate a boost in Store types

Now we will simulate a boost in one of the categories of the `StoreType` column.

In [None]:
SEED_COLUMN = 'StoreType'

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
data_source[SEED_COLUMN].value_counts().plot(kind="barh", rot=0, grid=True, ax=ax)

In [None]:
# Conditional generation

SEED_VALUE = 'c'
SEED_RECORDS = 10000
seeds = pd.DataFrame(data=[SEED_VALUE] * SEED_RECORDS, columns=[SEED_COLUMN])

generated = gretel.submit_generate(trained.model_id, seed_data=seeds)

In [None]:
# Inspect the synthetic data

print("Synthetic data")
display(generated.synthetic_data.head())

In [None]:
# Show class balance

df_augmented = pd.concat([data_source, generated.synthetic_data])

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
df_join = pd.DataFrame(np.column_stack((data_source[SEED_COLUMN].value_counts().sort_index(), df_augmented[SEED_COLUMN].value_counts().sort_index())), columns=['Source','Source+Synthetic'])
df_join[SEED_COLUMN] = data_source[SEED_COLUMN].value_counts().sort_index().index
df_join.plot(x=SEED_COLUMN, y=["Source", "Source+Synthetic"], kind="barh", rot=0, grid=True, ax=ax)

# Use Case 4: How to enhance data through augmentation

Finally, we will use conditional generation to create more data points within a specific range of the numerical fields `Sales` and `Customers`.

In [None]:
import random
from sklearn.utils import resample

SEED_RECORDS = 100
COLUMN1 = "Sales"
COLUMN2 = "Customers"
df_sel = data_source[[COLUMN1, COLUMN2]]

# Sales
min_col1_value = 20_000
max_col1_value = 25_000
df_sel = df_sel[(df_sel[COLUMN1] >= min_col1_value) & (df_sel[COLUMN1] <= max_col1_value)]
print(len(df_sel))

# Customers
min_col2_value = 2000
max_col2_value = 3000
df_sel = df_sel[(df_sel[COLUMN2] >= min_col2_value) & (df_sel[COLUMN2] <= max_col2_value)]

seeds = resample(df_sel, replace=True, n_samples=SEED_RECORDS)

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
data_source.plot.scatter(x=COLUMN1, y=COLUMN2, c='DarkBlue', grid=True, ax=ax)
seeds.plot.scatter(x=COLUMN1, y=COLUMN2, c='DarkRed', grid=True, ax=ax)

In [None]:
# Conditional generation

generated = gretel.submit_generate(trained.model_id, seed_data=seeds)

In [None]:
# Inspect the synthetic data

print("Synthetic data")
display(generated.synthetic_data.sample(n=5))