# Add custom generators and metrics
Benchmarking requires a common API, where generators have a fit() and generate() method and metrics have a compute() method. You can add custom generators and metrics by subclassing the BaseGenerator and BaseMetric classes, so these are interoperable with the benchmarking framework.

In [6]:
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


In [2]:
from crnsynth.serialization import paths
from crnsynth.benchmark.benchmark import benchmark_generators
from crnsynth.benchmark.review import SyntheticDataReview
from crnsynth.metrics.privacy.dcr import DistanceClosestRecord
from crnsynth.generators.marginal import MarginalGenerator
from crnsynth.generators.base import BaseGenerator
from crnsynth.metrics.base import BaseMetric
from crnsynth.processing.preprocessing import split_train_holdout

## Add custom generator

For example here we add the CTGAN generator from SDV.

In [3]:
from sdv.single_table import CTGANSynthesizer

class CTGANGenerator(BaseGenerator):
    def __init__(self, metadata):
        self.generator = CTGANSynthesizer(metadata)
    
    def fit(self, real_data):
        self.generator.fit(real_data)

    def generate(self, n):
        return self.generator.sample(n)

In [7]:
from sdv.datasets.demo import download_demo

df, metadata = download_demo(
    modality='single_table',
    dataset_name='fake_hotel_guests'
)

df_train, df_holdout = split_train_holdout(df, holdout_size=0.2)
df_train.head()

Unnamed: 0,guest_email,has_rewards,room_type,amenities_fee,checkin_date,checkout_date,room_rate,billing_address,credit_card_number
414,adamsmark@phillips-barnes.com,False,BASIC,46.27,26 May 2020,27 May 2020,141.68,"7874 Joshua Hills Apt. 837\nFullerfort, HI 81547",2224569421948456
84,myersmonica@jacobs.net,False,BASIC,10.45,04 Feb 2020,,155.61,"216 Stephanie Islands\nElainechester, DE 62055",4997172877158950
437,joshua15@gaines.com,True,DELUXE,0.0,22 Jun 2020,24 Jun 2020,204.11,"361 Compton Harbor\nYvetteland, KS 17306",4523783681085860804
259,castromelissa@scott-flores.com,False,BASIC,19.93,21 May 2020,22 May 2020,119.39,"0784 Todd Manors\nJonesmouth, WY 42593",4150943237171848754
331,elizabethvaldez@torres.org,False,BASIC,19.36,15 Sep 2020,17 Sep 2020,115.68,"77 Massachusetts Ave\nCambridge, MA 02139",4626586438747


In [8]:
generator = CTGANGenerator(metadata=metadata)
generator.fit(df_train)
df_synth = generator.generate(1000)
df_synth.head()

Unnamed: 0,guest_email,has_rewards,room_type,amenities_fee,checkin_date,checkout_date,room_rate,billing_address,credit_card_number
0,marshallmatthew@example.net,False,DELUXE,0.0,03 Aug 2020,07 Oct 2020,144.96,"401 Gould Glen Suite 244\nWalterville, FM 20462",6543609979607503
1,nrodriguez@example.org,False,SUITE,5.45,06 Jan 2021,10 Oct 2020,149.44,"815 Michael Throughway\nToddhaven, MP 80482",6503840662536059
2,whitney60@example.org,False,BASIC,0.0,16 Nov 2020,07 Jan 2020,218.57,"030 Diane Parks\nBartonmouth, ND 67470",4228737522997993
3,anthonyrosales@example.org,False,BASIC,0.0,27 Sep 2020,04 Feb 2020,167.6,Unit 1489 Box 4235\nDPO AE 64043,3544641794792067
4,brentaguirre@example.com,False,SUITE,,25 Oct 2020,29 Apr 2020,257.67,"735 Hobbs Ridges\nPort Anthony, ND 57086",3529271628138698


## Add custom metric

In [9]:
class AverageCardinalityPreserved(BaseMetric):
    """Compute percentage of cardinality of categorical columns in synthetic data compared to real data."""

    
    def __init__(self, categorical_columns=None):
        self.categorical_columns = categorical_columns
        
        self.scores_ = {}
    
    @staticmethod
    def type() -> str:
        return 'similarity'
    
    @staticmethod
    def direction() -> str:
        return 'maximize'
    
    def compute(self, real_data, synthetic_data, holdout=None):
        if self.categorical_columns is None:
            self.categorical_columns = real_data.select_dtypes(include='object').columns.tolist()
        
        cardinality_scores = np.zeros(len(self.categorical_columns))
        for i, col in enumerate(self.categorical_columns):
            # get unique values of each column
            unique_real = real_data[col].unique()
            unique_synth = synthetic_data[col].unique()
            
            # calculate percentage of categories from real data that are in the synth data
            percentage_overlap = len(set(unique_real).intersection(set(unique_synth))) / len(unique_real)
            cardinality_scores[i] = percentage_overlap
            
        # take average of all columns
        self.scores_['score'] = np.mean(cardinality_scores)
        return self.scores_
    
metric_cardinality = AverageCardinalityPreserved(categorical_columns=['has_rewards', 'room_type'])
metric_cardinality.compute(df_train, df_synth)
                            
    

{'score': 1.0}

## Benchmark

In [10]:
reviewer = SyntheticDataReview(
    metrics=[AverageCardinalityPreserved(), DistanceClosestRecord(quantile=0.05)],
    metric_kwargs = {'categorical_columns': ['has_rewards', 'room_type']}
)

benchmark_generators(
    data_real=df_train,
    data_holdout=df_holdout,    
    generators=[MarginalGenerator(epsilon=0.1), CTGANGenerator(metadata=metadata)],
    reviewer=reviewer,
    path_out = paths.PATH_RESULTS / 'fake_hotel_guests',
)

Running generator MarginalGenerator
Fitting generator MarginalGenerator on input data
Marginal fitted: guest_email
Marginal fitted: has_rewards
Marginal fitted: room_type
Marginal fitted: amenities_fee
Marginal fitted: checkin_date
Marginal fitted: checkout_date
Marginal fitted: room_rate
Marginal fitted: billing_address
Marginal fitted: credit_card_number
Generator fitted. Generating 400 records
Column sampled: guest_email
Column sampled: has_rewards
Column sampled: room_type
Column sampled: amenities_fee
Column sampled: checkin_date
Column sampled: checkout_date
Column sampled: room_rate
Column sampled: billing_address
Column sampled: credit_card_number
Saved synthetic data and generator for 0_MarginalGenerator at /Users/dknoors/Projects/synthesis-dk/crn-synth/results/fake_hotel_guests
Running reviewer for 0_MarginalGenerator
Running metric AverageCardinalityPreserved
Running metric DistanceClosestRecord
Running generator CTGANGenerator
Fitting generator CTGANGenerator on input data


In [11]:
# load results
scores_benchmark = pd.read_csv(paths.PATH_RESULTS / 'fake_hotel_guests/reports/scores.csv')
scores_benchmark

Unnamed: 0,metric,0_MarginalGenerator,1_CTGANGenerator
0,AverageCardinalityPreserved_score,0.833333,1.0
1,DistanceClosestRecord_holdout,1.0,1.0
2,DistanceClosestRecord_synth,0.003982,1.0
