# Hyperparameter search and evaluation

In [None]:
# sometimes we have to purge the workspace to avoid errors
!rm -rf workspace

In [None]:
# stdlib
import sys
import warnings

warnings.filterwarnings("ignore")

from datetime import datetime, timedelta

import numpy as np
import pandas as pd

import optuna

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.benchmark import Benchmarks
from synthcity.utils.optuna_sample import suggest_all
from synthcity.utils.serialization import load, load_from_file, save, save_to_file


log.add(sink=sys.stderr, level="INFO")

In [None]:
from sklearn.preprocessing import OrdinalEncoder
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector

In [None]:
import matplotlib.pyplot as plt
from synthcity.metrics.plots import plot_marginal_comparison, plot_tsne

## Optimization

In [None]:
peaks = 1
n_iter = 2000 # WARNING change this
# it should be 4000 but it takes 35 minutes on a GPU
num_seq = 4000
days = 1
# real data dir
data_dir = "../"
# generate_tsne
generate_tsne = False

In [None]:
rd_filename = f"real_data_synthcity_{days}_days_{peaks}_peaks_tabular.csv"
generator = "adsgan"
model_filename = f"model_{generator}_synthcity_{days}_days_{peaks}_peaks_tabular_opt.pkl"

### Load real data and instantiate the dataloaders

In [None]:
real_data = pd.read_csv(rd_filename, index_col=0)

In [None]:
loader = GenericDataLoader(real_data)

In [None]:
train_loader, test_loader = loader.train(), loader.test()

### Load the generator

In [None]:
plugin_cls = type(Plugins().get(generator))
plugin_cls

### Display the hyperparameter space

In [None]:
plugin_cls.hyperparameter_space()

### Set a trial

In [None]:
trial = optuna.create_study().ask()
params = suggest_all(trial, plugin_cls.hyperparameter_space())
params['n_iter'] = n_iter
params

### Evaluate the generator

In [None]:
%%time
plugin = plugin_cls(**params).fit(train_loader)
report = Benchmarks.evaluate(
    [("trial", generator, params)],
    train_loader,  # Benchmarks.evaluate will split out a validation set
    repeats=1,
    metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
)
report['trial']

In [None]:
def objective(trial: optuna.Trial):
    hp_space = Plugins().get(generator).hyperparameter_space()
    hp_space[0].high = 100  # speed up for now
    params = suggest_all(trial, hp_space)
    ID = f"trial_{trial.number}"
    try:
        report = Benchmarks.evaluate(
            [(ID, generator, params)],
            train_loader,
            repeats=1,
            metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
        )
    except Exception as e:  # invalid set of params
        print(f"{type(e).__name__}: {e}")
        print(params)
        raise optuna.TrialPruned()
    score = report[ID].query('direction == "minimize"')['mean'].mean()
    # average score across all metrics with direction="minimize"
    return score

In [None]:
%%time
# WARNING CUDA out of memory
try:
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=2)
    study.best_params
except:
    pass

#### Visualize the study

In [None]:
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_slice

plot_optimization_history(study)

In [None]:
# Visualize high-dimensional parameter relationships. 
plot_parallel_coordinate(study)

In [None]:
# Visualize high-dimensional parameter relationships. 
plot_parallel_coordinate(study)

In [None]:
# Visualize individual hyperparameters as slice plot.
plot_slice(study)

In [None]:
# Visualize parameter importances.
plot_param_importances(study)

In [None]:
# Learn which hyperparameters are affecting the trial duration with hyperparameter importance.
optuna.visualization.plot_param_importances(
    study, target=lambda t: t.duration.total_seconds(), target_name="duration"
)

In [None]:
# Visualize empirical distribution function of the objective.
plot_edf(study)

## Performance

In [None]:
# FIXME OOM issue
try: 
    best_params = study.best_params
    report = Benchmarks.evaluate(
        [("test", generator, best_params)],
        train_loader,
        test_loader,
        repeats=1,
        metrics={"detection": ["detection_mlp", "detection_xgb"]},  # DELETE THIS LINE FOR ALL METRICS
    )
    Benchmarks.print(report)
except:
    pass

## Fit and save model with best parameters

In [None]:
# delete this, just for testing because of the out of memory problem
best_params = params
best_params

In [None]:
%%time
syn_model = plugin_cls(**best_params).fit(loader)

In [None]:
save_to_file(model_filename, syn_model)

In [None]:
%%time
synthetic_data = syn_model.generate().dataframe()

In [None]:
synthetic_data.head(5)

In [None]:
synthetic_data.to_csv(f"synthetic_data_synthcity_{days}_days_{peaks}_peaks_tabular_opt.csv")

## Plots

In [None]:
def convert_to_dloaders(static_df, ct=None):
    if not ct:
        ct = make_column_transformer((OrdinalEncoder(), make_column_selector(dtype_include="object")),
                                     ("passthrough",make_column_selector(dtype_include=["float64"])),
                                     ("passthrough",make_column_selector(dtype_include=["int64"])))

    column_order = list(static_df.select_dtypes(include=["object"]).columns) + list(static_df.select_dtypes(include=["float64"]).columns)\
                   + list(static_df.select_dtypes(include=["int64"]).columns) 
    tr_df = pd.DataFrame(ct.fit_transform(static_df), index=static_df.index, columns=column_order)[static_df.columns]

    loader = GenericDataLoader(tr_df)
    return loader, ct

In [None]:
%%time
rd_loader, ct = convert_to_dloaders(real_data)

In [None]:
%%time
sd_loader, _ = convert_to_dloaders(synthetic_data, ct)

In [None]:
%%time
plot_marginal_comparison(plt, rd_loader, sd_loader)

In [None]:
%%time
if generate_tsne:
    plot_tsne(plt, rd_loader, sd_loader)

## Done!