# 04 Disjoint generative models on data supplied in a fragmented state

In many projects involving synthetic data generation the source data is often supplied in a workable state. Conversely, in some projects, especially large colaborative endevours, data is supplied from multiple labs gathering different variables on the same population. In this case study we explore how/if disjoint generative models may improve on the prospect set from the preprossessing steps when data is supplied in a fragmented state. 

## Artificially fragmented data

In order to explore if there is any gain in using disjoint generative models on data that is supplied already partitioned, we will artificially break a dataset so that we can sompare the results to the original. We will focus on the Diabetic Mellitus dataset, and create random subsets of the variables - introduce random missingness across the whole dataset, and missing records in the different subsets to simulate a realistic scenario where the data supplied from the diffrent stakeholders suffer from different problems.

In [1]:
import numpy as np
import pandas as pd

from typing import Dict, List

from disjoint_generative_model.utils.dataset_manager import random_split_columns

df_train = pd.read_csv('experiments\datasets\diabetic_mellitus_train.csv')
df_test = pd.read_csv('experiments\datasets\diabetic_mellitus_test.csv')

def simulate_fragmented_data(dataset, fragment_sizes: Dict[str, float], p_value_miss: float = 0.02, p_row_miss: float = 0.03, verbose: bool = True) -> Dict[str, pd.DataFrame]:
    """
    Systematically break a dataset into partitions with missing elements
    """
    fragments = random_split_columns(dataset, fragment_sizes)

    # drop 2% of all values in the dataset
    dataset = dataset.mask(np.random.rand(*dataset.shape) < p_value_miss)

    if verbose: print(f"Value missingness: {dataset.isnull().sum().sum()/dataset.size}")
    
    dataset.reset_index(inplace=True) # simulate social security number

    partitions = {name: dataset[['index']+fragment] for name, fragment in fragments.items()}

    # drop rows at random
    for name, partition in partitions.items():
        partitions[name] = partition.drop(partition.sample(frac=p_row_miss).index)

    if verbose: 
        data_assembled = pd.concat(partitions.values(), axis=1)
        print(f"Total missingness: {data_assembled.isnull().sum().sum()/data_assembled.size}")

    for name, partition in partitions.items():
        partition.reset_index(drop=True, inplace=True)
    return partitions

parts = simulate_fragmented_data(df_train, {'A': 0.5, 'B': 0.3, 'C': 0.2}, p_row_miss=0.05, p_value_miss=0.07)
parts['A'].head()

Value missingness: 0.0727437641723356
Total missingness: 0.1159075907590759


Unnamed: 0,index,DBP,SFH,DRB,CFS_2,CDC,OCP,WGT,SEX,HDC,...,NRV_2,TRM,SLS,IRT,CIT,DZN,HBP,EXT_2,CFS_3,IRT_2
0,0,110.0,0.0,0.0,0.0,0.0,3.0,65.0,2.0,1.0,...,0.0,0.0,,0.0,0.0,1.0,,1.0,0.0,0.0
1,1,60.0,1.0,0.0,0.0,0.0,3.0,,1.0,1.0,...,1.0,0.0,,0.0,0.0,0.0,0.0,1.0,0.0,1.0
2,2,80.0,,0.0,0.0,0.0,3.0,70.0,2.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,3,90.0,0.0,1.0,0.0,0.0,2.0,57.0,1.0,0.0,...,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0,0.0
4,4,40.0,0.0,0.0,0.0,0.0,2.0,67.0,1.0,0.0,...,0.0,0.0,0.0,0.0,,,0.0,1.0,0.0,0.0


In [12]:
# Make a pipeline for repairing the data like one would usually do in practice

from sklearn.impute import KNNImputer

def repair_data(partitions: Dict[str, pd.DataFrame], verbose: bool = True) -> pd.DataFrame:
    """
    Repair the data by filling in missing values
    """
    repaired_dataset = partitions['A'].copy()

    for partition in list(partitions.values())[1:]:
        repaired_dataset = repaired_dataset.merge(partition, how='outer', on='index')

    if verbose: print(f"Total missingness: {repaired_dataset.isnull().sum().sum()/repaired_dataset.size}")

    repaired_dataset.drop('index', axis=1, inplace=True)

    # impute using knn imputer
    imputer = KNNImputer(n_neighbors=1)
    repaired_dataset = pd.DataFrame(imputer.fit_transform(repaired_dataset), columns=repaired_dataset.columns)

    return repaired_dataset.reset_index(drop=True)

repaired_dataset = repair_data(parts)
repaired_dataset = repaired_dataset[df_train.columns]
repaired_dataset.head()

Total missingness: 0.11676767676767677


Unnamed: 0,AGE,GLU,DBP,BMI,WGT,OCP,SEX,DIT,MST,RSB,...,PRE,ETN,CIT,FTG_5,SBF,WTG,CSS,LHV,PLH,TYPE
0,51.0,2.4,110.0,21.0,65.0,3.0,2.0,2.0,3.0,2.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
1,64.0,12.2,60.0,18.0,60.0,3.0,1.0,1.0,3.0,1.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0
2,23.0,3.9,80.0,33.0,70.0,3.0,2.0,1.0,3.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,61.0,11.9,90.0,18.0,57.0,2.0,1.0,1.0,3.0,2.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
4,74.0,14.3,40.0,22.5,67.0,2.0,1.0,1.0,3.0,3.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


## Conduct experiments 

In the experiments we aim on comparing regular synthesis pipeline and DGMs to the repaired and origin data

In [18]:

from disjoint_generative_model import DisjointGenerativeModels
from disjoint_generative_model.utils.joining_validator import JoiningValidator
from disjoint_generative_model.utils.joining_strategies import UsingJoiningValidator
from disjoint_generative_model.utils.generative_model_adapters import generate_synthetic_data

def disjoint_scenario(partitions: Dict[str, pd.DataFrame], models = List[str], join_multiplier: int = 4, verbose: bool = True) -> pd.DataFrame:
    """
    Simulate a disjoint scenario where synthetic data are generated based on the individual partitions
    """

    imputed_data = repair_data(partitions, verbose=False)

    joining_validator_model = JoiningValidator(save_proba=True)
    joining_method = UsingJoiningValidator(join_validator_model=joining_validator_model)
    joining_method.max_size = len(imputed_data)

    joining_method.join_validator.fit_classifier(imputed_data, num_batches_of_bad_joins=2)

    syns = {}
    if verbose: print(f"Generating...")
    for model, (name, partition) in zip(models, partitions.items()):
        partition = partition.drop('index', axis=1)
        partition = partition.dropna(axis=0)
        syn_part = generate_synthetic_data(partition, model, num_to_generate=join_multiplier*len(imputed_data))
        syns[name] = syn_part.reset_index(drop=True)

    if verbose: print(f"Now joining the synthetic data")
    assembled_data = joining_method.join(syns)

    return assembled_data

syn_data = disjoint_scenario(parts, models=['synthpop', 'dpgan', 'datasynthesizer'], verbose=True)
syn_data = syn_data[df_train.columns]
syn_data.head()

Validator: No search parameters specified. Using default configuration.
Validator: Calibration improved the model from 0.0189 to 0.0178
Generating...


[2025-03-16T09:13:26.775467+0100][10908][CRITICAL] module disabled: c:\Users\danho\AppData\Local\Programs\Python\Python310\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
 20%|█▉        | 399/2000 [01:21<05:25,  4.92it/s]


Adding ROOT CSS
Adding attribute RIH
Adding attribute BMI
Adding attribute FTG_5
Adding attribute IIU
Adding attribute FTG_4
Adding attribute FAW
Adding attribute RPP
Adding attribute LCR
Adding attribute TLF
Adding attribute SBF
Adding attribute NSA_2
Adding attribute MST
Adding attribute CFS
Adding attribute EYP
Adding attribute SAD
Adding attribute ETN
Adding attribute LHV
Adding attribute CVS
Now joining the synthetic data
Threshold auto-set to: 0.5
Predicted good joins fraction: 0.14888888888888888
Predicted good joins fraction: 0.06919060052219321
Predicted good joins fraction: 0.0546984572230014


Unnamed: 0,AGE,GLU,DBP,BMI,WGT,OCP,SEX,DIT,MST,RSB,...,PRE,ETN,CIT,FTG_5,SBF,WTG,CSS,LHV,PLH,TYPE
0,65,10.032462,80,23.136477,58,3,2,2.0,2.0,3.0,...,0,0.0,0,0.0,0.0,0,0.0,0.0,0,1.0
1,65,8.176858,87,18.69494,82,3,1,2.0,3.0,3.0,...,0,0.0,0,1.0,1.0,0,0.0,0.0,0,1.0
2,72,8.622371,87,20.528088,83,2,2,2.0,3.0,3.0,...,0,0.0,0,1.0,0.0,1,0.0,0.0,0,1.0
3,70,10.609033,70,21.937537,58,3,2,2.0,3.0,3.0,...,0,0.0,0,0.0,0.0,0,0.0,0.0,0,1.0
4,70,7.976543,80,21.239573,76,3,2,2.0,3.0,3.0,...,0,0.0,0,0.0,1.0,0,0.0,0.0,0,1.0


In [21]:
df_base = generate_synthetic_data(repaired_dataset, 'dpgan', num_to_generate=len(repaired_dataset))
df_base.head()

[2025-03-16T10:15:21.628342+0100][10908][CRITICAL] module disabled: c:\Users\danho\AppData\Local\Programs\Python\Python310\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
 15%|█▍        | 299/2000 [05:14<29:49,  1.05s/it]  


Unnamed: 0,AGE,GLU,DBP,BMI,WGT,OCP,SEX,DIT,MST,RSB,...,PRE,ETN,CIT,FTG_5,SBF,WTG,CSS,LHV,PLH,TYPE
0,56.433159,11.15854,82.503295,14.002731,31.0,1.0,1.0,2.0,2.0,3.0,...,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0
1,58.070791,11.15854,84.097229,17.797582,31.0,1.0,1.0,2.0,2.0,3.0,...,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0
2,58.174965,11.15854,83.386248,14.47827,31.0,1.0,1.0,2.0,2.0,3.0,...,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0
3,56.502774,11.15854,82.619444,16.473061,31.0,1.0,1.0,2.0,2.0,3.0,...,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0
4,82.0,11.15854,85.653308,14.67442,31.0,1.0,1.0,2.0,2.0,3.0,...,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0


In [22]:
df_sp = generate_synthetic_data(repaired_dataset, 'synthpop', num_to_generate=len(repaired_dataset))
df_sp.head()

Unnamed: 0,AGE,GLU,DBP,BMI,WGT,OCP,SEX,DIT,MST,RSB,...,PRE,ETN,CIT,FTG_5,SBF,WTG,CSS,LHV,PLH,TYPE
0,44,9.2,84,28.0,88,3,2,1,3,3,...,0,0,0,1,1,0,0,0,0,1
1,43,7.9,80,27.0,85,2,1,2,3,3,...,0,0,0,1,1,0,0,0,0,1
2,55,9.8,88,28.0,88,3,2,1,3,3,...,0,0,0,0,0,0,0,0,0,1
3,60,13.6,85,17.5,49,3,2,1,3,3,...,0,0,0,0,1,0,0,0,0,1
4,50,11.5,100,30.5,92,3,2,2,3,1,...,0,0,1,1,0,1,0,0,0,0


In [23]:
from syntheval import SynthEval

SE = SynthEval(real_dataframe=df_train, holdout_dataframe=df_test)
res, _ = SE.benchmark({'dgms': syn_data, 'dpgan': df_base, 'synthpop': df_sp}, 'TYPE', 'full_eval')
res.T

Inferred categorical columns (unique threshold: 10):
['OCP', 'SEX', 'DIT', 'MST', 'RSB', 'LOE', 'DCD', 'EXT', 'FRU', 'WLG', 'FLS', 'BRV', 'IRT', 'SHC', 'TLF', 'RIG', 'RIV', 'SWT', 'SHK', 'VDS', 'WKN', 'HNG', 'DZN', 'NRV', 'HDC', 'FHB', 'IRT_2', 'NSA', 'CCS', 'SLS', 'DLB', 'DRN_2', 'CFS', 'CVS', 'EXT_2', 'IIU', 'WKN_2', 'LCR', 'CFS_2', 'RPP', 'COM', 'DRB', 'SFS', 'LOA_2', 'NSA_2', 'FVR', 'STP', 'WTL', 'WKN_3', 'FTG_2', 'CFS_3', 'DRN', 'SFH', 'SOB', 'HBP', 'CDC', 'LOA', 'NSA_3', 'VMT', 'DIS', 'FTG_3', 'SCE', 'GSV', 'BWR', 'DSV', 'DRV', 'EYP', 'FRO', 'SAD', 'VLS', 'SOB_2', 'PCJ', 'FAW', 'SWG', 'LHN', 'RIH', 'EXP', 'NRV_2', 'HIT', 'PPT', 'TRM', 'FTG_4', 'WTL_2', 'PRE', 'ETN', 'CIT', 'FTG_5', 'SBF', 'WTG', 'CSS', 'LHV', 'PLH', 'TYPE']


Unnamed: 0,dataset,dgms,dpgan,synthpop
avg_dwm_diff,value,0.061525,0.24874,0.008264
avg_dwm_diff,error,0.007476,0.006297,0.008224
pca_eigval_diff,value,0.113214,0.463275,0.058438
pca_eigval_diff,error,,,
pca_eigvec_ang,value,0.676611,0.790511,0.077114
pca_eigvec_ang,error,,,
avg_cio,value,0.169505,0.0,0.698149
avg_cio,error,0.104966,0.0,0.168678
corr_mat_diff,value,6.952502,,3.400939
corr_mat_diff,error,,,
