# Add custom generators and metrics
Benchmarking requires a common API, where generators have a fit() and generate() method and metrics have a compute() method. You can add custom generators and metrics by subclassing the BaseGenerator and BaseMetric classes, so these are interoperable with the benchmarking framework.

In [7]:
import numpy as np
import pandas as pd
import warnings

In [8]:
# ignore warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# autoreload changes from local files
%load_ext autoreload
%autoreload 2

# pandas show full output
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 200)

In [1]:
from crnsynth.serialization import paths
from crnsynth.benchmark.benchmark import benchmark_generators
from crnsynth.benchmark.review import SyntheticDataReview
from crnsynth.metrics.privacy.dcr import DistanceClosestRecord
from crnsynth.generators.marginal import MarginalGenerator
from crnsynth.generators.base import BaseGenerator
from crnsynth.metrics.base import BaseMetric
from crnsynth.processing.preprocessing import split_train_holdout

## Add custom generator

For example here we add the CTGAN generator from SDV.

In [2]:
from sdv.single_table import CTGANSynthesizer

class CTGANGenerator(BaseGenerator):
    def __init__(self, metadata):
        self.generator = CTGANSynthesizer(metadata)
    
    def fit(self, real_data):
        self.generator.fit(real_data)

    def generate(self, n):
        return self.generator.sample(n)

In [3]:
from sdv.datasets.demo import download_demo

df, metadata = download_demo(
    modality='single_table',
    dataset_name='fake_hotel_guests'
)

df_train, df_holdout = split_train_holdout(df, holdout_size=0.2)
df_train.head()

Unnamed: 0,guest_email,has_rewards,room_type,amenities_fee,checkin_date,checkout_date,room_rate,billing_address,credit_card_number
165,robertomorris@long.com,True,SUITE,0.0,07 Dec 2020,10 Dec 2020,230.38,"77 Massachusetts Ave\nCambridge, MA 02139",561674411369
144,juanwatson@chung.org,True,SUITE,0.0,20 Feb 2020,21 Feb 2020,204.06,"5678 Office Road\nSan Francisco, CA 94103",4432016585990225
362,stephanie09@sexton-spencer.com,False,BASIC,9.88,17 Feb 2020,19 Feb 2020,108.53,"77 Massachusetts Ave\nCambridge, MA 02139",4285102311649280378
467,belljose@goodwin-farrell.net,False,BASIC,23.41,06 Nov 2020,07 Nov 2020,103.17,"45274 Andrew Bridge\nNew Elizabethton, FL 05037",4534892217780995
231,elizabethedwards@edwards.net,False,BASIC,34.92,03 Aug 2020,04 Aug 2020,107.92,"9280 Laura Prairie Suite 706\nRomerochester, M...",6576864934777079


In [9]:
generator = CTGANGenerator(metadata=metadata)
generator.fit(df_train)
df_synth = generator.generate(1000)
df_synth.head()

Unnamed: 0,guest_email,has_rewards,room_type,amenities_fee,checkin_date,checkout_date,room_rate,billing_address,credit_card_number
0,sarah63@example.net,False,BASIC,0.0,04 Apr 2020,27 Jun 2020,86.31,"63520 John Stream\nDawnview, DE 36491",4384922664587519
1,sarahortiz@example.com,False,DELUXE,0.0,07 Jan 2021,08 Jan 2021,352.39,"36680 Lori Village Apt. 439\nEast Ray, MS 95382",5568254531057640
2,dominguezchristopher@example.net,False,SUITE,25.34,17 Jun 2020,08 Jan 2021,86.62,"74167 Laura Street Apt. 788\nWest Robertfurt, ...",4816260594545267
3,cantrelljoshua@example.org,True,DELUXE,0.0,04 Sep 2020,23 Oct 2020,118.07,"710 Fuller Station Apt. 020\nEast Beth, NV 87797",4618523508526800
4,yvonnenorton@example.org,False,DELUXE,,24 Apr 2020,19 May 2020,83.8,"097 Christina Avenue\nGordonfort, AL 38610",5526638030276589


## Add custom metric

In [10]:
class AverageCardinalityPreserved(BaseMetric):
    """Compute percentage of cardinality of categorical columns in synthetic data compared to real data."""

    
    def __init__(self, categorical_columns=None):
        self.categorical_columns = categorical_columns
        
        self.scores_ = {}
    
    @staticmethod
    def type() -> str:
        return 'similarity'
    
    @staticmethod
    def direction() -> str:
        return 'maximize'
    
    def compute(self, real_data, synthetic_data, holdout=None, **kwargs):
        if self.categorical_columns is None:
            self.categorical_columns = real_data.select_dtypes(include='object').columns.tolist()
        
        cardinality_scores = np.zeros(len(self.categorical_columns))
        for i, col in enumerate(self.categorical_columns):
            # get unique values of each column
            unique_real = real_data[col].unique()
            unique_synth = synthetic_data[col].unique()
            
            # calculate percentage of categories from real data that are in the synth data
            percentage_overlap = len(set(unique_real).intersection(set(unique_synth))) / len(unique_real)
            cardinality_scores[i] = percentage_overlap
            
        # take average of all columns
        self.scores_['score'] = np.mean(cardinality_scores)
        return self.scores_
    
metric_cardinality = AverageCardinalityPreserved(categorical_columns=['has_rewards', 'room_type'])
metric_cardinality.compute(df_train, df_synth)
                            
    

{'score': 1.0}

Metrics can also be imported from other libraries. These just need to have the compute() method. For the popular libraries, like `synthcity`, we created a wrapper class that allows you to import the metric and use it in the benchmarking framework.

In [12]:
from crnsynth.integration.metrics import SynthcityMetricWrapper
from synthcity.metrics.eval_statistical import JensenShannonDistance
    
sc_js = SynthcityMetricWrapper(metric=JensenShannonDistance(), encoder='ordinal')
sc_js.compute(df_train, df_synth, df_holdout)

{'marginal': 0.06122683436385463}

## Benchmark

In [13]:
reviewer = SyntheticDataReview(
    metrics=[AverageCardinalityPreserved(), DistanceClosestRecord(quantile=0.05), sc_js],
    metric_kwargs = {'categorical_columns': ['has_rewards', 'room_type']}
)

benchmark_generators(
    data_real=df_train,
    data_holdout=df_holdout,    
    generators=[MarginalGenerator(epsilon=0.1), CTGANGenerator(metadata=metadata)],
    reviewer=reviewer,
    path_out = paths.PATH_RESULTS / 'fake_hotel_guests',
)

Running generator MarginalGenerator
Fitting generator MarginalGenerator on input data
Marginal fitted: guest_email
Marginal fitted: has_rewards
Marginal fitted: room_type
Marginal fitted: amenities_fee
Marginal fitted: checkin_date
Marginal fitted: checkout_date
Marginal fitted: room_rate
Marginal fitted: billing_address
Marginal fitted: credit_card_number
Generator fitted. Generating 400 records
Column sampled: guest_email
Column sampled: has_rewards
Column sampled: room_type
Column sampled: amenities_fee
Column sampled: checkin_date
Column sampled: checkout_date
Column sampled: room_rate
Column sampled: billing_address
Column sampled: credit_card_number
Saved to disk: /Users/dknoors/Projects/synthesis-dk/crn-synth/results/fake_hotel_guests/configs/0_MarginalGenerator.json
Saved synthetic data, generator and configs for 0_MarginalGenerator at /Users/dknoors/Projects/synthesis-dk/crn-synth/results/fake_hotel_guests
Running reviewer for 0_MarginalGenerator
Running metric AverageCardinal

In [14]:
# load results
scores_benchmark = pd.read_csv(paths.PATH_RESULTS / 'fake_hotel_guests/reports/scores.csv')
scores_benchmark

Unnamed: 0,metric,0_MarginalGenerator,1_CTGANGenerator
0,AverageCardinalityPreserved_score,1.0,1.0
1,DistanceClosestRecord_holdout,1.0,1.0
2,DistanceClosestRecord_synth,0.004453,1.0
3,JensenShannonDistance_marginal,0.019594,0.061232
