# Example: DP-CTGAN and tabular data

**IMPORTANT:** refer to the [README]("https://github.com/kasra-hosseini/privgem#credits") for credits.

In [None]:
# solve issue with autocomplete
%config Completer.use_jedi = False

%load_ext autoreload
%autoreload 2
%matplotlib inline

from warnings import simplefilter
# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)

In [None]:
import os
import pandas as pd
import sys

## Load a tabular data

In [None]:
from ctgan import load_demo

data = load_demo()
data.head()

### Name of the columns that are discrete

In [None]:
discrete_columns = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'income']

## Synthesize using DP-CTGAN

### Split the data into train/test

In [None]:
from privgem import tabular_utils

path_save="./test_dpctgan/orig_data/orig_data.csv"
path_train="./test_dpctgan/orig_data/orig_train.csv"
path_test="./test_dpctgan/orig_data/orig_test.csv"

tabular_utils.split_save_orig_data(data,
                                   path_save=path_save,
                                   path_train=path_train,
                                   path_test=path_test,
                                   label_col="income",
                                   test_size=0.25,
                                   random_state=42)

### Instantiate a tabular_dpctgan object

In [None]:
from privgem import tabular_dpctgan

# inputs
epsilon = 1
sigma = 5
batch_size = 64
# default: 300
epochs = 10
output_save_path = "./test_dpctgan/dpctgan_training.csv"
device = "cuda:0" # or "default" or "cpu" or "cuda:1"

dp_model = tabular_dpctgan(verbose=True, 
                           epsilon=epsilon, 
                           batch_size=batch_size, 
                           sigma=sigma,
                           secure_rng=False,
                           epochs=epochs,
                           output_save_path=output_save_path,
                           device=device)

### train a new model

Note that this can take a long time (depends on the hyperparams)

In [None]:
dp_model.train(data, discrete_columns)

### Plot the training log file

In [None]:
from privgem import tabular_utils

tabular_utils.plot_log_dpctgan(filename="./test_dpctgan/dpctgan_training.csv")

### sample and save the output

In [None]:
synth_output = dp_model.sample(len(data))

path2synth_file = "./test_dpctgan/dpctgan_001/synthetic_output.csv"
os.makedirs(os.path.dirname(path2synth_file), exist_ok=True)
synth_output.to_csv(path2synth_file, index=False)

synth_output