<br>

<center><a href=https://gretel.ai/><img src="https://global-uploads.webflow.com/5ea8b9202fac2ea6211667a4/62dae7c82eb3a22ac4bd415e_gretel.ai%20logo.svg" alt="Gretel" width="350"/></a></center>

<br>

## Welcome to the Gretel Advanced Tabular Blueprint!  

In this Blueprint, we will demonstrate two advanced usages of the high-level `Gretel` interface:
1. Customizing model configurations via keyword arguments in the `submit_train` method.

2. Conditionally generate synthetic data using the `seed_data` parameter of the `submit_generate` method.

## In the right place?

If this is your first time using Gretel, we recommend starting with our [Gretel 101 Blueprint](https://colab.research.google.com/drive/1HU22bRwRnXsOtgeQ93HzVdj7WV32TjmG?usp=sharing).

**Note:** You will need a [free Gretel account](https://console.gretel.ai/) to run this notebook.


<br>

#### Ready? Let's go 🚀

## 💾 Install gretel-client and its dependencies

In [None]:
%%capture
!pip install gretel-client

## 🛜 Configure your Gretel session

- Each `Gretel` instance is bound to a single [Gretel project](https://docs.gretel.ai/guides/gretel-fundamentals/projects).  

- You can set the project name at instantiation, or you can use the `set_project` method.

- If you do not set the project, a random project will be created with your first job submission.


- You can retrieve your API key [here](https://console.gretel.ai/users/me/key).

In [None]:
from gretel_client import Gretel

gretel = Gretel(project_name="advanced-usage", api_key="prompt", validate=True)

In [None]:
# @title 🗂️ Set the dataset path

dataset_path_dict = {
    "adult income in the USA (14000 records, 15 fields)": "https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/sample_data/us-adult-income.csv",
}
dataset = "adult income in the USA (14000 records, 15 fields)" # @param ["adult income in the USA (14000 records, 15 fields)"]
dataset = dataset_path_dict[dataset]

# @markdown

# @markdown - This Blueprint uses a sample of the [UCI adult income dataset](https://archive.ics.uci.edu/dataset/2/adult).

# @markdown - Run this cell to set `dataset` to [the sample data path](https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/sample_data/us-adult-income.csv).

# @markdown <br>

# @markdown ##### Preview of records:

# @markdown |age|workclass|fnlwgt|education|education\_num|marital\_status|occupation|relationship|
# @markdown |---|---|---|---|---|---|---|---|
# @markdown |33|Private|229051|Some-college|10|Never-married|Prof-specialty|Not-in-family|
# @markdown |38|Local-gov|91711|Bachelors|13|Married-civ-spouse|Prof-specialty|Husband|
# @markdown |56|Private|282023|HS-grad|9|Married-civ-spouse|Adm-clerical|Husband|
# @markdown |32|Private|209538|Masters|14|Married-civ-spouse|Exec-managerial|Husband|
# @markdown |34|Self-emp-inc|215382|Masters|14|Separated|Prof-specialty|Not-in-family|
# @markdown <br>

# @markdown ##### Preview of records (continued):

# @markdown |race|gender|capital\_gain|capital\_loss|hours\_per\_week|native\_country|income\_bracket|
# @markdown |---|---|---|---|---|---|---|
# @markdown |White|Male|0|0|52|United-States|\<=50K|
# @markdown |White|Male|0|0|50|United-States|\>50K|
# @markdown |White|Male|0|0|40|United-States|\<=50K|
# @markdown |White|Male|0|0|55|United-States|\>50K|
# @markdown |White|Female|4787|0|40|United-States|\>50K|

## 🏗️ Train Gretel's ACTGAN with a **custom configuration**

Here is a [base yaml configuration for ACTGAN](https://github.com/gretelai/gretel-blueprints/blob/main/config_templates/gretel/synthetics/tabular-actgan.yml), which we select using `base_config="tabular-actgan"` in the `submit_train` method:

```yaml
schema_version: "1.0"
name: "tabular-actgan"
models:
  - actgan:
        data_source: __tmp__
        params:
            epochs: auto
            generator_dim: [1024, 1024]
            discriminator_dim: [1024, 1024]
            generator_lr: 0.0001
            discriminator_lr: .00033
            batch_size: auto
            auto_transform_datetimes: False
        generate:
            num_records: 5000
        privacy_filters:
            outliers: null
            similarity: null
```

- You can customize this configuration using **keyword arguments** in the `submit_train` method.

- The keywords can be any of the sections under the model. In this case `params`, `generate`, or `privacy_filters`.

- The values must be a dictionary with parameters from the associated section, as demonstrated below.

- **Tip:** Use the `job_label` argument to append a descriptive label to the model's name.

In [None]:
trained = gretel.submit_train(
    base_config="tabular-actgan",
    data_source=dataset,
    job_label="custom-config",
    params={"epochs": 800, "discriminator_dim": [1024, 1024, 1024]},
    privacy_filters={"similarity": "high", "outliers": None},
)

In [None]:
# view synthetic data quality scores
print(trained.report)

## 🌱 Prepare the seed data

- Conditional data generation is accomplished by submitting seed data, which can be given as a file path or `DataFrame`.

- The seed data should contain a subset of the dataset's columns with the desired seed values.

- Currently, only categorical seed columns are supported.

- Here, we will conditionally generate `num_records` synthetic examples of high-school graduates making more than $50k per year.

In [None]:
import pandas as pd

num_records = 500

seed_data = pd.DataFrame(
    {
        "education": ["HS-grad"] * num_records,
        "income_bracket": [">50K"] * num_records
    },
)

seed_data

## 🤖 Conditionally generate synthetic data

- The `submit_generate` method requires either `num_records` **or** `seed_data` as a keyword argument.

- If `seed_data` is given, the number of generated records will equal `len(seed_data)`.

- **Tip:** You can generate data from any trained model in the current project by using its associated `model_id`.

In [None]:
generated = gretel.submit_generate(trained.model_id, seed_data=seed_data)

In [None]:
# inspect conditionally generated data
generated.synthetic_data.head()

In [None]:
# verify that the seeded columns are correct
print(generated.synthetic_data["education"].value_counts(), end="\n\n")
print(generated.synthetic_data["income_bracket"].value_counts())