<a target="_blank" href="https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/docs/notebooks/demo/gretel_tuner_advanced_tutorial.ipynb"> 
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> </a>

# 🎛️ Advanced **Gretel Tuner** Tutorial

In this tutorial, we will demonstrate how to sweep ACTGAN's hyperparameters using **Gretel Tuner**, leveraging two advanced features:

1. We will implement a custom, task-specific optimization metric.

2. We will define a `sampler_callback` function, which applies arbitrary constraints on the sampled model configs.


## In the right place?

If you are new to the **Gretel Tuner**, we recommend first working through the [Introductory Gretel Tuner Tutorial](https://colab.research.google.com/drive/1goVEJIufC4AebJ106FOpGmlCgT5UJswW?usp=sharing).


## 💿 Installation

- The tuner requires additional dependencies beyond the minimal requirements of [gretel_client](https://github.com/gretelai/gretel-python-client).

- To install the tuner along with the client, add the `[tuner]` option to the pip install command:

In [None]:
%%capture
!pip install gretel-client[tuner]

## 🛜 Configure your Gretel session

- The [`Gretel` object](https://docs.gretel.ai/guides/high-level-sdk-interface/the-gretel-object) provides a high-level interface for streamlining interactions with Gretel's APIs.

- Each `Gretel` instance is bound to a single [Gretel project](https://docs.gretel.ai/guides/gretel-fundamentals/projects).

- Running the cell below will prompt you for your Gretel API key, which you can retrieve [here](https://console.gretel.ai/users/me/key).

- With `validate=True`, your login credentials will be validated immediately at instantiation.

In [None]:
from gretel_client import Gretel

gretel = Gretel(
    project_name="tuner-advanced-tutorial",
    api_key="prompt",
    validate=True,
)

## 🏦 Preview bank marketing dataset

- For this demo, we will use a subset of a [bank marketing dataset](https://archive.ics.uci.edu/dataset/222/bank+marketing).

In [None]:
import pandas as pd

data_source = "https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/sample_data/bank_marketing_small.csv"

df_ref = pd.read_csv(data_source)
df_ref.head()

## 🏗️ Building custom optimization metrics

- As we saw in the intro Gretel Tuner tutorial, Gretel's quality metrics such as the [Synthetic Data Quality Score](https://docs.gretel.ai/reference/evaluate/synthetic-data-quality-report#synthetic-data-quality-score-sqs) can be set as the tuner's optimization metric via its yaml config.

- To use a custom metric, create a class that inherits `BaseTunerMetric` and implement a `__call__` method that takes a Gretel `Model` as input and returns the metric score as a float, as we demonstrate in the cell below.

- `BaseTunerMetric` has two helper methods:
    - `get_gretel_report` - fetches the Gretel Synthetic Data Quality Report, which is useful if you want to incorporate Gretel's score(s) into your custom metric.
    - `submit_generate_for_trial` - submits a synthetic data generation job to Gretel, which is useful if you need to generate synthetic data beyond the data used in Gretel's report.

> #### Example use case
Given the bank marketing dataset, suppose we have a use case that requires **(i)** our synthetic model to conditionally generate synthetic records with `job = 'entrepreneur'` and **(ii)** we really care about accurately reproducing the distribution of bank balances. In the cell below, we build both of these requirements into a custom metric called `BalanceKSComplementPlusSQS`, which calculates a metric score as follows. For each trial:
- Conditionally generate `num_samples` records with `target_job='entrepreneur'`.
- For records with `job = 'entrepreneur'`, compare the real and conditionally-generated `balance` distributions using the [KS statistic](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test) (lower is better).
- Calculate the custom metric score as a weighted sum of the KS complement (1 - KS -> higher is better) and Gretel's SQS score, which measures the general synthetic data quality.

> **Note:** This example and custom metric are only meant to demonstrate functionality. Implementing a task-specific optimization metric in the real world requires careful consideration of your dataset and use case requirements.

In [None]:
from scipy.stats import ks_2samp

from gretel_client.tuner import BaseTunerMetric, MetricDirection

class BalanceKSComplementPlusSQS(BaseTunerMetric):
    def __init__(self, df_ref, target_job="entrepreneur", num_samples=500):
        self.target_job = target_job
        self.num_samples = num_samples
        self.balance_ref = df_ref.query("job==@target_job")["balance"]

        # Set metric optimization direction.
        # (note the default is maximize, so this is optional in this case)
        self.direction = MetricDirection.MAXIMIZE

    def __call__(self, model):
        # Fetch Gretel's synthetic data quality score for the model.
        report = self.get_gretel_report(model)
        sqs = report["synthetic_data_quality_score"]["raw_score"] / 100

        # (i) Conditionally generate synthetic records with the target job.
        seed_data = pd.DataFrame({"job": [self.target_job] * self.num_samples})
        df_synth = self.submit_generate_for_trial(model, seed_data=seed_data)

        # (ii) Calculate the KS complement of the real and synthetic balances.
        ks_comp = 1 - ks_2samp(df_synth["balance"], self.balance_ref).statistic

        # Calculate score as weighted sum of the KS complement and SQS.
        score =  0.7 * ks_comp + 0.3 * sqs

        return score

## 🚀 Run Gretel Tuner with the custom metric

- The [Gretel object](https://docs.gretel.ai/guides/high-level-sdk-interface/the-gretel-object) has a convenience `run_tuner` method, which will run the parameter sweeps in a single command.

- The tuner submits training jobs to Gretel with different model configurations. While the submitted jobs run remotely in the cloud, the tuner runs **locally**, submitting new jobs as model training completes from previous jobs.

- Here we use `n_trials = 4`, which is too small to find an optimal model. In an actual hyperparameter tuning experiment, we recommend using at least ~20-50 trials, depending on the observed convergence of the metric score.

- The `sampler_callback` function is applied to each trial config before the training job is submitted. Its input argument is the model section of the config. In this example, we use it to set the constraint `generator_dim = discriminator_dim`.

In [None]:
# This cell should take ~10 minutes to complete.
tuner_config = """
base_config: tabular-actgan

params:

    batch_size:
        fixed: 500

    epochs:
        fixed: 500

    generator_lr:
        log_range: [0.00001, 0.001]

    discriminator_lr:
        log_range: [0.00001, 0.001]

    generator_dim:
        choices:
            - [512, 512, 512, 512]
            - [1024, 1024]
            - [1024, 1024, 1024]
            - [2048, 2048]
            - [2048, 2048, 2048]
"""

def sampler_callback(model_section):
    """Always set discriminator_dim = generator_dim in ACTGAN's config."""
    model_section["params"]["discriminator_dim"] = model_section["params"]["generator_dim"]
    return model_section

target_job = "entrepreneur"

df_ref = pd.read_csv(data_source)

metric = BalanceKSComplementPlusSQS(df_ref, target_job=target_job)

tuner_results = gretel.run_tuner(
    tuner_config,
    data_source=df_ref,
    n_jobs=2,
    n_trials=4,
    metric=metric,
    sampler_callback=sampler_callback
)

## 📈 Visualize the experiment results

- Under the hood, Gretel Tuner uses [Optuna](https://optuna.readthedocs.io/en/stable/index.html) to drive the sampling of hyperparameters.

- This means we can use Optuna's visualization tools to better understand our tuning experiments.

In [None]:
import optuna.visualization as viz

# Plot the optimization metric as a function of trial number.
viz.plot_optimization_history(tuner_results.study)

In [None]:
# Compare the importances of the sampled hyperparameters.
viz.plot_param_importances(tuner_results.study)

## 🤖 Conditionally generate synthetic data using the "best" model
-  We submit the generate job using the `best_model_id` from the above tuner results.

In [None]:
generated = gretel.submit_generate(
    model_id=tuner_results.best_model_id,
    seed_data=pd.DataFrame({"job": [target_job] * 100})
)

In [None]:
# The synthetic data is returned as a DataFrame.
generated.synthetic_data