# Gretel Trainer

This notebook is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code works by intelligently dividing a dataset into a set of smaller datasets of correlated columns that can be parallelized and then joined together. 

**Getting started:**


*   Copy your [Gretel API](https://console.gretel.cloud) key to the clipboard.
*   Update the `DATASET_PATH` to your dataset, or use the provided example.
*   Click Runtime -> Run All.
*   Use the correlation graph to compare your real world and synthetic data! 



In [None]:
import os

!git clone https://github.com/gretelai/trainer.git

os.chdir('./trainer')
!pip install .

In [None]:
from getpass import getpass
import pandas as pd

from gretel_trainer import strategy, runner

from gretel_client import configure_session, ClientConfig
from gretel_client.projects import create_or_get_unique_project
from gretel_client.projects.models import read_model_config
from gretel_client.projects.jobs import Status
from gretel_synthetics.utils.header_clusters import cluster

In [None]:
# Specify your Gretel API key

configure_session(ClientConfig(api_key=getpass(prompt="Enter Gretel API key"), 
                               endpoint="https://api.gretel.cloud"))

## Project Settings

Use the configuration options below to control parallelization settings.
* `MAX_ROWS` sets the maximum number of rows per sub-model, which helps reduce per-job run time and the need to tune neural network parameters such as `learning_rate`, `batch_size`, and `rnn_units` for larger datasets. Typical values are 20,000-100,000 rows.
* `MAX_HEADER_CLUSTERS` sets the maximum number of columns to include for each sub-model. Higher values help with maintaining complex correlations in datasets. Try a lower value if you have mixed text and numeric values, or if the model is not generating valid records. Typical values are 10-20 columns.

In [None]:
MAX_ROWS = 20000 # Maximum row count per model
MAX_HEADER_CLUSTERS = 20 # Max columns per cluster
PROJECT_NAME = 'health-data'
PROJECT = create_or_get_unique_project(name=PROJECT_NAME)

print(f"Follow model training at: {PROJECT.get_console_url()}")

## Dataset Preprocessing

Preprocess the data before training the network. A few tips:
* Try reducing floating point precision to a consistent number of decimal places (default: 4)
* Fields containing complex random numbers (such as UIDs) can be difficult for language models to learn. Try dropping them, specifying them as model seeds to preserve them, or replacing them with a LabelEncoder.

In [None]:
DATASET_PATH = './data/mitre-synthea-health.csv'
ROUND_DECIMALS = 4


def preprocess_data(dataset_path: str) -> pd.DataFrame:
    tmp = pd.read_csv(dataset_path, low_memory=False)
    tmp = tmp.round(ROUND_DECIMALS)
    return tmp


DF = preprocess_data(DATASET_PATH)
DF

## Hyperparameter Settings

View example default configs on GitHub https://github.com/gretelai/gretel-blueprints/tree/main/config_templates/gretel/synthetics. A few tips:

* For larger dataset sizes (50,000-100,000 rows) try setting `vocab_size` to 20,000. This uses a `sentencepiece` tokenizer and will speed up model training and generation. For smaller datasets or if you're seeing high invalid record counts, use `vocab_size` of 0 to tokenize per character.

In [None]:
# Fine tune any configuration settings here

CONFIG = read_model_config("synthetics/default")
CONFIG["models"][0]["synthetics"]["params"]["vocab_size"] = 0
CONFIG["models"][0]["synthetics"]["params"]["learning_rate"] = 0.001
CONFIG["models"][0]["synthetics"]["privacy_filters"] = {}
CONFIG["models"][0]["synthetics"]["privacy_filters"]["outliers"] = None
CONFIG["models"][0]["synthetics"]["privacy_filters"]["similarity"] = None

In [None]:
# Initialize the parallelization strategy

def initialize_run() -> runner.StrategyRunner:
    
    # Create clusters of correlated columns (might take a few minutes)
    plot = True if len(DF.columns) < 50 else False
    header_clusters = cluster(DF, maxsize=MAX_HEADER_CLUSTERS, plot=plot) 

    constraints = strategy.PartitionConstraints(
        header_clusters=header_clusters, 
        max_row_count=MAX_ROWS
    )
    
    run = runner.StrategyRunner(
        strategy_id="foo",
        df=DF,
        cache_file="runner.json",
        cache_overwrite=True,  # Set to False to load existing cache (and not start over)
        model_config=CONFIG,
        partition_constraints=constraints,
        project=PROJECT
    )    
    return run

run = initialize_run()

In [None]:
# Train all models
run.train_all_partitions()

In [None]:
# Access synthetic data

synthetic = run.get_training_synthetic_data()
synthetic.to_csv('synthetic.csv', index=False)
synthetic

In [None]:
#Uncomment and run these lines to terminate models training in the cloud

#run.cancel_all()
#PROJECT.delete()

## Plot correlations

Use the `_get_correlation_matrix` from `gretel-synthetics` to compare correlations between the real world and synthetic datasets.

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from gretel_synthetics.utils.header_clusters import _get_correlation_matrix


def plot_correlations(real_df: pd.DataFrame, synthetic_df: pd.DataFrame):
    s_corr = _get_correlation_matrix(real_df)
    r_corr = _get_correlation_matrix(synthetic_df)

    fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.2)
    fig.update_layout(title_text="Real world vs. Synthetic Correlations")
    trace1 = go.Heatmap(z=r_corr, y=r_corr.index, x=r_corr.columns)
    trace2 = go.Heatmap(z=s_corr, y=s_corr.index, x=s_corr.columns)
    fig.add_trace(trace1, row=1, col=1)
    fig.add_trace(trace2, row=1, col=2)
    fig.update_traces(showscale=False)
    fig.show()

plot_correlations(DF, synthetic)

In [None]:
# Use the model to generate additional data

run.generate_data(num_records=5000, max_invalid=None, clear_cache=True)
run.get_synthetic_data()