In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn, optim
from torch.autograd import Variable

from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn import preprocessing
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification
from sklearn.metrics import precision_score, recall_score, roc_auc_score

from table_evaluator import TableEvaluator

from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning

import os.path, sys
from tests.utils import load_adult

import warnings

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader

In [None]:
log.add(sink=sys.stderr, level="INFO")
sys.path.append(os.getcwd())

In [None]:
import PreProcessData

clean_df = PreProcessData.clean_df('data/adult.data')
#arr_X, arr_y, adult_df = load_adult() # arr_X np array without label/target, arr_y np array of just label/target

clean_df

In [None]:
loader = GenericDataLoader(
    clean_df,
    target_column="label",
    sensitive_columns=["race","sex","native-country"],
)

loader

In [None]:
# DAG and bias dicts for debiasing adult dataset
df_dag = [
    ('race', 'marital-status'), ('race', 'education'), ('race','occupation'), ('race','hours-per-week'), ('race','label'),

    ('age', 'marital-status'), ('age', 'education'), ('age','occupation'), ('age','hours-per-week'), ('age','workclass'), ('age','relationship'), ('age','label'),

    ('sex', 'marital-status'), ('sex', 'education'), ('sex','occupation'), ('sex','hours-per-week'), ('sex','workclass'), ('sex','relationship'), ('sex','label'),

    ('native-country', 'marital-status'), ('native-country', 'education'), ('native-country','hours-per-week'), ('native-country','workclass'), ('native-country','relationship'), ('native-country','label'),

    ('marital-status', 'education'), ('marital-status','occupation'), ('marital-status','hours-per-week'), ('marital-status','workclass'), ('marital-status','relationship'), ('marital-status','label'),

    ('education','occupation'), ('education','hours-per-week'), ('education','workclass'), ('education','relationship'), ('education','label'),

    ('occupation', 'label'),
    ('hours-per-week', 'label'),
    ('workclass', 'label'),
    ('relationship', 'label')
]

ftu = {'sex':['label']}

dp = {'sex':['label'],
      'marital-status':['label'],
      'relationship':['label'],
      'occupation':['label'],
      'hours-per-week':['label'],
      'workclass':['label'],
      'education':['label']    
    }

cf = {'sex':['label'],
      'marital-status':['label'],
      'relationship':['label']
    }

In [None]:
# synthcity absolute
from synthcity.plugins import Plugins
from synthcity.plugins.privacy import plugin_decaf

syn_model = plugin_decaf.DECAFPlugin(n_iter=10, n_iter_baseline=100)

syn_model.fit(loader,dag=df_dag)

In [None]:
synth_ftu = syn_model.generate(count=70000, biased_edges=ftu).dataframe()
synth_ftu

In [None]:
from Metrics import get_metrics

get_metrics("DECAF-FTU", clean_df, synth_ftu)

In [None]:
synth_dp = syn_model.generate(count=70000, biased_edges=dp).dataframe()
synth_dp

In [None]:
get_metrics("DECAF-DP", clean_df, synth_dp)

In [None]:
synth_cf = syn_model.generate(count=70000, biased_edges=cf).dataframe()
synth_cf

In [None]:
get_metrics("DECAF-cf", clean_df, synth_cf)

In [None]:
# Synthcity Benchmarking doesn't support DECAF at the moment. Can't pass in DAG or bias dict. Will need to build benchmarking from scratch for DECAF and other external GANs
# from SynthBenchmarks import Benchmarks

# score = Benchmarks.evaluate(
#     [("Test Decaf", "decaf", {"n_iter": 10, "n_iter_baseline": 100})],
#     loader,
#     synthetic_size=1000,
#     repeats=2
#     #synthetic_cache=False
#     #synthetic_reuse_if_exists=True
# )