In [24]:
from synthcity import *

In [25]:
from synthcity.plugins import Plugins

Plugins(categories=['generic','privacy']).list()


['ctgan', 'rtvae', 'dpgan', 'ddpm', 'pategan', 'privbayes', 'tvae', 'adsgan']

In [29]:
from sklearn.datasets import load_iris
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.utils.serialization import save_to_file, load_from_file
          
X, y = load_iris(as_frame=True, return_X_y=True)
X["target"] = y
X = GenericDataLoader(X)

plugin_params = dict(
    n_iter = 100
)
test_plugin = Plugins().get('ddpm',**plugin_params)
test_plugin.fit(X)

save_to_file('../ddpm_100_epochs.pkl',test_plugin)
reloaded = load_from_file('../ddpm_100_epochs.pkl')




[2023-07-02T18:25:36.309809+0200][7176][INFO] Encoding sepal length (cm) 8461685668942494555
[2023-07-02T18:25:36.359811+0200][7176][INFO] Encoding sepal width (cm) 7372477013158199918
[2023-07-02T18:25:36.373349+0200][7176][INFO] Encoding petal length (cm) 8795408021141068254
[2023-07-02T18:25:36.386437+0200][7176][INFO] Encoding petal width (cm) 1839870727438321343


[2023-07-02T18:25:36.445880+0200][7176][INFO] Encoding target 2443400643551247192
Epoch: 100%|██████████| 100/100 [00:43<00:00,  2.33it/s, loss=2.22]


In [30]:
reloaded.generate(count=10)

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
0,7.9,4.4,1.0,2.5,0
1,4.3,2.0,6.9,0.1,0
2,7.9,2.0,6.9,2.5,2
3,7.9,2.0,1.0,0.1,0
4,4.3,4.4,6.9,0.1,0
5,4.3,2.0,1.0,2.5,2
6,7.9,4.4,1.0,2.5,2
7,4.3,4.4,1.0,2.5,2
8,4.3,4.4,6.9,2.5,0
9,4.3,2.0,6.9,2.5,2


# Discover hyperparameter space from ddpm

In [5]:
import sys
import warnings

import optuna
import synthcity.logger as log

log.add(sink=sys.stderr,level='INFO')
warnings.filterwarnings("ignore")

In [21]:
PLUGIN = "ddpm"
plugin_cls = type(Plugins().get(PLUGIN))
plugin_cls

synthcity.plugins.generic.plugin_ddpm.TabDDPMPlugin

In [22]:
plugin_cls.hyperparameter_space()

[LogDistribution(name='lr', data=None, random_state=0, marginal_distribution=None, low=1e-05, high=0.1),
 IntLogDistribution(name='batch_size', data=None, random_state=0, marginal_distribution=None, low=256, high=4096, step=1),
 IntegerDistribution(name='num_timesteps', data=None, random_state=0, marginal_distribution=None, low=10, high=1000, step=1),
 IntLogDistribution(name='n_iter', data=None, random_state=0, marginal_distribution=None, low=1000, high=10000, step=1)]

In [23]:
from synthcity.utils.optuna_sample import suggest_all

trial = optuna.create_study().ask()
params = suggest_all(trial, plugin_cls.hyperparameter_space())
params['n_iter'] = 100  # speed up
params

{'lr': 0.06959480611212286,
 'batch_size': 697,
 'num_timesteps': 350,
 'n_iter': 100}

# Discover hyperparameter space from ctgan

In [8]:
from sklearn.datasets import load_iris
from synthcity.plugins.core.dataloader import GenericDataLoader
          
X, y = load_iris(as_frame=True, return_X_y=True)
X["target"] = y
X = GenericDataLoader(X)
plugin_params = dict(
    n_iter = 10
)
test_plugin = Plugins().get('ctgan',**plugin_params)
test_plugin.fit(X, cond=y)


100%|██████████| 10/10 [00:01<00:00,  6.93it/s]


<synthcity.plugins.generic.plugin_ctgan.CTGANPlugin at 0x19e5316ffd0>

In [10]:
test_plugin.generate(30).dataframe()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
0,4.3,2.0,6.641957,1.061961,2
1,4.3,2.667095,6.9,2.5,2
2,4.3,2.448462,6.9,0.929168,2
3,4.3,3.148739,6.435413,0.806541,2
4,4.3,2.0,6.9,2.5,2
5,4.3,2.350046,6.9,1.012499,2
6,4.3,2.0,2.864482,1.061961,2
7,4.3,2.211711,5.924492,0.744429,2
8,4.3,2.864846,6.9,0.907809,2
9,4.3,2.54686,2.577006,0.952131,2


In [12]:
from sklearn.datasets import load_iris
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.plugins.core.constraints import Constraints
          
X, y = load_iris(as_frame=True, return_X_y=True)
X["target"] = y
X = GenericDataLoader(X)
plugin_params = dict(
    n_iter = 100
)
test_plugin = Plugins().get('tvae',**plugin_params)
test_plugin.fit(X, cond=y)


100%|██████████| 100/100 [00:28<00:00,  3.57it/s]


<synthcity.plugins.generic.plugin_tvae.TVAEPlugin at 0x19e5728d430>

In [13]:
test_plugin.generate(30).dataframe()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
0,6.89788,3.232471,4.774508,1.865713,2
1,5.166881,2.632011,4.676399,1.813565,2
2,5.116225,3.155409,4.885044,0.295749,2
3,6.944025,2.804557,5.308022,2.179232,2
4,6.567341,2.680667,5.48194,2.174216,2
5,5.858331,3.140261,4.793322,0.3191,2
6,6.155458,3.185978,5.227595,2.138442,2
7,6.18924,2.855802,5.057013,0.227167,2
8,5.289608,2.899719,1.295587,1.843505,1
9,6.617167,2.811755,5.06841,0.320463,2
