# Add custom generators and metrics
Benchmarking requires a common API, where generators have a fit() and generate() method and metrics have a compute() method. You can add custom generators and metrics by subclassing the BaseGenerator and BaseMetric classes, so these are interoperable with the benchmarking framework.

In [8]:
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


In [2]:
from crnsynth.serialization import paths
from crnsynth.benchmark.benchmark import benchmark_generators
from crnsynth.benchmark.review import SyntheticDataReview
from crnsynth.metrics.privacy.dcr import DistanceClosestRecord
from crnsynth.generators.marginal_generator import MarginalGenerator
from crnsynth.generators.base_generator import BaseGenerator
from crnsynth.metrics.base_metric import BaseMetric
from crnsynth.processing.preprocessing import split_train_holdout

## Add custom generator

For example here we add the CTGAN generator from SDV.

In [3]:
from sdv.single_table import CTGANSynthesizer

class CTGANGenerator(BaseGenerator):
    def __init__(self, metadata):
        self.generator = CTGANSynthesizer(metadata)
    
    def fit(self, real_data):
        self.generator.fit(real_data)

    def generate(self, n):
        return self.generator.sample(n)

In [4]:
from sdv.datasets.demo import download_demo

df, metadata = download_demo(
    modality='single_table',
    dataset_name='fake_hotel_guests'
)

df_train, df_holdout = split_train_holdout(df, holdout_size=0.2)
df_train.head()

Unnamed: 0,guest_email,has_rewards,room_type,amenities_fee,checkin_date,checkout_date,room_rate,billing_address,credit_card_number
276,sullivanjoshua@dorsey-obrien.com,False,BASIC,,27 Sep 2020,29 Sep 2020,137.67,"PSC 3710, Box 0259\nAPO AA 17597",4214783741962784604
90,melissa09@bryan.com,False,BASIC,0.72,15 Feb 2020,17 Feb 2020,162.9,"1234 Corporate Drive\nBoston, MA 02116",4491757306201
32,spencershawn@clark.com,False,BASIC,30.86,14 Nov 2020,16 Nov 2020,119.96,"5678 Office Road\nSan Francisco, CA 94103",4196831074465
74,lance88@snow-rodriguez.net,False,DELUXE,27.41,16 Oct 2020,19 Oct 2020,181.27,"42400 Bryan View Apt. 310\nGarciaside, NY 01999",213100856267969
38,toddkaitlin@leon-collier.com,True,BASIC,0.0,20 Feb 2020,23 Feb 2020,98.41,"5678 Office Road\nSan Francisco, CA 94103",4942094262703149


In [5]:
generator = CTGANGenerator(metadata=metadata)
generator.fit(df_train)
df_synth = generator.generate(1000)
df_synth.head()

See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.  (Deprecated NumPy 1.25)
  return np.find_common_type(types, [])
See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.  (Deprecated NumPy 1.25)
  return np.find_common_type(types, [])
See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.  (Deprecated NumPy 1.25)
  return np.find_common_type(types, [])
See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.  (Deprecated NumPy 1.25)
  return np.find_common_type(types, [])


Unnamed: 0,guest_email,has_rewards,room_type,amenities_fee,checkin_date,checkout_date,room_rate,billing_address,credit_card_number
0,laurencamacho@example.org,False,DELUXE,10.77,14 Sep 2020,29 May 2020,127.32,"138 Amanda Dale Apt. 815\nEast Susantown, MO 4...",4230479351149
1,xhoffman@example.net,False,SUITE,19.17,03 Nov 2020,08 Jan 2021,154.18,95472 Christopher Fall Apt. 623\nPort Sabrinat...,30303133881252
2,timothypeters@example.org,True,BASIC,9.61,08 Sep 2020,28 Jun 2020,132.91,"804 Dalton Springs Apt. 328\nPort Royfort, SD ...",4755628253310338134
3,edudley@example.org,False,BASIC,0.0,06 Jan 2020,21 Jul 2020,261.38,"46722 Finley Isle\nLake Kathybury, NC 81662",180074697686936
4,adam06@example.com,True,SUITE,,29 Jan 2020,02 Mar 2020,136.54,Unit 2430 Box 4042\nDPO AE 86481,180051900850376


## Add custom metric

In [6]:
class AverageCardinalityPreserved(BaseMetric):
    """Compute percentage of cardinality of categorical columns in synthetic data compared to real data."""

    
    def __init__(self, categorical_columns=None):
        self.categorical_columns = categorical_columns
        
        self.scores_ = {}
        
    def compute(self, real_data, synthetic_data, holdout=None):
        if self.categorical_columns is None:
            self.categorical_columns = real_data.select_dtypes(include='object').columns.tolist()
        
        cardinality_scores = np.zeros(len(self.categorical_columns))
        for i, col in enumerate(self.categorical_columns):
            # get unique values of each column
            unique_real = real_data[col].unique()
            unique_synth = synthetic_data[col].unique()
            
            # calculate percentage of categories from real data that are in the synth data
            percentage_overlap = len(set(unique_real).intersection(set(unique_synth))) / len(unique_real)
            cardinality_scores[i] = percentage_overlap
            
        # take average of all columns
        self.scores_['score'] = np.mean(cardinality_scores)
        return self.scores_
    
metric_cardinality = AverageCardinalityPreserved(categorical_columns=['has_rewards', 'room_type'])
metric_cardinality.compute(df_train, df_synth)
                            
    

{'score': 1.0}

## Benchmark

In [14]:
reviewer = SyntheticDataReview(
    metrics=[AverageCardinalityPreserved(), DistanceClosestRecord(quantile=0.05)],
    metric_kwargs = {'categorical_columns': ['has_rewards', 'room_type']}
)

benchmark_generators(
    data_real=df_train,
    data_holdout=df_holdout,    
    generators=[MarginalGenerator(epsilon=0.1), CTGANGenerator(metadata=metadata)],
    reviewer=reviewer,
    path_out = paths.PATH_RESULTS / 'fake_hotel_guests',
)

Running generator MarginalGenerator
Fitting generator MarginalGenerator on input data
Marginal fitted: guest_email
Marginal fitted: has_rewards
Marginal fitted: room_type
Marginal fitted: amenities_fee
Marginal fitted: checkin_date
Marginal fitted: checkout_date
Marginal fitted: room_rate
Marginal fitted: billing_address
Marginal fitted: credit_card_number
Generator fitted. Generating 400 records
Column sampled: guest_email
Column sampled: has_rewards
Column sampled: room_type
Column sampled: amenities_fee
Column sampled: checkin_date
Column sampled: checkout_date
Column sampled: room_rate
Column sampled: billing_address
Column sampled: credit_card_number
Saved synthetic data and generator for 0_MarginalGenerator at /Users/dknoors/Projects/synthesis-dk/crn-synth/results/fake_hotel_guests
Running reviewer for 0_MarginalGenerator
Running metric AverageCardinalityPreserved
Running metric DistanceClosestRecord
Running generator CTGANGenerator
Fitting generator CTGANGenerator on input data




Generator fitted. Generating 400 records
Saved synthetic data and generator for 1_CTGANGenerator at /Users/dknoors/Projects/synthesis-dk/crn-synth/results/fake_hotel_guests
Running reviewer for 1_CTGANGenerator
Running metric AverageCardinalityPreserved
Running metric DistanceClosestRecord
Saved scores at /Users/dknoors/Projects/synthesis-dk/crn-synth/results/fake_hotel_guests/reports/scores.csv


In [15]:
# load results
scores_benchmark = pd.read_csv(paths.PATH_RESULTS / 'fake_hotel_guests/reports/scores.csv')
scores_benchmark

Unnamed: 0,metric,0_MarginalGenerator,1_CTGANGenerator
0,AverageCardinalityPreserved_score,1.0,1.0
1,DistanceClosestRecord_holdout,1.0,1.0
2,DistanceClosestRecord_synth,0.004376,1.0
