<a target="_blank" href="https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/docs/notebooks/demo/demo-gretel-conditional-generation.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Setup and Installation

This section installs required python and system dependencies for the notebook to run, and then it creates a session with the Gretel API endpoint so that we can communicate with Gretel Cloud. Learn more in our documentation covering [environment setup](https://docs.gretel.ai/guides/environment-setup/cli-and-sdk).

In [None]:
%%capture
!pip install -U gretel-client

## Gretel Setup
Set up the Gretel API connection

In [None]:
from getpass import getpass
from gretel_client import configure_session

gretel_endpoint = "https://api.gretel.cloud"
gretel_api_key = getpass("API Key: ")

configure_session(
    api_key=gretel_api_key,
    endpoint=gretel_endpoint,
    validate=True,
    clear=True,
)

## Fetch and prepare data
Read in the dataset as a Gretel Relational object

In [None]:
import pandas as pd
import numpy as np

DATA_PATH = "https://gretel-datasets.s3.us-west-2.amazonaws.com/rossman_store_sales/train_50k.csv"
data_source = pd.read_csv(DATA_PATH)
display(data_source.sample(n=10))

## Train Gretel-ACTGAN model on data
Note that in this example we will use our tabular-actgan model as it support conditional data generation.

In [None]:
# Configure a Gretel session

from gretel_client import configure_session
from gretel_client.projects import create_or_get_unique_project

GRETEL_PROJECT_NAME = 'demo-conditional-generation'

configure_session(
    api_key="prompt",
    endpoint="https://api-dev.gretel.cloud",
    validate=True,
    clear=True,
)

project = create_or_get_unique_project(name=GRETEL_PROJECT_NAME)


## Synthesize data using Gretel Synthetics

In [None]:
from gretel_client.projects.models import read_model_config

# Load and modify Gretel Actan config
config = read_model_config("synthetics/tabular-actgan")

# conditional generation
config['models'][0]['actgan']['params']['conditional_vector_type'] = "anyway"
config['models'][0]['actgan']['params']['conditional_select_mean_columns'] = 2
config['models'][0]['actgan']['params']['reconstruction_loss_coef'] = 10.0
config['models'][0]['actgan']['params']['force_conditioning'] = True

In [None]:
from gretel_client.helpers import poll

# Train model 
model = project.create_model_obj(
    model_config=config, data_source=data_source
)
model.submit_cloud()

poll(model, verbose=False)

## Accessing Project and Model Files
All of the project artifacts can be found in the Gretel Console using the below URL.

In [None]:
console_url = f"https://console.gretel.ai/{project.project_guid}"
print(console_url)

## Use Case 1: Unconditional Synthetic data generation

By default, Gretel will generate synthetic data that has similar properties as the source data.

In [None]:
NUMBER_OF_RECORDS = len(data_source)

rh = model.create_record_handler_obj(
    params={"num_records": NUMBER_OF_RECORDS}
    )
rh.submit_cloud()
poll(rh)

synth_df = pd.read_csv(rh.get_artifact_link("data"), compression="gzip")

In [None]:
print("Source data")
display(data_source.head())

print("Synthetic data")
display(synth_df.head())

## Use Case 2: Balance data in a column

Conditionally generate more data to balance across a specific column

In [None]:
# Inspect synthetic data

import matplotlib.pyplot as plt

SEED_COLUMN = "DayOfWeek"
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
data_source[SEED_COLUMN].value_counts().plot(kind="barh", rot=90, grid=True, ax=ax)
ax.set_xlim(7000,7500)

In [None]:
col_value_counts = data_source[SEED_COLUMN].value_counts()

seed_data = []
for (line, value) in zip(col_value_counts.index, max(col_value_counts)-col_value_counts.values):
    if value > 0:
        seed_data += [line]*value  
seeds = pd.DataFrame(data=seed_data, columns=[SEED_COLUMN])

rh = model.create_record_handler_obj(data_source=seeds)
rh.submit_cloud()
poll(rh)

synth_df = pd.read_csv(rh.get_artifact_link("data"), compression="gzip")

In [None]:
# Inspect synthetic data
df_balanced = pd.concat([data_source, synth_df])

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
df_join = pd.DataFrame(np.column_stack((data_source[SEED_COLUMN].value_counts().values, df_balanced[SEED_COLUMN].value_counts())), columns=['Source','Source+Synthetic'])
df_join[SEED_COLUMN] = data_source[SEED_COLUMN].value_counts().index
df_join.plot(x=SEED_COLUMN, y=["Source", "Source+Synthetic"], kind="barh", rot=0, grid=True, ax=ax)
ax.set_xlim(7000,7500)

# Use Case 3: Simulate a boost in Store types

Let's simulate a boost in one of the field values.

In [None]:
SEED_COLUMN = 'StoreType'

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
data_source[SEED_COLUMN].value_counts().plot(kind="barh", rot=0, grid=True, ax=ax)

In [None]:
SEED_VALUE = 'c'
SEED_RECORDS = 10000

seeds = pd.DataFrame(data=[SEED_VALUE] * SEED_RECORDS, columns=[SEED_COLUMN])
rh = model.create_record_handler_obj(data_source=seeds)
rh.submit_cloud()
poll(rh)

synth_df = pd.read_csv(rh.get_artifact_link("data"), compression="gzip")

In [None]:
# Inspect synthetic data

df_augmented = pd.concat([data_source, synth_df])

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
df_join = pd.DataFrame(np.column_stack((data_source[SEED_COLUMN].value_counts().sort_index(), df_augmented[SEED_COLUMN].value_counts().sort_index())), columns=['Source','Source+Synthetic'])
df_join[SEED_COLUMN] = data_source[SEED_COLUMN].value_counts().sort_index().index
df_join.plot(x=SEED_COLUMN, y=["Source", "Source+Synthetic"], kind="barh", rot=0, grid=True, ax=ax)

In [None]:
print("Synthetic data")

display(synth_df.head())

# Use Case 4: How to enhance data through augmentation

In [None]:
import random
from sklearn.utils import resample

SEED_RECORDS = 100
COLUMN1 = "Sales"
COLUMN2 = "Customers"
df_sel = data_source[[COLUMN1, COLUMN2]]

# Sales
min_col1_value = 20_000
max_col1_value = 25_000
df_sel = df_sel[(df_sel[COLUMN1] >= min_col1_value) & (df_sel[COLUMN1] <= max_col1_value)]
print(len(df_sel))

# Customers
min_col2_value = 2000
max_col2_value = 3000
df_sel = df_sel[(df_sel[COLUMN2] >= min_col2_value) & (df_sel[COLUMN2] <= max_col2_value)]

seeds = resample(df_sel, replace=True, n_samples=SEED_RECORDS)

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
data_source.plot.scatter(x=COLUMN1, y=COLUMN2, c='DarkBlue', grid=True, ax=ax)
seeds.plot.scatter(x=COLUMN1, y=COLUMN2, c='DarkRed', grid=True, ax=ax)

In [None]:
# Conditionally generate

rh = model.create_record_handler_obj(data_source=seeds)
rh.submit_cloud()
poll(rh)

synth_df = pd.read_csv(rh.get_artifact_link("data"), compression="gzip")

In [None]:
print("Synthetic data")

display(synth_df.sample(n=5))