# Add custom generators and metrics
Benchmarking requires a common API, where generators have a fit() and generate() method and metrics have a compute() method. You can add custom generators and metrics by subclassing the BaseGenerator and BaseMetric classes, so these are interoperable with the benchmarking framework.

In [9]:
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


In [2]:
from crnsynth.serialization import paths
from crnsynth.benchmark.benchmark import benchmark_generators
from crnsynth.benchmark.review import SyntheticDataReview
from crnsynth.metrics.privacy.dcr import DistanceClosestRecord
from crnsynth.generators.marginal_generator import MarginalGenerator
from crnsynth.generators.base_generator import BaseGenerator
from crnsynth.metrics.base_metric import BaseMetric
from crnsynth.processing.preprocessing import split_train_holdout

## Add custom generator

For example here we add the CTGAN generator from SDV.

In [3]:
from sdv.single_table import CTGANSynthesizer

class CTGANGenerator(BaseGenerator):
    def __init__(self, metadata):
        self.generator = CTGANSynthesizer(metadata)
    
    def fit(self, real_data):
        self.generator.fit(real_data)

    def generate(self, n):
        return self.generator.sample(n)

In [4]:
from sdv.datasets.demo import download_demo

df, metadata = download_demo(
    modality='single_table',
    dataset_name='fake_hotel_guests'
)

df_train, df_holdout = split_train_holdout(df, holdout_size=0.2)
df_train.head()

Unnamed: 0,guest_email,has_rewards,room_type,amenities_fee,checkin_date,checkout_date,room_rate,billing_address,credit_card_number
495,laurabennett@jones-duncan.net,False,BASIC,8.71,04 Jan 2021,06 Jan 2021,103.25,"5678 Office Road\nSan Francisco, CA 94103",3505516387300030
65,craiglawson@wilson.com,False,BASIC,23.72,20 Jul 2020,21 Jul 2020,120.34,"463 Simmons Forks\nPort Eric, VA 11253",38651500078643
174,obrienbrenda@gentry.biz,False,BASIC,11.62,25 Feb 2020,27 Feb 2020,132.06,"5678 Office Road\nSan Francisco, CA 94103",3554356011481199
367,ptaylor@rhodes-johnson.org,False,BASIC,,05 Jul 2020,08 Jul 2020,109.8,"77 Massachusetts Ave\nCambridge, MA 02139",3527546197874381
237,jack98@riley-roberson.biz,False,BASIC,22.54,20 May 2020,24 May 2020,154.27,853 Alexandra Center Apt. 179\nEast Christinev...,180006078094389


In [10]:
generator = CTGANGenerator(metadata=metadata)
generator.fit(df_train)
df_synth = generator.generate(1000)
df_synth.head()

Unnamed: 0,guest_email,has_rewards,room_type,amenities_fee,checkin_date,checkout_date,room_rate,billing_address,credit_card_number
0,dawsonalexander@example.com,False,BASIC,0.1,17 May 2020,04 Aug 2020,261.28,"1225 Melissa Neck\nLake Dennis, AS 41867",4657132433426482
1,mmills@example.org,False,SUITE,10.39,05 Jan 2020,20 May 2020,170.0,"0374 Karen Island\nFranklinmouth, IL 83721",347942082550059
2,castrojeremy@example.net,True,BASIC,26.95,23 Jun 2020,27 Oct 2020,255.35,"448 Adkins Field\nWest Ryanburgh, DC 13761",4188558408990755
3,kwerner@example.net,False,BASIC,25.12,05 Jan 2020,15 Oct 2020,105.41,"686 Vance Route Suite 272\nCampbellmouth, UT 7...",4928393065485
4,josecross@example.com,False,SUITE,,08 Jul 2020,05 Dec 2020,218.43,"61851 Stone Via Apt. 773\nSouth Allenville, KY...",503842511399


## Add custom metric

In [6]:
class AverageCardinalityPreserved(BaseMetric):
    """Compute percentage of cardinality of categorical columns in synthetic data compared to real data."""

    
    def __init__(self, categorical_columns=None):
        self.categorical_columns = categorical_columns
        
        self.scores_ = {}
    
    @staticmethod
    def type() -> str:
        return 'similarity'
    
    @staticmethod
    def direction() -> str:
        return 'maximize'
    
    def compute(self, real_data, synthetic_data, holdout=None):
        if self.categorical_columns is None:
            self.categorical_columns = real_data.select_dtypes(include='object').columns.tolist()
        
        cardinality_scores = np.zeros(len(self.categorical_columns))
        for i, col in enumerate(self.categorical_columns):
            # get unique values of each column
            unique_real = real_data[col].unique()
            unique_synth = synthetic_data[col].unique()
            
            # calculate percentage of categories from real data that are in the synth data
            percentage_overlap = len(set(unique_real).intersection(set(unique_synth))) / len(unique_real)
            cardinality_scores[i] = percentage_overlap
            
        # take average of all columns
        self.scores_['score'] = np.mean(cardinality_scores)
        return self.scores_
    
metric_cardinality = AverageCardinalityPreserved(categorical_columns=['has_rewards', 'room_type'])
metric_cardinality.compute(df_train, df_synth)
                            
    

{'score': 1.0}

## Benchmark

In [11]:
reviewer = SyntheticDataReview(
    metrics=[AverageCardinalityPreserved(), DistanceClosestRecord(quantile=0.05)],
    metric_kwargs = {'categorical_columns': ['has_rewards', 'room_type']}
)

benchmark_generators(
    data_real=df_train,
    data_holdout=df_holdout,    
    generators=[MarginalGenerator(epsilon=0.1), CTGANGenerator(metadata=metadata)],
    reviewer=reviewer,
    path_out = paths.PATH_RESULTS / 'fake_hotel_guests',
)

Running generator MarginalGenerator
Fitting generator MarginalGenerator on input data
Marginal fitted: guest_email
Marginal fitted: has_rewards
Marginal fitted: room_type
Marginal fitted: amenities_fee
Marginal fitted: checkin_date
Marginal fitted: checkout_date
Marginal fitted: room_rate
Marginal fitted: billing_address
Marginal fitted: credit_card_number
Generator fitted. Generating 400 records
Column sampled: guest_email
Column sampled: has_rewards
Column sampled: room_type
Column sampled: amenities_fee
Column sampled: checkin_date
Column sampled: checkout_date
Column sampled: room_rate
Column sampled: billing_address
Column sampled: credit_card_number
Saved synthetic data and generator for 0_MarginalGenerator at /Users/dknoors/Projects/synthesis-dk/crn-synth/results/fake_hotel_guests
Running reviewer for 0_MarginalGenerator
Running metric AverageCardinalityPreserved
Running metric DistanceClosestRecord
Running generator CTGANGenerator
Fitting generator CTGANGenerator on input data


In [12]:
# load results
scores_benchmark = pd.read_csv(paths.PATH_RESULTS / 'fake_hotel_guests/reports/scores.csv')
scores_benchmark

Unnamed: 0,metric,0_MarginalGenerator,1_CTGANGenerator
0,AverageCardinalityPreserved_score,1.0,1.0
1,DistanceClosestRecord_holdout,1.0,1.0
2,DistanceClosestRecord_synth,0.004561,1.0
