In [1]:
!pip install synthcity
!pip uninstall -y torchaudio torchdata
!pip install plotly



In [2]:
# stdlib
import sys
import warnings

# third party
import optuna
from sklearn.datasets import load_diabetes

import numpy as np
import pandas as pd

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm




In [3]:
# objective function for the optuna optmization
# we optmize for minimizing detection of synthetic vs real data

from synthcity.utils.optuna_sample import suggest_all
from synthcity.benchmark import Benchmarks

def objective(trial: optuna.Trial):
    hp_space = Plugins().get(PLUGIN).hyperparameter_space()
    params = suggest_all(trial, hp_space)
    if PLUGIN == "ddpm":
        params["is_classification"] = False
    ID = f"trial_{trial.number}"
    try:
        report = Benchmarks.evaluate(
            [(ID, PLUGIN, params)],
            train_loader,
            repeats=1,
            metrics={"detection": ["detection_xgb"]}, 
        )
    except Exception as e:  # invalid set of params
        print(f"{type(e).__name__}: {e}")
        print(params)
        raise optuna.TrialPruned()
    score = report[ID].query('direction == "minimize"')['mean'].mean()
    # average score across all metrics with direction="minimize"
    return score


def enforce_dtypes(dat, 
                   num_variables, 
                   cat_variables):
    """
    Enforce "float64" type for numeric variables and "object" type for the
    categorical variables
    Parameters:
        dat (pd.DataFrame): Input data matrix (numeric, categorical, or mixed).
        num_variables (list): Indices of numeric variables.
        cat_variables (list): Indices of categorical variables.

    Returns:
    pd.DataFrame: with transformed data types
    """
    if num_variables is not None and cat_variables is None:
        dat_N = pd.DataFrame(dat.iloc[:, num_variables], dtype = "float64")
        dat = dat_N

    elif num_variables is None and cat_variables is not None:
        dat_C = pd.DataFrame(dat.iloc[:, cat_variables], dtype = "str")
        dat = dat_C

    elif num_variables is not None and cat_variables is not None:
        dat_N = pd.DataFrame(dat.iloc[:, num_variables], dtype = "float64")
        dat_C = pd.DataFrame(dat.iloc[:, cat_variables], dtype = "str")
        dat = pd.concat([dat_N, dat_C], axis=1)
        # Reorder columns to match the order in the original data
        reordered_indices = num_variables + cat_variables
        dat = dat.iloc[:, np.argsort(reordered_indices)]

    else:
        raise ValueError("At least one of num_variables or cat_variables must be specified.")
    
    return dat 


def train_test_data_split(X, my_seed):
    """
    Splits the data X into training and testing sets, using a random seed.
    
    Parameters:
    X (pd.DataFrame): The input data DataFrame.
    my_seed (int): The random seed for reproducibility.
    
    Returns:
    dict: A dictionary containing the training and testing DataFrames.
          {'X_train': X_train, 'X_test': X_test}
    """
    # Set random seed
    np.random.seed(my_seed)
    
    # Get the total number of rows
    n = X.shape[0]
    n_sub = n // 2  # Floor division to get half the rows
    
    # Randomly sample indexes for the training set
    idx_train = np.random.choice(X.index, size=n_sub, replace=False)
    
    # Compute the test indexes as the set difference
    idx_test = X.index.difference(idx_train)

    # Adjust sizes to make them equal if necessary
    if len(idx_train) < len(idx_test):
        idx_test = idx_test[:-1]  # Remove the last test index
    
    # Split the data
    X_train = X.loc[idx_train]
    X_test = X.loc[idx_test]
    
    return {"X_train": X_train, "X_test": X_test}

In [4]:
# load the data

from sklearn.datasets import fetch_california_housing

# Load the dataset
california_housing = fetch_california_housing(as_frame=True)

# Features (X) and target (y)
X = california_housing.data
y = california_housing.target

X["target"] = y

num_idx = [0, 1, 2, 3, 4, 5, 6, 7, 8]
cat_idx = None

X = enforce_dtypes(dat = X, 
                   num_variables = num_idx, 
                   cat_variables = cat_idx)

# Split the data
aux = train_test_data_split(X, my_seed=123)

X_train = aux["X_train"]
X_test = aux["X_test"]

In [5]:
# create data loader

train_loader = GenericDataLoader(
    X_train,
    target_column="target",
)

test_loader = GenericDataLoader(
    X_test,
    target_column="target",
)

In [6]:
# set number of optuna trials

n_trials = 20

In [7]:
# run optuna for ddpm

np.random.seed(123)

PLUGIN = "ddpm"
plugin_cls = type(Plugins().get(PLUGIN))

study_ddpm = optuna.create_study(direction="minimize")
study_ddpm.optimize(objective, n_trials=n_trials)
study_ddpm.best_params

[2025-05-27T03:11:14.274934+0000][118935][CRITICAL] Error importing TabularGoggle: No module named 'dgl'
[2025-05-27T03:11:14.274934+0000][118935][CRITICAL] Error importing TabularGoggle: No module named 'dgl'
[2025-05-27T03:11:14.279188+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:14.279188+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:15.343884+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:15.351529+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
Epoch:   0%|          | 20/6613 [00:02<12:59,  8.46it/s, loss=1.54]


KeyboardInterrupt: 

In [8]:
# run optuna for arf

np.random.seed(123)

PLUGIN = "arf"
plugin_cls = type(Plugins().get(PLUGIN))

study_arf = optuna.create_study(direction="minimize")
study_arf.optimize(objective, n_trials=n_trials)
study_arf.best_params

[2025-05-27T03:11:37.303352+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.307408+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.313533+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.341934+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.347461+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.376037+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.381432+0000][118935][CRITICAL] module disabled: /opt/conda/lib/pyth

AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 90, 'delta': 44, 'max_iters': 3, 'early_stop': True, 'min_node_size': 8}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 80, 'delta': 2, 'max_iters': 1, 'early_stop': False, 'min_node_size': 12}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 80, 'delta': 6, 'max_iters': 5, 'early_stop': True, 'min_node_size': 18}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 80, 'delta': 44, 'max_iters': 4, 'early_stop': True, 'min_node_size': 12}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 70, 'delta': 6, 'max_iters': 5, 'early_stop': True, 'min_node_size': 14}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 70, 'delta': 4, 'max_iters': 4, 'early_stop': False, 'min_node_size': 14}


[2025-05-27T03:11:37.544799+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.550106+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.578179+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.583598+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.611762+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.617759+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.646153+0000][118935][CRITICAL] module disabled: /opt/conda/lib/pyth

AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 60, 'delta': 26, 'max_iters': 4, 'early_stop': False, 'min_node_size': 8}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 10, 'delta': 50, 'max_iters': 5, 'early_stop': False, 'min_node_size': 10}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 90, 'delta': 18, 'max_iters': 5, 'early_stop': True, 'min_node_size': 16}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 60, 'delta': 12, 'max_iters': 1, 'early_stop': True, 'min_node_size': 12}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 30, 'delta': 34, 'max_iters': 2, 'early_stop': True, 'min_node_size': 2}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 100, 'delta': 32, 'max_iters': 1, 'early_stop': False, 'min_node_size': 6}


[2025-05-27T03:11:37.758431+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.786399+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.804198+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.832266+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.850229+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.878774+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:37.896673+0000][118935][CRITICAL] module disabled: /opt/conda/lib/pyth

AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 100, 'delta': 38, 'max_iters': 2, 'early_stop': False, 'min_node_size': 6}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 40, 'delta': 20, 'max_iters': 2, 'early_stop': False, 'min_node_size': 20}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 90, 'delta': 26, 'max_iters': 3, 'early_stop': True, 'min_node_size': 2}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 80, 'delta': 44, 'max_iters': 3, 'early_stop': False, 'min_node_size': 8}
AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 40, 'delta': 12, 'max_iters': 1, 'early_stop': False, 'min_node_size': 10}


[2025-05-27T03:11:37.990214+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:38.018920+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:38.037851+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py


AssertionError: parameter delta must be in range 0 <= delta <= 0.5
{'num_trees': 90, 'delta': 48, 'max_iters': 2, 'early_stop': True, 'min_node_size': 6}
Initial accuracy is 0.9196342054263565


KeyboardInterrupt: 

In [9]:
# run optuna for tvae

np.random.seed(123)

PLUGIN = "tvae"
plugin_cls = type(Plugins().get(PLUGIN))

study_tvae = optuna.create_study(direction="minimize")
study_tvae.optimize(objective, n_trials=n_trials)
study_tvae.best_params

[2025-05-27T03:11:51.157784+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:51.161076+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-05-27T03:11:51.168333+0000][118935][CRITICAL] module disabled: /opt/conda/lib/python3.12/site-packages/synthcity/plugins/generic/plugin_goggle.py
  1%|          | 1/100 [00:09<15:15,  9.25s/it]


KeyboardInterrupt: 

In [None]:
# run optuna for ctgan

np.random.seed(123)

PLUGIN = "ctgan"
plugin_cls = type(Plugins().get(PLUGIN))

study_ctgan = optuna.create_study(direction="minimize")
study_ctgan.optimize(objective, n_trials=n_trials)
study_ctgan.best_params