# Testing Synthesizers

SDV APIs have been changed. This notebook tests the code to get them running again.

In [1]:
import pandas as pd
import numpy as np
import torch

from sdv.metadata import SingleTableMetadata

  from .autonotebook import tqdm as notebook_tqdm


## CTGAN

In [3]:
from sdv.single_table import CTGANSynthesizer

In [25]:
data = pd.read_csv('../data/adult_trn.csv.gz')
data.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,23,Private,161708,Bachelors,13,Never-married,Other-service,Own-child,White,Female,0,0,30,United-States,<=50K
1,37,Private,114605,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,40,United-States,<=50K
2,35,Private,320305,HS-grad,9,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,32,United-States,<=50K
3,26,Private,106856,Assoc-voc,11,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,40,United-States,<=50K
4,43,Local-gov,147328,Masters,14,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,1977,60,United-States,>50K


In [26]:
# get metadata
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=data)

In [27]:
data.dtypes

age                int64
workclass         object
fnlwgt             int64
education         object
education-num      int64
marital-status    object
occupation        object
relationship      object
race              object
sex               object
capital-gain       int64
capital-loss       int64
hours-per-week     int64
native-country    object
income            object
dtype: object

In [28]:
metadata

{
    "columns": {
        "age": {
            "sdtype": "numerical"
        },
        "workclass": {
            "sdtype": "categorical"
        },
        "fnlwgt": {
            "sdtype": "numerical"
        },
        "education": {
            "sdtype": "categorical"
        },
        "education-num": {
            "sdtype": "numerical"
        },
        "marital-status": {
            "sdtype": "categorical"
        },
        "occupation": {
            "sdtype": "categorical"
        },
        "relationship": {
            "sdtype": "categorical"
        },
        "race": {
            "sdtype": "categorical"
        },
        "sex": {
            "sdtype": "categorical"
        },
        "capital-gain": {
            "sdtype": "numerical"
        },
        "capital-loss": {
            "sdtype": "numerical"
        },
        "hours-per-week": {
            "sdtype": "numerical"
        },
        "native-country": {
            "sdtype": "categorical"
        },
    

In [None]:
# this looks OK for all datasets
# we could hard-code a check here if we wanted to be absolutely sure

In [9]:
model = CTGANSynthesizer(metadata)

In [11]:
%%time
model.fit(data)

KeyboardInterrupt: 

In [None]:
%%time
samples = model.sample(50000)

In [None]:
%%time
samples.to_csv('../data_new/' + dataset + '_ctgan.csv', index=False)

In [29]:
datasets = ['adult', 'marketing', 'online-shoppers', 'credit-default']

In [32]:
data = pd.read_csv('../data/credit-default_trn.csv.gz')
data.dtypes

LIMIT_BAL                     float64
SEX                             int64
EDUCATION                       int64
MARRIAGE                        int64
AGE                             int64
PAY_0                           int64
PAY_2                           int64
PAY_3                           int64
PAY_4                           int64
PAY_5                           int64
PAY_6                           int64
BILL_AMT1                       int64
BILL_AMT2                     float64
BILL_AMT3                     float64
BILL_AMT4                     float64
BILL_AMT5                     float64
BILL_AMT6                     float64
PAY_AMT1                      float64
PAY_AMT2                      float64
PAY_AMT3                      float64
PAY_AMT4                      float64
PAY_AMT5                      float64
PAY_AMT6                      float64
default payment next month      int64
dtype: object

In [33]:
data = data.astype({'SEX': 'object', 'EDUCATION': 'object', 'MARRIAGE': 'object'})

In [36]:
data.dtypes

LIMIT_BAL                     float64
SEX                            object
EDUCATION                      object
MARRIAGE                       object
AGE                             int64
PAY_0                           int64
PAY_2                           int64
PAY_3                           int64
PAY_4                           int64
PAY_5                           int64
PAY_6                           int64
BILL_AMT1                       int64
BILL_AMT2                     float64
BILL_AMT3                     float64
BILL_AMT4                     float64
BILL_AMT5                     float64
BILL_AMT6                     float64
PAY_AMT1                      float64
PAY_AMT2                      float64
PAY_AMT3                      float64
PAY_AMT4                      float64
PAY_AMT5                      float64
PAY_AMT6                      float64
default payment next month      int64
dtype: object

In [37]:
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=data) 

In [38]:
metadata

{
    "columns": {
        "LIMIT_BAL": {
            "sdtype": "numerical"
        },
        "SEX": {
            "sdtype": "numerical"
        },
        "EDUCATION": {
            "sdtype": "numerical"
        },
        "MARRIAGE": {
            "sdtype": "numerical"
        },
        "AGE": {
            "sdtype": "numerical"
        },
        "PAY_0": {
            "sdtype": "numerical"
        },
        "PAY_2": {
            "sdtype": "numerical"
        },
        "PAY_3": {
            "sdtype": "numerical"
        },
        "PAY_4": {
            "sdtype": "numerical"
        },
        "PAY_5": {
            "sdtype": "numerical"
        },
        "PAY_6": {
            "sdtype": "numerical"
        },
        "BILL_AMT1": {
            "sdtype": "numerical"
        },
        "BILL_AMT2": {
            "sdtype": "numerical"
        },
        "BILL_AMT3": {
            "sdtype": "numerical"
        },
        "BILL_AMT4": {
            "sdtype": "numerical"
        }

In [40]:
# we need to set SEX, EDUCATION and MARRIAGE to categorical manually
for col in ['SEX', 'EDUCATION', 'MARRIAGE']:
    metadata.update_column(
        column_name=col,
        sdtype='categorical'
    )

In [41]:
metadata

{
    "columns": {
        "LIMIT_BAL": {
            "sdtype": "numerical"
        },
        "SEX": {
            "sdtype": "categorical"
        },
        "EDUCATION": {
            "sdtype": "categorical"
        },
        "MARRIAGE": {
            "sdtype": "categorical"
        },
        "AGE": {
            "sdtype": "numerical"
        },
        "PAY_0": {
            "sdtype": "numerical"
        },
        "PAY_2": {
            "sdtype": "numerical"
        },
        "PAY_3": {
            "sdtype": "numerical"
        },
        "PAY_4": {
            "sdtype": "numerical"
        },
        "PAY_5": {
            "sdtype": "numerical"
        },
        "PAY_6": {
            "sdtype": "numerical"
        },
        "BILL_AMT1": {
            "sdtype": "numerical"
        },
        "BILL_AMT2": {
            "sdtype": "numerical"
        },
        "BILL_AMT3": {
            "sdtype": "numerical"
        },
        "BILL_AMT4": {
            "sdtype": "numerical"
   

In [None]:
%%time
# generate all CTGAN synthetic datasets

for dataset in datasets:
    print('CTGAN ' + dataset)
    np.random.seed(0)
    torch.manual_seed(0)
    data = pd.read_csv('../data/' + dataset + '_trn.csv.gz')
# the SDV metadata generator overrides these dtypes
#    if dataset == 'credit-default':
#        data = data.astype({'SEX': 'object', 'EDUCATION': 'object', 'MARRIAGE': 'object'})
    metadata = SingleTableMetadata()
    metadata.detect_from_dataframe(data=data) 
    
    if dataset == 'credit-default':
        # we need to set SEX, EDUCATION and MARRIAGE to categorical manually here
        for col in ['SEX', 'EDUCATION', 'MARRIAGE']:
            metadata.update_column(
                column_name=col,
                sdtype='categorical'
            )
    
    model = CTGANSynthesizer(metadata)
    model.fit(data)
    samples = model.sample(50000)
    samples.to_csv('../data_new/' + dataset + '_ctgan.csv', index=False)