# Tutorial for Disjoint Generative Models
In this notebook we show the basic functionality of the DGMs codebase.

### Example 1: Getting started with DGMs

First we do a very rudimentary example of DGMs on a simple dataset. We specify two models ```synthpop``` and ```privbayes``` to each be responsible for one part of the dataset. 

Unless otherwise specified, the dataset manager module will randomly split the dataset into equal parts for each model.

In [1]:
# Imports
import pandas as pd
from disjoint_generative_model import DisjointGenerativeModels

In [2]:
# Load the training data
df_train = pd.read_csv('experiments/datasets/heart_train.csv')

# Define DGMs using the Synthpop CART model and PrivBayes BN
dgms = DisjointGenerativeModels(df_train, generative_models=['synthpop', 'privbayes'])
df_syn = dgms.fit_generate(num_samples=20)

df_syn.head()

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,39,1,0,178,315,0,0,138,1,2.5,1,0,3,0
1,58,1,0,123,149,0,0,109,1,2.2,1,4,3,0
2,74,1,0,165,342,0,0,132,1,2.0,2,0,3,0
3,69,0,0,150,295,1,0,114,1,1.0,1,0,2,0
4,42,1,1,120,228,0,0,125,0,1.2,0,3,3,0


If we want to specify the split, we can do so by passing a dictionary to the model containing the column names.

```python	
prepared_splits = {
    "part1": ["age", "sex", "cp", "trestbps", "chol"],
    "part2": ["fbs", "restecg", "thalach", "exang", "oldpeak", "slope", "ca", "thal", "target"]
}

dgms = DisjointGenerativeModels(df_train, generative_models=['synthpop', 'privbayes'], prepared_splits=prepared_splits)
```
Alternatively, we can specify the split by passing a dictionary with model names as keys and the corresponding column names as values (note that with this method one cannot specify using the same model for two different partitions).

```python
gms_splits = {
    "synthpop": ["age", "sex", "cp", "trestbps", "chol"],
    "privbayes": ["fbs", "restecg", "thalach", "exang", "oldpeak", "slope", "ca", "thal", "target"]
}

dgms = DisjointGenerativeModels(df_train, generative_models=gms_splits)
```
Finally, it is also possible to specify the number of equal-sized parts rather than the specific columns in both of the above methods.

e.g. send 2 parts to the synthpop model and 1 part to the PrivBayes model

In [3]:
dgms = DisjointGenerativeModels(df_train, generative_models={'synthpop': 2, 'privbayes': 1}) 
df_syn = dgms.fit_generate(num_samples=5)
 
df_syn



Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,60,1,0,160,269,1,0,120,1,2.8,1,1,3,0
1,56,1,2,150,240,1,1,123,1,0.6,1,0,3,1
2,77,1,0,150,289,0,0,133,1,0.6,1,1,1,0
3,61,0,2,160,394,0,0,133,0,0.8,2,0,2,1
4,70,1,3,200,303,0,1,122,0,3.6,1,2,3,0


Note that we get a ```UserWarning``` since a perfect 2:1 split ratio is not achievable (i.e. 14 is not divisible by 3).

Finally, we can also import the method used for randomly splitting the dataset and use it to split the dataset ourselves. This is helpful if we want to use the same split for multiple models, but we don't want to specify the split manually.

In [4]:
from disjoint_generative_model.utils.dataset_manager import random_split_columns

random_split = random_split_columns(df_train, {'part1': 2, 'part2': 1, 'part3': 1})
random_split



{'part1': ['fbs', 'ca', 'restecg', 'chol', 'cp', 'exang', 'target', 'sex'],
 'part2': ['trestbps', 'thalach', 'thal'],
 'part3': ['age', 'oldpeak', 'slope']}

### Example 2: Joining Strategies

The DGMs framework allows for virtually any sort of joining procedure. In this library the following joining starategies are implemented:

Unsupervised:
- ```Concatenating```: Simply concatenates the synthetic data generated by each model.
- ```RandomJoining```: Same as Concatenating, but shuffles the data before concatenating.

Supervised:
- ```UsingJoiningValidator```: Strategy for joining the synthetic data using a validator model. The validator model can use two different adapters ```JoiningValidator``` and ```OneClassValidator```, the former admits binary classification model backends and the latter one-class/outlier detection models. They assign prediction scores to querry joins on the synthetic samples repeadedly subject to various control parameters. Accepted joins are removed from the pool for the next round. 

The ```UsingJoiningValidator``` strategy has various control parameters that can be overwritten by the user, but for most regular use the ```'behaviour'``` argument acts as a shorthand for selecting pre-configured option sets. The following behaviours are available:
- ```'adaptive'```: The parameters are adjusted during the joining process to get more items, the selection threshold is automatically inferred. 
- ```'standard'```: Inherits the default settings from the ```JoiningValidator``` or ```OneClassValidator``` adapter.
- ```'strict'```: No parameters are changed during the joining process (likely to fail in getting enough good joins, consider adjusting the ```'join_multiplier'``` attribute of the DGMs object).


In [None]:
# Imports
import pandas as pd

from disjoint_generative_model import DisjointGenerativeModels
from disjoint_generative_model.utils.joining_validator import JoiningValidator, OneClassValidator
from disjoint_generative_model.utils.joining_strategies import UsingJoiningValidator

In [None]:
# Load the training data
df_train = pd.read_csv('experiments/datasets/heart_train.csv')

gms = {'synthpop': 2, 'privbayes': 1}

JS = UsingJoiningValidator()
dgms1 = DisjointGenerativeModels(df_train, gms, joining_strategy=JS)

df_syn1 = dgms1.fit_generate()
df_syn1



Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,59,1,0,138,321,0,0,120,0,0.6,1,0,1,1
1,63,0,2,152,233,0,0,168,0,2.3,1,1,2,1
2,40,1,0,110,265,0,0,103,1,0.4,1,1,2,0
3,52,1,0,101,227,0,1,157,0,0.0,1,2,2,1
4,48,1,1,125,207,0,0,174,0,3.1,1,1,2,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
237,51,1,0,125,215,0,1,168,0,2.8,1,0,3,0
238,64,1,1,125,169,0,1,158,0,0.0,1,0,2,1
239,62,1,1,120,245,0,1,96,1,0.2,1,0,3,0
240,64,0,2,120,160,0,1,138,0,0.0,2,0,2,1


In [None]:
JS = UsingJoiningValidator(OneClassValidator())
dgms2 = DisjointGenerativeModels(df_train, gms, joining_strategy=JS)

df_syn2 = dgms2.fit_generate()
df_syn2



Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,74,0,0,112,265,0,1,140,0,0.4,2,1,2,1
1,76,1,3,160,288,1,0,150,1,0.8,1,0,3,1
2,53,0,0,120,178,0,0,143,0,0.8,1,0,2,1
3,52,0,2,100,175,1,0,140,0,0.6,2,0,2,1
4,45,1,2,130,234,1,0,171,0,2.0,1,0,2,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
237,63,1,1,200,197,1,0,158,0,4.0,1,0,2,1
238,66,1,0,136,274,0,1,96,1,2.2,0,2,3,0
239,71,0,0,130,197,0,1,126,1,1.6,1,0,2,1
240,59,1,2,160,239,0,0,150,0,3.0,1,0,1,1


In [7]:
from syntheval import SynthEval

### Metrics
metrics = {
    "h_dist"    : {},
    "corr_diff" : {"mixed_corr": True},
    "auroc_diff" : {"model": "rf_cls"},
    "cls_acc"   : {"F1_type": "macro"},
    "eps_risk"  : {},
    "dcr"       : {},
    "mia"  : {"num_eval_iter": 5},
}

df_train = pd.read_csv('experiments/datasets/heart_train.csv')
df_test = pd.read_csv('experiments/datasets/heart_test.csv')

SE = SynthEval(df_train, df_test)
res, _ = SE.benchmark({'occls': df_syn1, 'cls': df_syn2}, analysis_target_var="target",rank_strategy='summation', **metrics)

res

SynthEval: inferred categorical columns...


Unnamed: 0_level_0,avg_h_dist,avg_h_dist,corr_mat_diff,corr_mat_diff,auroc,auroc,cls_F1_diff,cls_F1_diff,cls_F1_diff_hout,cls_F1_diff_hout,...,median_DCR,median_DCR,mia_recall,mia_recall,mia_precision,mia_precision,rank,u_rank,p_rank,f_rank
Unnamed: 0_level_1,value,error,value,error,value,error,value,error,value,error,...,value,error,value,error,value,error,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
dataset,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
occls,0.076695,0.025107,1.284948,,0.040047,,0.044697,0.017128,0.077073,0.01588,...,0.925389,,0.475,0.067315,0.520483,0.054529,8.224302,4.747368,3.476933,0.0
cls,0.035389,0.012471,0.594695,,0.052612,,0.021893,0.016232,0.070182,0.010184,...,1.104452,,0.475,0.04239,0.55641,0.035483,8.230142,4.81339,3.416751,0.0


In [1]:
# Imports
import pandas as pd

from disjoint_generative_model import DisjointGenerativeModels
from disjoint_generative_model.utils.joining_validator import JoiningValidator, OneClassValidator
from disjoint_generative_model.utils.joining_strategies import UsingJoiningValidator

df_train = pd.read_csv('experiments/datasets/heart_train.csv')

gms = {'synthpop': 2, 'privbayes': 1}

JS = UsingJoiningValidator(OneClassValidator(), behaviour='adaptive')
dgms2 = DisjointGenerativeModels(df_train, gms, joining_strategy=JS)

df_syn2 = dgms2.fit_generate()
df_syn2



Bad joins found F1: [0.44680851063829785, 0.45569620253164556, 0.5609756097560976, 0.4943820224719101, 0.3855421686746988]
Mean F1: 0.46868090281452995
Final model trained!
Threshold auto-set to: -0.45309948008843137
Predicted good joins fraction: 0.10055096418732783
Predicted good joins fraction: 0.05972434915773354
Predicted good joins fraction: 0.05537459283387622
Predicted good joins fraction: 0.02586206896551724
Predicted good joins fraction: 0.017699115044247787
Predicted good joins fraction: 0.016216216216216217
Predicted good joins fraction: 0.018315018315018316
Predicted good joins fraction: 0.016791044776119403
Predicted good joins fraction: 0.011385199240986717
Predicted good joins fraction: 0.013435700575815739
Predicted good joins fraction: 0.01556420233463035
Predicted good joins fraction: 0.009881422924901186
Predicted good joins fraction: 0.005988023952095809
Predicted good joins fraction: 0.002008032128514056
Predicted good joins fraction: 0.006036217303822937
Predicte

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,42,0,1,130,284,0,0,170,0,0.0,2,0,2,1
1,53,0,1,130,196,0,0,165,0,0.0,2,0,2,1
2,41,1,2,160,234,0,0,177,0,0.0,1,0,2,1
3,43,1,2,130,282,0,0,175,0,0.6,2,0,2,1
4,60,1,0,130,252,0,0,174,0,1.4,1,1,3,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
237,40,1,1,140,203,0,1,143,0,0.0,2,0,3,1
238,66,1,0,122,214,0,0,139,1,2.0,2,1,3,0
239,45,0,2,128,196,0,0,169,0,0.0,2,1,2,1
240,39,1,1,123,214,0,1,163,0,0.6,1,0,3,1
