# Fidelity

## Implementation

In [1]:
import os.path
import pandas as pd
import numpy as np
from os import listdir
from os.path import isfile, join

def bin_data(dt1, dt2, c = 10):
    dt1 = dt1.copy()
    dt2 = dt2.copy()
    # quantile binning of numerics
    num_cols = dt1.dtypes[dt1.dtypes!='object'].index
    for col in num_cols:
        # determine breaks based on `dt1`
        breaks = dt1[col].quantile(np.linspace(0, 1, c+1)).unique()
        dt1[col] = pd.cut(dt1[col], bins=breaks, include_lowest=True).astype(str)
        dt2_vals = pd.to_numeric(dt2[col], 'coerce')
        dt2_bins = pd.cut(dt2_vals, bins=breaks, include_lowest=True).astype(str)
        dt2_bins[dt2_vals < min(breaks)] = '_other_'
        dt2_bins[dt2_vals > max(breaks)] = '_other_'
        dt2[col] = dt2_bins
    # top-C binning of categoricals
    cat_cols = dt1.dtypes[dt1.dtypes=='object'].index
    for col in cat_cols:
        # determine top values based on `dt1`
        top_vals = dt1[col].value_counts().head(c).index.tolist()
        dt1[col].replace(np.setdiff1d(dt1[col].unique().tolist(), top_vals), '_other_', inplace=True)
        dt2[col].replace(np.setdiff1d(dt2[col].unique().tolist(), top_vals), '_other_', inplace=True)
    return [dt1, dt2]

def hellinger(p1, p2):
  return np.sqrt(1 - np.sum(np.sqrt(p1*p2)))

def kullback_leibler(p1, p2):
  idx = p1>0
  return np.sum(p1[idx] * np.log(p1[idx]/p2[idx]))

def jensen_shannon(p1, p2):
  m = 0.5 * (p1 + p2)
  return 0.5 * kullback_leibler(p1, m) + 0.5 * kullback_leibler(p2, m)

def fidelity(dt1, dt2, c = 100, k = 1):
    [dt1_bin, dt2_bin] = bin_data(dt1, dt2, c = c)
    # build grid of all cross-combinations
    cols = trn.columns
    interactions = pd.DataFrame(np.array(np.meshgrid(cols, cols, cols)).reshape(3, len(cols)**3).T)
    interactions.columns = ['dim1', 'dim2', 'dim3']
    if k == 1:
        interactions = interactions.loc[(interactions['dim1']==interactions['dim2']) & (interactions['dim2']==interactions['dim3'])]
    elif k == 2:
        interactions = interactions.loc[(interactions['dim1']<interactions['dim2']) & (interactions['dim2']==interactions['dim3'])]
    elif k == 3:
        interactions = interactions.loc[(interactions['dim1']<interactions['dim2']) & (interactions['dim2']<interactions['dim3'])]
    else:
        raise('k>3 not supported')

    results = []
    for idx in range(interactions.shape[0]):
        row = interactions.iloc[idx]
        val1 = dt1_bin[row.dim1] + dt1_bin[row.dim2] + dt1_bin[row.dim3]
        val2 = dt2_bin[row.dim1] + dt2_bin[row.dim2] + dt2_bin[row.dim3]
        freq1 = val1.value_counts(normalize=True).to_frame(name='p1')
        freq2 = val2.value_counts(normalize=True).to_frame(name='p2')
        freq = freq1.join(freq2, how='outer').fillna(0.0)
        p1 = freq['p1']
        p2 = freq['p2']
        out = pd.DataFrame({
          'k': k,
          'dim1': [row.dim1], 'dim2': [row.dim2], 'dim3': [row.dim3],
          'tvd': [np.sum(np.abs(p1 - p2)) / 2], 
          'mae': [np.mean(np.abs(p1 - p2))], 
          'max': [np.max(np.abs(p1 - p2))],
          'l1d': [np.sum(np.abs(p1 - p2))],
          'l2d': [np.sqrt(np.sum((p1 - p2)**2))],
          'hellinger': [hellinger(p1, p2)],
          'jensen_shannon': [jensen_shannon(p1, p2)]})
        results.append(out)

    return pd.concat(results)
    

## Test Drive

In [2]:
trn = pd.read_csv('data/credit-default_trn.csv.gz')
syn = pd.read_csv('data/credit-default_mostly.csv.gz')
#syn = pd.read_csv('data/credit-default_synthpop.csv.gz')

fidelity(trn, syn, k=1, c=100).agg('mean')

k                 1.000000
tvd               0.037641
mae               0.003392
max               0.013368
l1d               0.075281
l2d               0.019235
hellinger         0.034161
jensen_shannon    0.001531
dtype: float64

## Benchmark

In [4]:
# benchmark all
datasets = ['adult', 'credit-default', 'bank-marketing', 'online-shoppers']
fns = ['mostly', 'copulagan', 'ctgan', 'tvae', 'gaussian_copula', 'gretel', 'synthpop',
       'mostly_e1', 'mostly_e2', 'mostly_e4', 'mostly_e8', 'mostly_e16',
       'flip10', 'flip20', 'flip30', 'flip40', 'flip50', 
       'flip60', 'flip70', 'flip80', 'flip90',
       'val']

results = []
for dataset in datasets:
    trn = pd.read_csv('data/' + dataset + '_trn.csv.gz')
    for fn in fns:
        syn_fn = 'data/' + dataset  + '_' + fn + '.csv.gz'
        print(syn_fn)
        if (os.path.exists(syn_fn)):
            syn = pd.read_csv(syn_fn)
            fid1 = fidelity(trn, syn, k=1, c=100)
            fid2 = fidelity(trn, syn, k=2, c=10)
            fid3 = fidelity(trn, syn, k=3, c=5)
            out = pd.concat([fid1, fid2, fid3])
            out['dataset'] = dataset
            out['synthesizer'] = fn
            results.append(out)

x = pd.concat(results)
x.to_csv('fidelity.csv', index=False)
x

data/adult_mostly.csv.gz
data/adult_copulagan.csv.gz
data/adult_ctgan.csv.gz
data/adult_tvae.csv.gz
data/adult_gaussian_copula.csv.gz
data/adult_gretel.csv.gz
data/adult_synthpop.csv.gz
data/adult_mostly_e1.csv.gz
data/adult_mostly_e2.csv.gz
data/adult_mostly_e4.csv.gz
data/adult_mostly_e8.csv.gz
data/adult_mostly_e16.csv.gz
data/adult_flip10.csv.gz
data/adult_flip20.csv.gz
data/adult_flip30.csv.gz
data/adult_flip40.csv.gz
data/adult_flip50.csv.gz
data/adult_flip60.csv.gz
data/adult_flip70.csv.gz
data/adult_flip80.csv.gz
data/adult_flip90.csv.gz
data/adult_val.csv.gz
data/credit-default_mostly.csv.gz
data/credit-default_copulagan.csv.gz
data/credit-default_ctgan.csv.gz
data/credit-default_tvae.csv.gz
data/credit-default_gaussian_copula.csv.gz
data/credit-default_gretel.csv.gz


  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


data/credit-default_synthpop.csv.gz
data/credit-default_mostly_e1.csv.gz
data/credit-default_mostly_e2.csv.gz
data/credit-default_mostly_e4.csv.gz
data/credit-default_mostly_e8.csv.gz
data/credit-default_mostly_e16.csv.gz
data/credit-default_flip10.csv.gz
data/credit-default_flip20.csv.gz
data/credit-default_flip30.csv.gz
data/credit-default_flip40.csv.gz
data/credit-default_flip50.csv.gz
data/credit-default_flip60.csv.gz
data/credit-default_flip70.csv.gz
data/credit-default_flip80.csv.gz
data/credit-default_flip90.csv.gz
data/credit-default_val.csv.gz
data/bank-marketing_mostly.csv.gz
data/bank-marketing_copulagan.csv.gz
data/bank-marketing_ctgan.csv.gz
data/bank-marketing_tvae.csv.gz
data/bank-marketing_gaussian_copula.csv.gz
data/bank-marketing_gretel.csv.gz
data/bank-marketing_synthpop.csv.gz
data/bank-marketing_mostly_e1.csv.gz
data/bank-marketing_mostly_e2.csv.gz
data/bank-marketing_mostly_e4.csv.gz
data/bank-marketing_mostly_e8.csv.gz
data/bank-marketing_mostly_e16.csv.gz
data/b

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


data/online-shoppers_synthpop.csv.gz
data/online-shoppers_mostly_e1.csv.gz
data/online-shoppers_mostly_e2.csv.gz
data/online-shoppers_mostly_e4.csv.gz
data/online-shoppers_mostly_e8.csv.gz
data/online-shoppers_mostly_e16.csv.gz
data/online-shoppers_flip10.csv.gz
data/online-shoppers_flip20.csv.gz
data/online-shoppers_flip30.csv.gz
data/online-shoppers_flip40.csv.gz
data/online-shoppers_flip50.csv.gz
data/online-shoppers_flip60.csv.gz
data/online-shoppers_flip70.csv.gz
data/online-shoppers_flip80.csv.gz
data/online-shoppers_flip90.csv.gz
data/online-shoppers_val.csv.gz


Unnamed: 0,k,dim1,dim2,dim3,tvd,mae,max,l1d,l2d,hellinger,jensen_shannon,dataset,synthesizer
0,1,age,age,age,0.028415,0.001114,0.005074,0.056831,0.010763,0.029035,8.420200e-04,adult,mostly
0,1,workclass,workclass,workclass,0.005931,0.001186,0.003501,0.011862,0.005244,0.023622,4.055906e-04,adult,mostly
0,1,fnlwgt,fnlwgt,fnlwgt,0.034485,0.000690,0.002471,0.068970,0.008635,0.030822,9.494985e-04,adult,mostly
0,1,education,education,education,0.012151,0.001519,0.004540,0.024302,0.008049,0.013165,1.732776e-04,adult,mostly
0,1,education-num,education-num,education-num,0.011926,0.001704,0.004540,0.023853,0.008105,0.013329,1.775985e-04,adult,mostly
...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,3,Browser,Revenue,Weekend,0.000811,0.000811,0.000811,0.001622,0.001147,0.000793,6.288789e-07,online-shoppers,val
0,3,Region,Revenue,SpecialDay,0.009084,0.004542,0.009084,0.018167,0.012304,0.008111,6.579021e-05,online-shoppers,val
0,3,Region,Revenue,TrafficType,0.022709,0.002839,0.008110,0.045418,0.014866,0.021295,4.533528e-04,online-shoppers,val
0,3,Region,Revenue,VisitorType,0.012165,0.002028,0.007299,0.024331,0.009978,0.014966,2.219092e-04,online-shoppers,val


In [5]:
x.groupby(['dataset', 'synthesizer', 'k']).agg('mean').head(20)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,tvd,mae,max,l1d,l2d,hellinger,jensen_shannon
dataset,synthesizer,k,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
adult,copulagan,1,0.131132,0.024476,0.088083,0.262265,0.127586,0.140999,0.035665
adult,copulagan,2,0.207133,0.027109,0.093471,0.414266,0.147127,0.215095,0.063381
adult,copulagan,3,0.264201,0.012318,0.082287,0.528402,0.142305,0.275851,0.092314
adult,ctgan,1,0.158024,0.032822,0.109902,0.316048,0.150377,0.156599,0.040073
adult,ctgan,2,0.209406,0.026458,0.097115,0.418813,0.148934,0.215927,0.058027
adult,ctgan,3,0.263186,0.011949,0.081526,0.526372,0.140765,0.272194,0.083775
adult,flip10,1,0.005298,0.000953,0.002204,0.010597,0.003359,0.006028,5.1e-05
adult,flip10,2,0.016747,0.001341,0.005056,0.033495,0.008773,0.023427,0.001432
adult,flip10,3,0.029523,0.001281,0.006973,0.059046,0.013089,0.043117,0.003421
adult,flip20,1,0.005175,0.000588,0.001766,0.01035,0.002896,0.005788,5.4e-05
