In [46]:
!pip install pmlb
!pip install ydata-synthetic

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [47]:
from pmlb import fetch_data

from ydata_synthetic.synthesizers.regular import RegularSynthesizer
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters

#Load data and define the data processor parameters
data = fetch_data('adult')
num_cols = ['fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
cat_cols = ['age','workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',
            'native-country', 'target']

# DRAGAN training
#Defining the training parameters of DRAGAN

noise_dim = 128
dim = 128
batch_size = 500

log_step = 100
epochs = 10+1
learning_rate = 1e-5
beta_1 = 0.5
beta_2 = 0.9
models_dir = '../cache'

gan_args = ModelParameters(batch_size=batch_size,
                           lr=learning_rate,
                           betas=(beta_1, beta_2),
                           noise_dim=noise_dim,
                           layers_dim=dim)

train_args = TrainParameters(epochs=epochs,
                             sample_interval=log_step)

synth = RegularSynthesizer(modelname='dragan', model_parameters=gan_args, n_discriminator=3)
synth.fit(data = data, train_arguments = train_args, num_cols = num_cols, cat_cols = cat_cols)

synth.save('adult_synth.pkl')

#########################################################
#    Loading and sampling from a trained synthesizer    #
#########################################################
synthesizer = RegularSynthesizer.load('adult_synth.pkl')
a = synthesizer.sample(1000)

  9%|▉         | 1/11 [01:22<13:40, 82.04s/it]

Epoch: 0 | disc_loss: -0.4180566668510437 | gen_loss: -0.024495083838701248


 18%|█▊        | 2/11 [02:43<12:17, 81.96s/it]

Epoch: 1 | disc_loss: -0.515661358833313 | gen_loss: -0.042718514800071716


 27%|██▋       | 3/11 [03:47<09:48, 73.53s/it]

Epoch: 2 | disc_loss: -0.5230857133865356 | gen_loss: -0.018776515498757362


 36%|███▋      | 4/11 [05:09<08:57, 76.84s/it]

Epoch: 3 | disc_loss: -0.5004736185073853 | gen_loss: -0.04081151261925697


 45%|████▌     | 5/11 [06:20<07:28, 74.77s/it]

Epoch: 4 | disc_loss: -0.4549264907836914 | gen_loss: -0.042799513787031174


 55%|█████▍    | 6/11 [07:42<06:25, 77.20s/it]

Epoch: 5 | disc_loss: -0.4156056046485901 | gen_loss: -0.0643736869096756


 64%|██████▎   | 7/11 [08:43<04:48, 72.07s/it]

Epoch: 6 | disc_loss: -0.42996877431869507 | gen_loss: -0.06048325076699257


 73%|███████▎  | 8/11 [09:45<03:25, 68.64s/it]

Epoch: 7 | disc_loss: -0.35278311371803284 | gen_loss: -0.08285617083311081


 82%|████████▏ | 9/11 [11:07<02:25, 72.79s/it]

Epoch: 8 | disc_loss: -0.39341050386428833 | gen_loss: -0.07045663893222809


 91%|█████████ | 10/11 [12:17<01:12, 72.13s/it]

Epoch: 9 | disc_loss: -0.4389007091522217 | gen_loss: -0.06099839508533478


100%|██████████| 11/11 [13:39<00:00, 74.51s/it]


Epoch: 10 | disc_loss: -0.4862366020679474 | gen_loss: -0.04345722123980522


Synthetic data generation: 100%|██████████| 3/3 [00:00<00:00, 82.04it/s]


In [48]:
data.columns

Index(['age', 'workclass', 'fnlwgt', 'education', 'education-num',
       'marital-status', 'occupation', 'relationship', 'race', 'sex',
       'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
       'target'],
      dtype='object')

In [49]:
a.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,target
0,49.0,4,898057.4,6,13.0,2,11,0,4,1,-171000.640625,-783.133789,-10.440917,40,1
1,68.0,4,853895.1,6,13.0,2,0,5,4,1,-77862.398438,-2994.961182,51.81192,39,1
2,68.0,4,1207276.0,6,13.0,2,0,0,4,1,-36544.351562,-1252.957642,74.764465,39,1
3,46.0,4,581865.3,6,13.0,2,9,5,4,1,-164715.796875,-1698.143066,2.250433,39,1
4,51.0,4,1481534.0,6,13.0,2,0,0,4,1,-105568.65625,-1675.071777,41.464359,39,1


In [50]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split


In [51]:
def evaluate_forest(original, generated, target_column):
  train_data = generated.drop([target_column], axis=1).to_numpy()
  train_values = generated[target_column].to_numpy()
  test_data = original.drop([target_column],axis=1).to_numpy()
  test_values = original[target_column].to_numpy()

  x_train, x_test, y_train, y_test = train_test_split(test_data, test_values)

  #classification
  model = RandomForestClassifier()
  test_model = RandomForestClassifier()
  
  model.fit(train_data, train_values)
  test_model.fit(x_train, y_train)

  model_score = model.score(test_data, test_values)
  test_score = test_model.score(x_test, y_test)

  return model_score, test_score

In [52]:
model_score = []
test_score = []
for _ in range(30):
    samples = synth.sample(1000)
    samples_df = pd.DataFrame(data, columns=['age', 'workclass', 'fnlwgt', 'education', 'education-num',
       'marital-status', 'occupation', 'relationship', 'race', 'sex',
       'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
       'target'])
    x, y = evaluate_forest(data, samples_df, target_column="capital-gain")
    model_score.append(x)
    test_score.append(y)

Synthetic data generation: 100%|██████████| 3/3 [00:00<00:00, 53.44it/s]
Synthetic data generation: 100%|██████████| 3/3 [00:00<00:00, 71.91it/s]
Synthetic data generation: 100%|██████████| 3/3 [00:00<00:00, 75.87it/s]
Synthetic data generation: 100%|██████████| 3/3 [00:00<00:00, 72.67it/s]
Synthetic data generation: 100%|██████████| 3/3 [00:00<00:00, 70.77it/s]


In [53]:
import matplotlib.pyplot as plt

In [54]:
x.mean()

0.9997747839973793

In [55]:
y.mean()

0.9169601179264597