In [1]:
import pandas as pd
import numpy as np
from stratified_dataset import ParallelStratifiedSynthesizer
from snsynth.mst import MSTSynthesizer
from snsynth.aim import AIMSynthesizer
from gem_synthesizer import GEMSynthesizer
import dill
from helpers.data_utils import get_employment, force_data_categorical_to_numeric
import itertools
import os
from IPython.display import clear_output
from stratified_dataset import StratifiedDataset
import warnings
warnings.filterwarnings('ignore')

all_data, features, target, group = get_employment()

df = all_data.copy()
df = df.drop(columns=['CIT', 'MIG', 'DEAR', 'DEYE', 'NATIVITY', 'ANC'])


  from .autonotebook import tqdm as notebook_tqdm


['RELP']


In [2]:
grouped_original_df = df.astype(float).groupby(['SEX','RAC1P']).mean().reset_index()
grouped_original_df['SEX'] = grouped_original_df['SEX'] - 1.0
grouped_original_df['RAC1P'] = grouped_original_df['RAC1P'] - 1.0
grouped_original_df

Unnamed: 0,SEX,RAC1P,AGEP,SCHL,MAR,DIS,ESP,MIL,DREM,ESR
0,0.0,0.0,1.725445,2.881674,2.954433,1.861228,0.472135,3.05002,1.85173,0.50128
1,0.0,1.0,1.466101,2.597544,3.641148,1.851314,0.928771,3.074569,1.834161,0.397581
2,0.0,2.0,1.490975,2.599278,3.505415,1.812274,1.050542,3.021661,1.877256,0.458484
3,0.0,3.0,2.0,3.5,3.0,2.0,0.0,4.0,2.0,1.0
4,0.0,4.0,1.36,2.4,3.672,1.816,1.248,2.976,1.776,0.384
5,0.0,5.0,1.48893,2.815129,2.918204,1.940467,0.410086,3.255228,1.878598,0.538868
6,0.0,6.0,1.309524,2.52381,3.571429,1.880952,0.952381,2.619048,1.761905,0.309524
7,0.0,7.0,1.210914,2.257175,3.605551,1.896337,1.035498,2.953172,1.825906,0.479796
8,0.0,8.0,0.960687,2.176718,3.854198,1.886641,1.265267,2.35687,1.692748,0.360687
9,1.0,0.0,1.841565,2.970697,2.83169,1.856746,0.426884,3.350475,1.861204,0.447228


In [72]:
with open('models/GEMSynthesizer_epsilon_5.0_SEX_RAC1P_seed_1.dill', "rb") as file:
    model = dill.load(file)
synth_df_strat = model.sample(len(df))

In [73]:
def calculate_disparity(real_train_df, synth_df, strata_cols, func):
    assert len(strata_cols) > 0, "strata_cols must be a list with at least one element"

    # Create multi-index DataFrames grouped by strata_cols
    real_grouped = real_train_df.groupby(strata_cols)
    synth_grouped = synth_df.groupby(strata_cols)

    # Initialize disparity as negative infinity
    max_disparity = float('-inf')
    max_key = None

    # Iterate over unique combinations of strata_cols
    for key in real_grouped.groups.keys():
        # Check if the group also exists in the synthetic data
        if key in synth_grouped.groups.keys():
            real_group = real_grouped.get_group(key)
            synth_group = synth_grouped.get_group(key)

            # Apply the function to the real and synthetic groups
            real_result = func(real_group)
            synth_result = func(synth_group)

            # Calculate the absolute difference normalized
            disparity = abs((real_result - synth_result) / real_result) 
            # If this disparity is greater than the current maximum, update maximum
            for disp in disparity:
                if not np.isinf(disp):
                    if disp > max_disparity:
                        max_disparity = disp
                        max_key = key
                        
    return max_disparity, max_key

In [74]:
def a_mean(df):
    return df.astype(float).mean().values
df = force_data_categorical_to_numeric(df, cat_columns=df.columns)
calculate_disparity(df, synth_df_strat, ['SEX','RAC1P'], a_mean)

(0.33082706766917297, (1, 2))

In [75]:
from sklearn.metrics import accuracy_score
from fairlearn.metrics import false_positive_rate, false_negative_rate, equalized_odds_ratio, demographic_parity_ratio
from fairlearn.metrics import MetricFrame
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier

def evaluate_on_dataframes_with_fairlearn(train_df, test_df, target_col = 'ESR'):
    # Feature columns
    feature_cols = [col for col in train_df.columns if col != target_col]

    # # Convert all columns to categorical
    for col in train_df.columns:
        train_df[col] = train_df[col].astype('float')

    for col in test_df.columns:
        test_df[col] = test_df[col].astype('float')

    test_df = test_df.dropna(subset=['SEX', 'RAC1P'])
    # Prepare the dataset
    X_train = train_df[feature_cols]
    y_train = train_df[target_col]
    X_test = test_df[feature_cols]
    y_test = test_df[target_col]

    # Train the classifier
    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(X_train, y_train)

    # Make predictions on the test set
    y_pred = clf.predict(X_test)

    # Evaluate the model
    accuracy = accuracy_score(y_test, y_pred)

    # Compute fairness metrics
    metrics = MetricFrame({
        'accuracy': accuracy_score,
        'false_positive_rate': false_positive_rate,
        'false_negative_rate': false_negative_rate,
    }, y_test.values, y_pred, sensitive_features=test_df[['SEX','RAC1P']])

    # Compute difference and ratio
    m_dif = metrics.difference()
    m_ratio = metrics.ratio()
    fpr_difference = m_dif['false_positive_rate']
    fnr_difference = m_dif['false_negative_rate']
    fpr_ratio = m_ratio['false_positive_rate']
    fnr_ratio = m_ratio['false_negative_rate']

    # Define sensitive features
    sensitive_features_test = X_test[['SEX','RAC1P']]  # Replace 'strata_cols' with the actual column name(s)

    # Compute equalized odds ratio
    eor = equalized_odds_ratio(y_true=y_test, 
                            y_pred=y_pred, 
                            sensitive_features=sensitive_features_test)

    # Compute equalized odds ratio
    dpr = demographic_parity_ratio(y_true=y_test, 
                            y_pred=y_pred, 
                            sensitive_features=sensitive_features_test)

    results = {
        'accuracy': accuracy,
        'false_positive_rate': metrics.overall['false_positive_rate'],
        'false_negative_rate': metrics.overall['false_negative_rate'],
        'fpr_difference': fpr_difference,
        'fnr_difference': fnr_difference,
        'fpr_ratio': fpr_ratio,
        'fnr_ratio': fnr_ratio,
        'equalized_odds_ratio': eor,
        'demographic_parity_ratio': dpr,
    }
    
    return results

In [76]:
target_col = 'ESR'
train_df, test_df = synth_df_strat[:int(len(synth_df_strat) * 0.8)], synth_df_strat[int(len(synth_df_strat) * 0.8):]
# Feature columns
feature_cols = [col for col in train_df.columns if col != target_col]

# # Convert all columns to categorical
for col in train_df.columns:
    train_df[col] = train_df[col].astype('float')

for col in test_df.columns:
    test_df[col] = test_df[col].astype('float')

test_df = test_df.dropna(subset=['SEX', 'RAC1P'])
# Prepare the dataset
X_train = train_df[feature_cols]
y_train = train_df[target_col]
X_test = test_df[feature_cols]
y_test = test_df[target_col]

# Train the classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Make predictions on the test set
y_pred = clf.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)

# Compute fairness metrics
metrics = MetricFrame({
    'accuracy': accuracy_score,
    'false_positive_rate': false_positive_rate,
    'false_negative_rate': false_negative_rate,
}, y_test.values, y_pred, sensitive_features=test_df[['SEX','RAC1P']])

# Compute difference and ratio
m_dif = metrics.difference()
m_ratio = metrics.ratio()
fpr_difference = m_dif['false_positive_rate']
fnr_difference = m_dif['false_negative_rate']
fpr_ratio = m_ratio['false_positive_rate']
fnr_ratio = m_ratio['false_negative_rate']

# Define sensitive features
sensitive_features_test = X_test[['SEX','RAC1P']]  # Replace 'strata_cols' with the actual column name(s)

# Compute equalized odds ratio
eor = equalized_odds_ratio(y_true=y_test, 
                        y_pred=y_pred, 
                        sensitive_features=sensitive_features_test)

# Compute equalized odds ratio
dpr = demographic_parity_ratio(y_true=y_test, 
                        y_pred=y_pred, 
                        sensitive_features=sensitive_features_test)

results = {
    'accuracy': accuracy,
    'false_positive_rate': metrics.overall['false_positive_rate'],
    'false_negative_rate': metrics.overall['false_negative_rate'],
    'fpr_difference': fpr_difference,
    'fnr_difference': fnr_difference,
    'fpr_ratio': fpr_ratio,
    'fnr_ratio': fnr_ratio,
    'equalized_odds_ratio': eor,
    'demographic_parity_ratio': dpr,
}


In [77]:
results

{'accuracy': 0.7806254767353166,
 'false_positive_rate': 0.26159016619860964,
 'false_negative_rate': 0.1673008120847294,
 'fpr_difference': 0.2793423364075538,
 'fnr_difference': 0.32883072733737456,
 'fpr_ratio': 0.21473765432098768,
 'fnr_ratio': 0.17482100498356945,
 'equalized_odds_ratio': 0.21473765432098768,
 'demographic_parity_ratio': 0.49616690061169977}

In [50]:
y_test[y_test==y_pred]

157574    1.0
157577    1.0
157578    1.0
157579    1.0
157580    1.0
         ... 
196961    1.0
196962    1.0
196963    0.0
196964    1.0
196965    1.0
Name: ESR, Length: 31373, dtype: float64

In [36]:
{'SEX': df['SEX'].values, 'RAC1P': df['RAC1P'].values}

{'SEX': array([1, 0, 1, ..., 0, 0, 1], dtype=int8),
 'RAC1P': array([0, 0, 7, ..., 0, 0, 0], dtype=int8)}

In [33]:
evaluate_on_dataframes_with_fairlearn(df[:int(len(df) * 0.8)], df[int(len(df) * 0.8):])

TypeError: '<' not supported between instances of 'int' and 'NoneType'

In [10]:
synth_df_strat['RAC1P'].unique()

array([0, 7, 5, 8, 1, 4, 2], dtype=int8)

In [135]:
grouped_synth_df_strat = synth_df_strat.groupby(['SEX','RAC1P']).mean().reset_index()
grouped_synth_df_strat.astype(float)

Unnamed: 0,SEX,RAC1P,AGEP,SCHL,MAR,DIS,ESP,MIL,DREM,ESR
0,0.0,0.0,1.727865,2.882697,1.955056,0.860787,0.476854,3.050562,1.851236,0.501236
1,0.0,1.0,1.47496,2.604219,2.628127,0.85136,0.929312,3.074608,1.837853,0.398606
2,0.0,2.0,2.114992,2.315144,2.929003,0.856307,0.67315,3.510008,1.834592,0.497168
3,0.0,5.0,1.492372,2.816737,1.918751,0.945834,0.406428,3.236526,1.883181,0.536788
4,0.0,7.0,1.217579,2.268556,2.60207,0.900981,1.046806,2.908396,1.813625,0.477949
5,0.0,8.0,0.960687,2.200382,2.835115,0.887405,1.283206,2.344275,1.693893,0.367176
6,1.0,0.0,1.843507,2.970889,1.830275,0.856916,0.424488,3.351976,1.861503,0.446718
7,1.0,1.0,1.681903,2.781197,2.621422,0.835555,0.762475,3.326618,1.854474,0.442262
8,1.0,2.0,1.948052,2.008658,2.562771,0.848485,1.47619,3.194805,2.0,0.515152
9,1.0,5.0,1.559815,2.799075,1.771646,0.925975,0.374091,3.346662,1.896894,0.461335


In [136]:
with open('models/MSTSynthesizer_epsilon_1.0_seed_0.dill', "rb") as file:
    model = dill.load(file)
synth_df = model.sample(len(df))

In [137]:
grouped_synth_df = synth_df.groupby(['SEX','RAC1P']).mean().reset_index()
# subtract 1 from SEX and RACE
grouped_synth_df['SEX'] = grouped_synth_df['SEX'] - 1.0
grouped_synth_df['RAC1P'] = grouped_synth_df['RAC1P'] - 1.0
grouped_synth_df.astype(float)

Unnamed: 0,SEX,RAC1P,AGEP,SCHL,MAR,DIS,ESP,MIL,DREM,ESR
0,0.0,0.0,1.777163,2.849468,2.902956,1.864874,0.51216,3.102743,1.85551,0.457984
1,0.0,1.0,1.501589,2.689223,3.612881,1.869987,0.730872,2.857965,1.836582,0.422928
2,0.0,2.0,1.583333,2.558333,3.5125,1.870833,0.991667,2.816667,1.841667,0.395833
3,0.0,3.0,1.631579,2.473684,3.473684,1.894737,0.631579,2.631579,1.947368,0.210526
4,0.0,4.0,1.520325,2.577236,3.528455,1.869919,0.715447,2.886179,1.821138,0.430894
5,0.0,5.0,1.726579,2.851897,2.836141,1.858563,0.516907,3.077082,1.852866,0.450006
6,0.0,6.0,2.058824,2.941176,3.470588,1.882353,0.117647,3.647059,1.882353,0.411765
7,0.0,7.0,1.483481,2.699075,3.547668,1.87087,0.706438,2.846328,1.838777,0.429866
8,0.0,8.0,1.34965,2.630107,3.815237,1.855355,0.800515,2.747516,1.818918,0.421789
9,1.0,0.0,1.733473,2.889531,2.881281,1.865536,0.482314,3.367462,1.858961,0.484255


In [138]:
# Merge the original dataframe with the vanilla synthetic dataframe using outer join
merged_diff = pd.merge(grouped_original_df, grouped_synth_df, on=['SEX', 'RAC1P'], how='outer', suffixes=('_orig', '_vanilla'))
merged_diff

Unnamed: 0,SEX,RAC1P,AGEP_orig,SCHL_orig,MAR_orig,DIS_orig,ESP_orig,MIL_orig,DREM_orig,ESR_orig,AGEP_vanilla,SCHL_vanilla,MAR_vanilla,DIS_vanilla,ESP_vanilla,MIL_vanilla,DREM_vanilla,ESR_vanilla
0,0.0,0.0,1.725445,2.881674,2.954433,1.861228,0.472135,3.05002,1.85173,0.50128,1.777163,2.849468,2.902956,1.864874,0.51216,3.102743,1.85551,0.457984
1,0.0,1.0,1.466101,2.597544,3.641148,1.851314,0.928771,3.074569,1.834161,0.397581,1.501589,2.689223,3.612881,1.869987,0.730872,2.857965,1.836582,0.422928
2,0.0,2.0,1.490975,2.599278,3.505415,1.812274,1.050542,3.021661,1.877256,0.458484,1.583333,2.558333,3.5125,1.870833,0.991667,2.816667,1.841667,0.395833
3,0.0,3.0,2.0,3.5,3.0,2.0,0.0,4.0,2.0,1.0,1.631579,2.473684,3.473684,1.894737,0.631579,2.631579,1.947368,0.210526
4,0.0,4.0,1.36,2.4,3.672,1.816,1.248,2.976,1.776,0.384,1.520325,2.577236,3.528455,1.869919,0.715447,2.886179,1.821138,0.430894
5,0.0,5.0,1.48893,2.815129,2.918204,1.940467,0.410086,3.255228,1.878598,0.538868,1.726579,2.851897,2.836141,1.858563,0.516907,3.077082,1.852866,0.450006
6,0.0,6.0,1.309524,2.52381,3.571429,1.880952,0.952381,2.619048,1.761905,0.309524,2.058824,2.941176,3.470588,1.882353,0.117647,3.647059,1.882353,0.411765
7,0.0,7.0,1.210914,2.257175,3.605551,1.896337,1.035498,2.953172,1.825906,0.479796,1.483481,2.699075,3.547668,1.87087,0.706438,2.846328,1.838777,0.429866
8,0.0,8.0,0.960687,2.176718,3.854198,1.886641,1.265267,2.35687,1.692748,0.360687,1.34965,2.630107,3.815237,1.855355,0.800515,2.747516,1.818918,0.421789
9,1.0,0.0,1.841565,2.970697,2.83169,1.856746,0.426884,3.350475,1.861204,0.447228,1.733473,2.889531,2.881281,1.865536,0.482314,3.367462,1.858961,0.484255


In [152]:
res_og_df = grouped_original_df - grouped_synth_df

In [145]:
filtered_grouped_df = pd.merge(
    grouped_synth_df_strat[['SEX', 'RAC1P']],
    grouped_original_df.reset_index(),
    on=['SEX', 'RAC1P'],
    how='left'
).set_index(['SEX', 'RAC1P'])

In [147]:
filtered_grouped_df.drop(columns=['index'], inplace=True)
filtered_grouped_df.reset_index(inplace=True)
filtered_grouped_df

Unnamed: 0,SEX,RAC1P,AGEP,SCHL,MAR,DIS,ESP,MIL,DREM,ESR
0,0,0,1.725445,2.881674,2.954433,1.861228,0.472135,3.05002,1.85173,0.50128
1,0,1,1.466101,2.597544,3.641148,1.851314,0.928771,3.074569,1.834161,0.397581
2,0,2,1.490975,2.599278,3.505415,1.812274,1.050542,3.021661,1.877256,0.458484
3,0,5,1.48893,2.815129,2.918204,1.940467,0.410086,3.255228,1.878598,0.538868
4,0,7,1.210914,2.257175,3.605551,1.896337,1.035498,2.953172,1.825906,0.479796
5,0,8,0.960687,2.176718,3.854198,1.886641,1.265267,2.35687,1.692748,0.360687
6,1,0,1.841565,2.970697,2.83169,1.856746,0.426884,3.350475,1.861204,0.447228
7,1,1,1.680235,2.775006,3.611339,1.83556,0.757704,3.332509,1.85626,0.441492
8,1,2,1.497835,2.666667,3.363636,1.861472,1.151515,3.108225,1.865801,0.38961
9,1,5,1.560112,2.811011,2.762809,1.925955,0.355169,3.353146,1.897753,0.459775


In [153]:
res_strat_df = filtered_grouped_df.astype(float) - grouped_synth_df_strat.astype(float)

In [164]:
np.abs(res_strat_df)['AGEP'].sum()

1.1415179896153436

In [165]:
np.abs(res_og_df)['AGEP'].sum()

4.257854852914879

In [159]:
res_og_df

Unnamed: 0,SEX,RAC1P,AGEP,SCHL,MAR,DIS,ESP,MIL,DREM,ESR
0,0.0,0.0,-0.051718,0.032207,0.051478,-0.003646,-0.040026,-0.052723,-0.00378,0.043295
1,0.0,0.0,-0.035488,-0.091678,0.028267,-0.018674,0.1979,0.216604,-0.002421,-0.025348
2,0.0,0.0,-0.092359,0.040945,-0.007085,-0.058559,0.058875,0.204994,0.03559,0.06265
3,0.0,0.0,0.368421,1.026316,-0.473684,0.105263,-0.631579,1.368421,0.052632,0.789474
4,0.0,0.0,-0.160325,-0.177236,0.143545,-0.053919,0.532553,0.089821,-0.045138,-0.046894
5,0.0,0.0,-0.237649,-0.036768,0.082063,0.081905,-0.106821,0.178146,0.025731,0.088862
6,0.0,0.0,-0.7493,-0.417367,0.10084,-0.001401,0.834734,-1.028011,-0.120448,-0.102241
7,0.0,0.0,-0.272567,-0.4419,0.057883,0.025467,0.329061,0.106844,-0.01287,0.04993
8,0.0,0.0,-0.388963,-0.453389,0.038961,0.031286,0.464752,-0.390645,-0.12617,-0.061102
9,0.0,0.0,0.108093,0.081166,-0.049592,-0.00879,-0.05543,-0.016988,0.002243,-0.037026


In [128]:
merged_df = pd.merge(grouped_synth_df, grouped_synth_df_strat, on=['SEX', 'RAC1P'], suffixes=('_vanilla', '_strat'))

In [103]:
df.astype(float).groupby(['SEX','RAC1P']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,AGEP,SCHL,MAR,DIS,ESP,MIL,DREM,ESR
SEX,RAC1P,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1.0,1.0,1.725445,2.881674,2.954433,1.861228,0.472135,3.05002,1.85173,0.50128
1.0,2.0,1.466101,2.597544,3.641148,1.851314,0.928771,3.074569,1.834161,0.397581
1.0,3.0,1.490975,2.599278,3.505415,1.812274,1.050542,3.021661,1.877256,0.458484
1.0,4.0,2.0,3.5,3.0,2.0,0.0,4.0,2.0,1.0
1.0,5.0,1.36,2.4,3.672,1.816,1.248,2.976,1.776,0.384
1.0,6.0,1.48893,2.815129,2.918204,1.940467,0.410086,3.255228,1.878598,0.538868
1.0,7.0,1.309524,2.52381,3.571429,1.880952,0.952381,2.619048,1.761905,0.309524
1.0,8.0,1.210914,2.257175,3.605551,1.896337,1.035498,2.953172,1.825906,0.479796
1.0,9.0,0.960687,2.176718,3.854198,1.886641,1.265267,2.35687,1.692748,0.360687
2.0,1.0,1.841565,2.970697,2.83169,1.856746,0.426884,3.350475,1.861204,0.447228


In [93]:
def get_subgroup_key(group, groupby_cols):
    key = []
    for col in groupby_cols:
        unique_values = group[col].unique()
        if len(unique_values) == 1:
            key.append((col, unique_values[0]))
        else:
            print(f"More than one unique value found for column '{col}' in the given group.")
            print(f"Unique values found: {unique_values}")
            print(f"Group:\n{group}")
            raise ValueError(f"More than one unique value found for column '{col}' in the given group.")
    return tuple(key)

def create_subgroups_dict(X, groupby_cols):
    subgroups = {}
    for _, group in X.groupby(groupby_cols):
        if not group.empty:
            key = get_subgroup_key(group, groupby_cols)
            subgroups[key] = group
        else:
            print('This weird thing happens sometimes where a group is empty. Not sure why.')
    return subgroups

def f(df, col='ESR'):
    return df[col].mean()

def parity_error_synth_data(X, X_prime, groupby_cols, f, omega):
    subgroups_real = create_subgroups_dict(X, groupby_cols)
    subgroups_synth = create_subgroups_dict(X_prime, groupby_cols)
    # print(len(subgroups_real.keys()))
    # print(len(subgroups_synth.keys()))
    f_values_real = []
    f_values_synth = []
    print(subgroups_real.keys())
    print(subgroups_synth.keys())
    # Calculate f and M values for each stratum
    for key, s in subgroups_real.items():
        f_value_real = f(s)
        f_values_real.append(f_value_real)
        
        if key in subgroups_synth:
            f_value_synth = f(subgroups_synth[key])
        else:
            f_value_synth = 0
            #f(X_prime)

        f_values_synth.append(f_value_synth)

    # Calculate the global f and M values
    f_global = f(X)
    f_synth_global = f(X_prime)

    # Compute the parity error
    beta = omega * abs(f_global - f_synth_global) + sum([abs(t - s) for t, s in zip(f_values_real, f_values_synth)])

    return beta

def f(df, col='ESR'):
    return df[col].mean()



In [37]:
# Define the groupby columns
strata_cols = ['SEX', 'RAC1P']

# Calculate the parity error
omega = 0.0
# error = parity_error_synth_data(df, synth_df, strata_cols, f, omega)
# print("Parity error:", error)

In [14]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
def force_data_categorical_to_numeric(df, cat_columns=[]):
    # convert columns to categorical if they are not already
    for col in cat_columns:
        if col in df.columns:
            df[col] = df[col].astype('category')
            df[col] = df[col].cat.codes
    return df
    
df_numeric = force_data_categorical_to_numeric(df, cat_columns=df.columns)
X_real = df_numeric.drop('ESR', axis=1)
y_real = df_numeric['ESR']
X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(X_real, y_real, test_size=0.2, random_state=42)
train_df_real = X_train_real.copy()
train_df_real['ESR'] = y_train_real
test_df_real = X_test_real.copy()
test_df_real['ESR'] = y_test_real

In [47]:
list(keys_strat.names)

['SEX', 'RAC1P']

In [59]:
def min_max_eval(train_df, test_df, strata_cols, target_col = 'ESR'):
    # Feature columns
    feature_cols = [col for col in train_df.columns if col != target_col]

    # Convert all columns to categorical
    for col in train_df.columns:
        train_df[col] = train_df[col].astype('category')

    # Prepare the dataset
    combinations = []
    for i in range(1, len(strata_cols) + 1):
        combinations.extend(list(itertools.combinations(strata_cols, i)))
    
    accuracies = []
    for combination in combinations:
        keys_strat = synth_df_strat[list(combination)].value_counts().keys()
        for key in keys_strat:
            for var in keys_strat.names:
                if list(keys_strat.names) == ['SEX', 'RAC1P']:
                    subset = synth_df_strat.loc[(synth_df_strat['SEX'] == key[0]) & (synth_df_strat['RAC1P'] == key[1])]
                elif list(keys_strat.names) == ['RAC1P']:
                    subset = synth_df_strat.loc[(synth_df_strat['RAC1P'] == key[0])]
                else:
                    subset = synth_df_strat.loc[(synth_df_strat['SEX'] == key[0])]
            X_train = train_df[feature_cols]
            y_train = train_df[target_col]
            X_test = subset[feature_cols]
            y_test = subset[target_col]

            # Train the classifier
            clf = RandomForestClassifier(n_estimators=100, random_state=42)
            clf.fit(X_train, y_train)

            # Make predictions on the test set
            y_pred = clf.predict(X_test)

            # Evaluate the model
            accuracy = accuracy_score(y_test, y_pred)
            accuracies.append((accuracy, (str(list(combination)), key)))
            print((accuracy, (str(list(combination)), key)))
    min_tuple = min(accuracies, key=lambda x: x[0])
    max_tuple = max(accuracies, key=lambda x: x[0])

    return (min_tuple[0], max_tuple[0])

def evaluate_worst_accuracy(train_df, test_df, strata_cols, target_col = 'ESR'):
    # Feature columns
    feature_cols = [col for col in train_df.columns if col != target_col]

    # Convert all columns to categorical
    for col in train_df.columns:
        train_df[col] = train_df[col].astype('category')
    
    for col in test_df.columns:
        test_df[col] = test_df[col].astype('category')

    # Create subgroups for train_df and test_df using strata_cols
    train_groups = train_df.groupby(strata_cols)
    test_groups = test_df.groupby(strata_cols)

    accuracies = []

    for train_key, train_group in train_groups:
        # try:
        print(train_key)
        test_group = test_groups.get_group(train_key)
        # except KeyError:
        #     print(f"No test group found for key: {train_key}")
        #     continue
        
        X_train = train_group[feature_cols]
        y_train = train_group[target_col]
        X_test = test_group[feature_cols]
        y_test = test_group[target_col]

        # Train the classifier
        clf = RandomForestClassifier(n_estimators=100, random_state=42)
        clf.fit(X_train, y_train)

        # Make predictions on the test set
        y_pred = clf.predict(X_test)

        # Evaluate the model
        accuracy = accuracy_score(y_test, y_pred)
        accuracies.append(accuracy)

    # Return the worst accuracy
    return min(accuracies)

In [21]:
# Prepare the dataset
combinations = []
for i in range(1, len(['SEX','RAC1P']) + 1):
    combinations.extend(list(itertools.combinations(['SEX','RAC1P'], i)))

In [28]:
keys_strat = synth_df_strat[list(combinations[2])].value_counts().keys()

In [None]:
for key in keys_strat:
    subset = synth_df_strat.loc[(synth_df_strat['SEX'] == key[0]) & (synth_df_strat['RAC1P'] == key[1])]
    print(subset)

In [63]:
with open('models/GEMSynthesizer_epsilon_0.1_seed_0.dill', "rb") as file:
    model = dill.load(file)
synth_df_strat = model.sample(len(df))

In [64]:
evaluate_on_dataframes(synth_df_strat, test_df_real, strata_cols)

(0.7968677359367736, ("['SEX']", (1,)))
(0.812962152818532, ("['SEX']", (0,)))
(0.8027939515077779, ("['RAC1P']", (0,)))
(0.800578273440727, ("['RAC1P']", (1,)))
(0.8059396697590291, ("['RAC1P']", (5,)))
(0.8076536761751707, ("['RAC1P']", (7,)))
(0.829687884000983, ("['RAC1P']", (8,)))
(0.8450834879406308, ("['RAC1P']", (6,)))
(0.8495850622406639, ("['RAC1P']", (3,)))
(0.8652849740932642, ("['RAC1P']", (2,)))
(0.8661870503597122, ("['RAC1P']", (4,)))
(0.7963190530672204, ("['SEX', 'RAC1P']", (1, 0)))
(0.8096066104893962, ("['SEX', 'RAC1P']", (0, 0)))
(0.7887178731865381, ("['SEX', 'RAC1P']", (1, 1)))
(0.8157275891261405, ("['SEX', 'RAC1P']", (0, 1)))
(0.7971257736284485, ("['SEX', 'RAC1P']", (1, 5)))
(0.8169865895345779, ("['SEX', 'RAC1P']", (0, 5)))
(0.796093023255814, ("['SEX', 'RAC1P']", (1, 7)))
(0.8212180746561886, ("['SEX', 'RAC1P']", (0, 7)))
(0.8226196230062832, ("['SEX', 'RAC1P']", (1, 8)))
(0.837, ("['SEX', 'RAC1P']", (0, 8)))
(0.8443708609271523, ("['SEX', 'RAC1P']", (1, 6))

(0.0865556497893919,
 ("['SEX', 'RAC1P']", (1, 1)),
 ("['SEX', 'RAC1P']", (0, 3)))

In [None]:
import numpy as np
import random

def random_queries(num_queries, columns, min_values, max_values):
    queries = []
    for _ in range(num_queries):
        query = {}
        for col in columns:
            query[col] = (random.uniform(min_values[col], max_values[col]), random.uniform(min_values[col], max_values[col]))
        queries.append(query)
    return queries

def query_result(df, query):
    mask = np.full(df.shape[0], True, dtype=bool)
    for col, (min_val, max_val) in query.items():
        mask &= (df[col] >= min_val) & (df[col] <= max_val)
    return df[mask].shape[0]

def f(X, X_prime, num_queries=10, columns=None):
    if columns is None:
        columns = X.columns
    
    min_values = {col: min(X[col].min(), X_prime[col].min()) for col in columns}
    max_values = {col: max(X[col].max(), X_prime[col].max()) for col in columns}
    
    queries = random_queries(num_queries, columns, min_values, max_values)
    errors = []
    for query in queries:
        result_real = query_result(X, query)
        result_synth = query_result(X_prime, query)
        error = abs(result_real - result_synth)
        if result_real > 0:
            error /= result_real
        errors.append(error)
    return np.mean(errors)