# Gretel Trainer

This notebook is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code works by intelligently dividing a dataset into a set of smaller datasets of correlated columns that can be parallelized and then joined together. 

In [None]:
import strategy
import runner

from gretel_client.projects import create_or_get_unique_project
from gretel_client.projects.models import read_model_config
from gretel_client.projects.jobs import Status
from gretel_synthetics.utils.header_clusters import cluster

import pandas as pd

In [None]:
# Parallelization parameters and options

MAX_ROWS = 20000 # Maximum row count per model
MAX_HEADER_CLUSTERS = 20 # Max columns per cluster
GENERATE_ROWS = 0 # Use zero to match row count from training data
CACHE_FILE = "runner.json"
PROJECT = create_or_get_unique_project(name="complex-dataset")

print(f"Follow model training at: {PROJECT.get_console_url()}")

In [None]:
# Load the dataset to synthesize

DATASET_PATH = 'cpu_states.csv'
ROUND_DECIMALS = 4


def preprocess_data(dataset_path: str) -> pd.DataFrame:
    tmp = pd.read_csv(dataset_path, low_memory=False)
    tmp = tmp.round(ROUND_DECIMALS)
    return tmp
    
DF = preprocess_data(DATASET_PATH)
DF

In [None]:
# Load a default configuration from GitHub

CONFIG = read_model_config("synthetics/default")
CONFIG["models"][0]["synthetics"]["params"]["learning_rate"] = 0.001
CONFIG["models"][0]["synthetics"]["privacy_filters"] = {}
CONFIG["models"][0]["synthetics"]["privacy_filters"]["outliers"] = None
CONFIG["models"][0]["synthetics"]["privacy_filters"]["similarity"] = None

In [None]:
# Initialize the parallelization strategy

def initialize_run() -> runner.StrategyRunner:
    
    # Create clusters of correlated columns (might take a few minutes)
    header_clusters = cluster(DF, maxsize=MAX_HEADER_CLUSTERS, plot=True) 

    constraints = strategy.PartitionConstraints(
        header_clusters=header_clusters, 
        max_row_count=MAX_ROWS
    )
    
    run = runner.StrategyRunner(
        strategy_id="foo",
        df=DF,
        cache_file=CACHE_FILE,
        cache_overwrite=True,  # False means we'll try and load an existing cache and start back up, otherwise start fresh
        model_config=CONFIG,
        partition_constraints=constraints,
        project=PROJECT
    )    
    return run

run = initialize_run()

In [None]:
run.train_all_partitions()

In [None]:
synthetic = run.get_training_synthetic_data()
synthetic.to_csv('synthetic.csv', index=False)
synthetic

In [None]:
#run.cancel_all()
#PROJECT.delete()