# YData-Synthetic Demo

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
from pmlb import fetch_data
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters

In [3]:
%cd /Users/alex/PETsARD

/Users/alex/PETsARD


In [4]:
data = fetch_data('adult')

In [5]:
num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
cat_cols = ['workclass','education', 'education-num', 'marital-status',
                'occupation', 'relationship', 'race', 'sex', 'native-country', 'target']

In [6]:
# Define model and training parameters
ctgan_args = ModelParameters(batch_size=500, lr=2e-4, betas=(0.5, 0.9))
train_args = TrainParameters(epochs=101)

# Train the generator model
synth = RegularSynthesizer(modelname='ctgan', model_parameters=ctgan_args)
synth.fit(data=data, train_arguments=train_args, num_cols=num_cols, cat_cols=cat_cols)

# Generate 1000 new synthetic samples
synth_data = synth.sample(1000) 



Epoch: 0 | critic_loss: 0.15915349125862122 | generator_loss: 1.5187902450561523
Epoch: 1 | critic_loss: 0.11161187291145325 | generator_loss: 1.1965501308441162
Epoch: 2 | critic_loss: 0.04443097114562988 | generator_loss: 0.8039807677268982
Epoch: 3 | critic_loss: 0.12427829205989838 | generator_loss: 0.2787080407142639
Epoch: 4 | critic_loss: 0.1292436420917511 | generator_loss: -0.28288543224334717
Epoch: 5 | critic_loss: 0.002009451389312744 | generator_loss: -0.461927592754364
Epoch: 6 | critic_loss: 0.06914292275905609 | generator_loss: -0.7213209867477417
Epoch: 7 | critic_loss: -0.04778724163770676 | generator_loss: -0.6719639301300049
Epoch: 8 | critic_loss: -0.019918859004974365 | generator_loss: -0.48722484707832336
Epoch: 9 | critic_loss: 0.013895466923713684 | generator_loss: -0.42522501945495605
Epoch: 10 | critic_loss: 0.020033370703458786 | generator_loss: -0.487899512052536
Epoch: 11 | critic_loss: -0.06784175336360931 | generator_loss: -0.6133951544761658
Epoch: 12 |

In [10]:
synth_data

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,target
0,54.237576,4,224884.987302,15,10.0,2,11,0,4,1,9.924221,0.050711,30.670093,39,0
1,31.556741,5,110257.221908,11,9.0,2,7,0,2,1,17.779265,-0.429337,40.019158,39,1
2,28.219478,4,214364.109204,11,9.0,4,1,3,4,1,-7.326760,0.359110,50.258479,39,1
3,67.363677,4,97572.830871,7,12.0,2,3,0,4,1,20.174761,-0.627301,40.013682,39,0
4,64.260167,6,443488.287979,5,4.0,2,10,0,4,1,4733.406175,-0.152154,40.009208,39,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,36.199577,4,292719.623111,11,9.0,2,3,0,4,1,1.892537,-1.182285,39.976687,39,1
996,43.551592,4,467389.208820,1,7.0,2,3,0,4,1,12.319821,-0.737305,40.024131,39,1
997,52.870376,4,134327.548889,15,10.0,4,14,1,4,1,-12.026845,-0.737271,40.000898,39,1
998,30.395759,4,151769.496431,15,8.0,2,4,0,4,1,0.999651,-0.444731,40.019469,39,0


In [11]:
# Train the GMM
synth_gmm = RegularSynthesizer(modelname='fast')
synth_gmm.fit(data=data, cat_cols=cat_cols, num_cols=num_cols)

# Generate 1000 new synthetic samples
synth_gmm_data = synth_gmm.sample(1000) 

Hyperparameter search: 100%|██████████| 8/8 [04:55<00:00, 36.94s/it]


In [12]:
synth_gmm_data

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,target
0,35.145432,4,244782.531528,8,11.0,4,12,3,4,0,2481.133979,176.650445,42.118269,39,1
1,38.273739,4,177868.906322,7,12.0,4,1,3,4,0,-3498.526098,54.190801,27.735051,39,1
2,28.127961,4,250592.696569,1,7.0,4,8,2,2,0,-883.172613,52.587449,43.578224,39,1
3,37.472839,4,75724.203372,1,7.0,4,8,2,4,0,1656.297077,325.978016,32.116786,39,1
4,45.864424,4,239463.113718,8,11.0,0,10,4,4,0,6889.443250,350.348402,56.277495,39,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,31.296466,4,297761.621401,8,11.0,4,6,1,2,1,-5205.744944,-393.359549,49.265973,39,1
996,38.450852,4,265844.327125,1,7.0,4,1,1,4,1,-12442.576781,74.614141,13.594040,39,1
997,12.271465,4,228976.621268,1,7.0,4,5,3,4,1,-14367.734848,130.324927,34.124153,39,1
998,28.734200,4,11232.319469,14,15.0,4,10,1,4,1,-3914.468912,542.377450,50.866178,39,1
