In [17]:
%pip install synthcity[all]
%pip install catenets

Collecting pytest-cov (from synthcity[all])
  Downloading pytest_cov-5.0.0-py3-none-any.whl (21 kB)
Collecting jupyter (from synthcity[all])
  Downloading jupyter-1.0.0-py2.py3-none-any.whl (2.7 kB)
Collecting notebook (from synthcity[all])
  Downloading notebook-7.1.2-py3-none-any.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m63.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting bandit (from synthcity[all])
  Downloading bandit-1.7.8-py3-none-any.whl (127 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.6/127.6 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting black (from synthcity[all])
  Downloading black-24.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m93.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting black-nb (from synthcity[all])
  Downloading black-nb-0.7.tar.gz (10 kB

In [3]:
from synthcity.plugins import Plugins
from synthcity.metrics.eval_statistical import AlphaPrecision, InverseKLDivergence, MaximumMeanDiscrepancy, WassersteinDistance
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.utils.serialization import save_to_file, load_from_file
import matplotlib.pyplot as plt
import pandas as pd
from catenets.models.jax import *
import numpy as np
from sklearn.metrics import mean_squared_error
from os import listdir
from sklearn.preprocessing import OneHotEncoder


<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.

    The location of Cuda header files cuda.h and nvrtc.h could not be detected on your system.
    You must determine their location and then define the environment variable CUDA_PATH,
    either before launching Python or using os.environ before importing keops. For example
    if these files are in /vol/cuda/10.2.89-cudnn7.6.4.38/include you can do :
      import os
      os.environ['CUDA_PATH'] = '/vol/cuda/10.2.89-cudnn7.6.4.38'
    
[KeOps] Compiling cuda jit compiler engine ... 
/anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/keopscore/binders/nvrtc/nvrtc_jit.cpp:16:10: fatal error: cuda.h: No such file or directory
 #include <cuda.h>
          ^~~~~~~~
compilation terminated.

OK
[pyKeOps] Compiling nvrtc binder for python ... 
In file included from /anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/pykeops/common/keops_io/pykeops_nvrtc.cpp:8:0:
/anaconda/envs/azur

In [4]:
def train_models(gen_list, datasets, n_runs):
    for g in gen_list:
        for dataset in datasets:
            for i in range(n_runs):
                print(f'Training model {g} iteration {i} on dataset {dataset}')
                model = Plugins().get(g)
                model.fit(datasets[dataset])
                fp = 'models/' + g + '_' + dataset + '_' + str(i) + '.pkl'
                print(f'Saving model {g} iteration {i} on dataset {dataset}')
                save_to_file(fp, model)

In [5]:
def generate_data(models, n):
    for fp in models:
        g = load_from_file('models/' + fp)
        df = g.generate(count = n).dataframe()

        name_list = fp.split('.')
        name = name_list[0]
        df.to_csv('syn_data/' + name + '.csv', index=False)


In [6]:
def run_ihdp_cate_experiment(datasets, X_t, mu0, mu1):
    results = pd.DataFrame(columns=['generator', 'learner', 'rmse'])
    for d in datasets:
        dataset = datasets[d]
        X = np.array(dataset.drop(['treatment', 'y_factual'], axis=1))
        y = np.array(dataset['y_factual'])
        w = np.array(dataset['treatment'])  
        learners = [TNet(), SNet(), PWNet(), RANet(), DRNet()]
        for learner in learners:
            learner.fit(X,y,w)
            pred = learner.predict(X_t)
            cate = mu1 - mu0
            rmse = mean_squared_error(cate, pred, squared = False)
            results.loc[len(results.index)] = [d, learner.__str__(), rmse]
            print(f'{learner} complete on {d}')

    return results

In [7]:
def ihdp_all_predictions(training_datasets, X_t, mu0, mu1):
    results = X_t.copy()
    results = pd.concat([results, mu0, mu1], axis=1)
    results['CATE'] = results['mu1'] - results['mu0']

    for d in training_datasets:
        dataset = training_datasets[d]
        X = np.array(dataset.drop(['treatment', 'y_factual'], axis=1))
        y = np.array(dataset['y_factual'])
        w = np.array(dataset['treatment'])
        learners = [TNet(), SNet(), PWNet(), RANet(), DRNet()]

        for learner in learners:
            learner.fit(X,y,w)
            X_t_array = np.array(X_t)
            pred = learner.predict(X_t_array)
            col = d.split('.')[0] + learner.__str__()
            results[col] = pred
        
            print(f'{learner} complete on {d}')
            
    return results


In [8]:
def rate_of_flipping(d, cols):
    rates = pd.DataFrame()
    for c in cols:
        m = d['CATE'] * d[c]
        wrong = sum(m<0)
        rate = wrong / d.shape[0]
        rates[c] = [rate]
    return rates

In [9]:
def standard_metrics(datasets, real, metrics):
    results = pd.DataFrame(columns = ['dataset', 'metric', 'result'])
    for d in datasets:
        dataset = GenericDataLoader(datasets[d])
        for m in metrics:
            metric = metrics[m]
            result = metric.evaluate(real, dataset)
            for r in result:
                results.loc[len(results.index)] = [d, r, result[r]]
    
    return results

In [10]:
def encode_acic_datasets(real, syn_data):
    encoder = OneHotEncoder()
    categorical = ['x_2', 'x_21', 'x_24']
    encoded = encoder.fit_transform(real[categorical])
    encoded_df = pd.DataFrame(encoded.toarray(), columns=encoder.get_feature_names_out(categorical))
    real_encoded = pd.concat([real, encoded_df], axis=1)
    real_encoded.drop(categorical, axis=1, inplace=True)

    syn_data_encoded = {}
    for d in syn_data:
        syn_dataset = syn_data[d]
        encoded = encoder.transform(syn_dataset[categorical])
        encoded_df = pd.DataFrame(encoded.toarray(), columns=encoder.get_feature_names_out(categorical))
        syn_encoded = pd.concat([syn_dataset, encoded_df], axis=1)
        syn_encoded.drop(categorical, axis=1, inplace=True)
        syn_data_encoded[d] = syn_encoded
    
    return real_encoded, syn_data_encoded

In [11]:
def run_acic_cate_experiment(datasets, X_t, mu0, mu1):
    results = pd.DataFrame(columns=['generator', 'learner', 'rmse'])
    for d in datasets:
        dataset = datasets[d]
        X = np.array(dataset.drop(['z', 'y'], axis=1))
        y = np.array(dataset['y'])
        w = np.array(dataset['z'])  
        learners = [TNet(), PWNet(), RANet(), DRNet()]
        for learner in learners:
            learner.fit(X,y,w)
            pred = learner.predict(X_t)
            cate = mu1 - mu0
            rmse = mean_squared_error(cate, pred, squared = False)
            results.loc[len(results.index)] = [d, learner.__str__(), rmse]
            print(f'{learner} complete on {d}')

    return results

In [32]:
def acic_all_predictions(training_datasets, X_t, mu0, mu1):
    results = X_t.copy()
    results = pd.concat([results, mu0, mu1], axis=1)
    results['CATE'] = results['mu1'] - results['mu0']

    for d in training_datasets:
        dataset = training_datasets[d]
        X = np.array(dataset.drop(['z', 'y'], axis=1))
        y = np.array(dataset['y'])
        w = np.array(dataset['z'])
        learners = [TNet(), PWNet(), RANet(), DRNet()]

        for learner in learners:
            learner.fit(X,y,w)
            X_t_array = np.array(X_t)
            pred = learner.predict(X_t_array)
            col = d.split('.')[0] + '_' + learner.__str__()
            results[col] = pred
        
            print(f'{learner} complete on {d}')
            
    return results

## Load real datasets

In [12]:
ihdp_full = pd.read_csv('../Datasets/ihdp.csv')
ihdp = ihdp_full.drop(['y_cfactual', 'mu0', 'mu1'], axis=1)

#train on 80% of the data, the remaining 20% for testing of CATE estimators
train_ihdp = ihdp.loc[[i for i in range(600)]]
test_ihdp = ihdp.loc[[i+600 for i in range(147)]]

In [75]:
jobs =pd.read_csv('../Datasets/jobs.csv')

In [20]:
twins = pd.read_csv('../Datasets/twins.csv')

In [33]:
acic_full = pd.read_csv('../Datasets/acic.csv')
acic_full['y'] = acic_full['y0']
acic_full.loc[acic_full['z']==1, 'y'] = acic_full.loc[acic_full['z']==1, 'y1']

acic = acic_full.drop(['y0', 'y1', 'mu0', 'mu1'], axis=1)
acic_train = acic.loc[[i for i in range(4000)]]
#acic_test = acic.loc[[i+4000 for i in range(802)]]

## Train generative models

In [128]:
gen_list = ['ctgan', 'nflow', 'tvae', 'arf', 'ddpm']
datasets = {}
datasets['acic'] = acic_train

In [129]:
train_models(gen_list, datasets, 5)

[2024-04-13T12:57:37.877501+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T12:57:37.879245+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T12:57:37.879773+0000][297025][CRITICAL] module plugin_goggle load failed


Training model ctgan iteration 0 on dataset acic


 30%|██▉       | 599/2000 [10:06<23:39,  1.01s/it]  
[2024-04-13T13:07:55.637792+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:07:55.638476+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:07:55.638935+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ctgan iteration 0 on dataset acic
Training model ctgan iteration 1 on dataset acic


 50%|████▉     | 999/2000 [16:05<16:07,  1.04it/s]
[2024-04-13T13:24:11.726908+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:24:11.727543+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:24:11.727941+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ctgan iteration 1 on dataset acic
Training model ctgan iteration 2 on dataset acic


 50%|████▉     | 999/2000 [15:56<15:58,  1.04it/s]
[2024-04-13T13:40:18.635097+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:40:18.636037+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:40:18.636479+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ctgan iteration 2 on dataset acic
Training model ctgan iteration 3 on dataset acic


 30%|██▉       | 599/2000 [09:30<22:13,  1.05it/s]
[2024-04-13T13:49:59.544465+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:49:59.545227+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:49:59.545684+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ctgan iteration 3 on dataset acic
Training model ctgan iteration 4 on dataset acic


 27%|██▋       | 549/2000 [08:37<22:46,  1.06it/s]
[2024-04-13T13:58:47.095883+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:58:47.096639+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:58:47.097054+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ctgan iteration 4 on dataset acic
Training model nflow iteration 0 on dataset acic


 30%|██▉       | 299/1000 [00:50<01:58,  5.90it/s]
[2024-04-13T13:59:48.075593+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:59:48.076212+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T13:59:48.076645+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model nflow iteration 0 on dataset acic
Training model nflow iteration 1 on dataset acic


 30%|██▉       | 299/1000 [00:47<01:51,  6.27it/s]
[2024-04-13T14:00:46.069421+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:00:46.070146+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:00:46.070542+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model nflow iteration 1 on dataset acic
Training model nflow iteration 2 on dataset acic


 30%|██▉       | 299/1000 [00:43<01:42,  6.85it/s]
[2024-04-13T14:01:40.255844+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:01:40.256545+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:01:40.256982+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model nflow iteration 2 on dataset acic
Training model nflow iteration 3 on dataset acic


 30%|██▉       | 299/1000 [00:43<01:42,  6.85it/s]
[2024-04-13T14:02:34.411405+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:02:34.412112+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:02:34.412601+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model nflow iteration 3 on dataset acic
Training model nflow iteration 4 on dataset acic


 30%|██▉       | 299/1000 [00:40<01:35,  7.35it/s]
[2024-04-13T14:03:25.567511+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:03:25.568237+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:03:25.568704+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model nflow iteration 4 on dataset acic
Training model tvae iteration 0 on dataset acic


 30%|███       | 300/1000 [04:22<10:13,  1.14it/s]
[2024-04-13T14:07:59.044520+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:07:59.045257+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:07:59.045744+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model tvae iteration 0 on dataset acic
Training model tvae iteration 1 on dataset acic


 30%|███       | 300/1000 [04:23<10:14,  1.14it/s]
[2024-04-13T14:12:33.576106+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:12:33.576881+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:12:33.577279+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model tvae iteration 1 on dataset acic
Training model tvae iteration 2 on dataset acic


 30%|███       | 300/1000 [04:14<09:54,  1.18it/s]
[2024-04-13T14:16:58.764651+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:16:58.765411+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:16:58.765792+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model tvae iteration 2 on dataset acic
Training model tvae iteration 3 on dataset acic


 30%|███       | 300/1000 [04:17<10:00,  1.16it/s]
[2024-04-13T14:21:26.917019+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:21:26.917832+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:21:26.918334+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model tvae iteration 3 on dataset acic
Training model tvae iteration 4 on dataset acic


 30%|███       | 300/1000 [04:20<10:08,  1.15it/s]
[2024-04-13T14:25:58.278353+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:25:58.278987+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:25:58.279403+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model tvae iteration 4 on dataset acic
Training model arf iteration 0 on dataset acic
Initial accuracy is 0.8685
Iteration number 1 reached accuracy of 0.625625.


[2024-04-13T14:27:29.420701+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:27:29.421284+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:27:29.421622+0000][297025][CRITICAL] module plugin_goggle load failed


Iteration number 2 reached accuracy of 0.62825.
Saving model arf iteration 0 on dataset acic
Training model arf iteration 1 on dataset acic
Initial accuracy is 0.8685
Iteration number 1 reached accuracy of 0.625625.


[2024-04-13T14:29:00.563627+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:29:00.564273+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:29:00.564613+0000][297025][CRITICAL] module plugin_goggle load failed


Iteration number 2 reached accuracy of 0.62825.
Saving model arf iteration 1 on dataset acic
Training model arf iteration 2 on dataset acic
Initial accuracy is 0.8685
Iteration number 1 reached accuracy of 0.625625.


[2024-04-13T14:30:31.554039+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:30:31.554629+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:30:31.554980+0000][297025][CRITICAL] module plugin_goggle load failed


Iteration number 2 reached accuracy of 0.62825.
Saving model arf iteration 2 on dataset acic
Training model arf iteration 3 on dataset acic
Initial accuracy is 0.8685
Iteration number 1 reached accuracy of 0.625625.


[2024-04-13T14:32:02.692121+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:32:02.692709+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:32:02.693093+0000][297025][CRITICAL] module plugin_goggle load failed


Iteration number 2 reached accuracy of 0.62825.
Saving model arf iteration 3 on dataset acic
Training model arf iteration 4 on dataset acic
Initial accuracy is 0.8685
Iteration number 1 reached accuracy of 0.625625.


[2024-04-13T14:33:34.291780+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:33:34.292507+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:33:34.292879+0000][297025][CRITICAL] module plugin_goggle load failed


Iteration number 2 reached accuracy of 0.62825.
Saving model arf iteration 4 on dataset acic
Training model ddpm iteration 0 on dataset acic


Epoch: 100%|██████████| 1000/1000 [01:34<00:00, 10.62it/s, loss=0.868]
[2024-04-13T14:35:08.741386+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:35:08.741940+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:35:08.742403+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ddpm iteration 0 on dataset acic
Training model ddpm iteration 1 on dataset acic


Epoch: 100%|██████████| 1000/1000 [01:33<00:00, 10.66it/s, loss=0.857]
[2024-04-13T14:36:42.780670+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:36:42.781472+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:36:42.781958+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ddpm iteration 1 on dataset acic
Training model ddpm iteration 2 on dataset acic


Epoch: 100%|██████████| 1000/1000 [01:37<00:00, 10.26it/s, loss=0.855]
[2024-04-13T14:38:20.462384+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:38:20.462812+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:38:20.463154+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ddpm iteration 2 on dataset acic
Training model ddpm iteration 3 on dataset acic


Epoch: 100%|██████████| 1000/1000 [01:39<00:00, 10.01it/s, loss=0.851]
[2024-04-13T14:40:00.592737+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:40:00.593290+0000][297025][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_goggle' has no attribute 'plugin'
[2024-04-13T14:40:00.593646+0000][297025][CRITICAL] module plugin_goggle load failed


Saving model ddpm iteration 3 on dataset acic
Training model ddpm iteration 4 on dataset acic


Epoch: 100%|██████████| 1000/1000 [01:34<00:00, 10.54it/s, loss=0.853]

Saving model ddpm iteration 4 on dataset acic





## Generate synthetic datasets

In [132]:
acic_models = [i for i in listdir('models') if 'acic' in i]

In [135]:
generate_data(acic_models, 4000)

## Run CATE estimators

In [89]:
#IHDP
X_t = np.array(test_ihdp.drop(['treatment', 'y_factual'], axis=1))
mu0 = ihdp_full.loc[[i+600 for i in range(147)]]['mu0']
mu1 = ihdp_full.loc[[i+600 for i in range(147)]]['mu1']

In [118]:
syn_data = {}

for i in listdir('syn_data'):
    if 'ihdp' in i:
        syn_data[i] = pd.read_csv('syn_data/'+i)
        
syn_data['real'] = ihdp

X_t = ihdp.drop(['treatment', 'y_factual'], axis=1)
mu0 = ihdp_full['mu0']
mu1 = ihdp_full['mu1']

In [119]:
ihdp_cate_estims = ihdp_all_predictions(syn_data, X_t, mu0, mu1)

TNet() complete on tvae_ihdp_0.csv
SNet() complete on tvae_ihdp_0.csv
PWNet() complete on tvae_ihdp_0.csv
RANet() complete on tvae_ihdp_0.csv
DRNet() complete on tvae_ihdp_0.csv
TNet() complete on ctgan_ihdp_4.csv
SNet() complete on ctgan_ihdp_4.csv
PWNet() complete on ctgan_ihdp_4.csv
RANet() complete on ctgan_ihdp_4.csv
DRNet() complete on ctgan_ihdp_4.csv
TNet() complete on ddpm_ihdp_1.csv
SNet() complete on ddpm_ihdp_1.csv
PWNet() complete on ddpm_ihdp_1.csv
RANet() complete on ddpm_ihdp_1.csv
DRNet() complete on ddpm_ihdp_1.csv
TNet() complete on ctgan_ihdp_0.csv
SNet() complete on ctgan_ihdp_0.csv
PWNet() complete on ctgan_ihdp_0.csv
RANet() complete on ctgan_ihdp_0.csv
DRNet() complete on ctgan_ihdp_0.csv
TNet() complete on arf_ihdp_3.csv
SNet() complete on arf_ihdp_3.csv
PWNet() complete on arf_ihdp_3.csv
RANet() complete on arf_ihdp_3.csv
DRNet() complete on arf_ihdp_3.csv
TNet() complete on ddpm_ihdp_2.csv
SNet() complete on ddpm_ihdp_2.csv
PWNet() complete on ddpm_ihdp_2.csv

In [28]:
ihdp_cate_estims.to_csv('all_ihdp_cate.csv', index=False)

NameError: name 'ihdp_cate_estims' is not defined

In [30]:
ihdp_cate_estims = pd.read_csv('all_ihdp_cate.csv')

In [31]:
ihdp_cate_estims

Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,ctgan_ihdp_1TNet(),ctgan_ihdp_1SNet(),ctgan_ihdp_1PWNet(),ctgan_ihdp_1RANet(),ctgan_ihdp_1DRNet(),realTNet(),realSNet(),realPWNet(),realRANet(),realDRNet()
0,-0.528603,-0.343455,1.128554,0.161703,-0.316603,1.295216,1,0,1,0,...,1.525365,6.147883,1.850494,1.411704,2.635499,3.206027,3.856652,2.224463,3.522441,4.327240
1,-1.736945,-1.802002,0.383828,2.244320,-0.629189,1.295216,0,0,0,1,...,2.113898,4.714939,0.594853,1.916519,2.005464,1.499983,3.002945,1.661345,1.981562,2.748401
2,-0.807451,-0.202946,-0.360898,-0.879606,0.808706,-0.526556,0,0,0,1,...,0.712719,0.467486,-3.124680,1.123163,2.495695,4.295106,6.065885,3.594182,4.523798,5.003065
3,0.390083,0.596582,-1.850350,-0.879606,-0.004017,-0.857787,0,0,0,0,...,2.677840,1.870960,5.493809,2.734890,4.052769,4.568575,4.207206,4.570941,4.714533,5.161874
4,-1.045229,-0.602710,0.011465,0.161703,0.683672,-0.360940,1,0,0,0,...,2.934916,3.049065,6.162396,2.790026,4.934308,4.904621,4.534925,5.056879,4.845352,5.415437
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
742,-0.007654,-0.202946,-0.360898,0.161703,-0.316603,1.792063,0,0,1,0,...,3.821021,3.147493,4.427485,3.630419,4.377537,1.321380,2.470347,3.690839,1.796879,2.480177
743,0.727295,-0.202946,-0.733261,-0.879606,0.808706,1.129600,0,0,1,0,...,1.567830,2.053231,-0.930796,1.708021,1.565880,2.498061,2.555947,3.754096,2.717450,2.466795
744,1.181234,0.196818,-1.477987,0.161703,0.746189,0.467138,0,0,0,0,...,0.981259,2.630848,1.702846,1.316971,2.094526,3.524793,3.152847,0.972531,3.774994,3.821777
745,-0.288664,-0.202946,-1.477987,-0.879606,1.621430,0.467138,1,0,1,0,...,2.022070,-2.972390,-0.623636,1.994366,2.480402,4.258956,3.918890,5.591290,4.152879,4.382747


In [183]:
cols = ['tvae_ihdp_0TNet()',
 'tvae_ihdp_0SNet()',
 'tvae_ihdp_0PWNet()',
 'tvae_ihdp_0RANet()',
 'tvae_ihdp_0DRNet()',
 'ctgan_ihdp_4TNet()',
 'ctgan_ihdp_4SNet()',
 'ctgan_ihdp_4PWNet()',
 'ctgan_ihdp_4RANet()',
 'ctgan_ihdp_4DRNet()',
 'ddpm_ihdp_1TNet()',
 'ddpm_ihdp_1SNet()',
 'ddpm_ihdp_1PWNet()',
 'ddpm_ihdp_1RANet()',
 'ddpm_ihdp_1DRNet()',
 'ctgan_ihdp_0TNet()',
 'ctgan_ihdp_0SNet()',
 'ctgan_ihdp_0PWNet()',
 'ctgan_ihdp_0RANet()',
 'ctgan_ihdp_0DRNet()',
 'arf_ihdp_3TNet()',
 'arf_ihdp_3SNet()',
 'arf_ihdp_3PWNet()',
 'arf_ihdp_3RANet()',
 'arf_ihdp_3DRNet()',
 'ddpm_ihdp_2TNet()',
 'ddpm_ihdp_2SNet()',
 'ddpm_ihdp_2PWNet()',
 'ddpm_ihdp_2RANet()',
 'ddpm_ihdp_2DRNet()',
 'tvae_ihdp_4TNet()',
 'tvae_ihdp_4SNet()',
 'tvae_ihdp_4PWNet()',
 'tvae_ihdp_4RANet()',
 'tvae_ihdp_4DRNet()',
 'ddpm_ihdp_3TNet()',
 'ddpm_ihdp_3SNet()',
 'ddpm_ihdp_3PWNet()',
 'ddpm_ihdp_3RANet()',
 'ddpm_ihdp_3DRNet()',
 'tvae_ihdp_3TNet()',
 'tvae_ihdp_3SNet()',
 'tvae_ihdp_3PWNet()',
 'tvae_ihdp_3RANet()',
 'tvae_ihdp_3DRNet()',
 'nflow_ihdp_3TNet()',
 'nflow_ihdp_3SNet()',
 'nflow_ihdp_3PWNet()',
 'nflow_ihdp_3RANet()',
 'nflow_ihdp_3DRNet()',
 'ddpm_ihdp_0TNet()',
 'ddpm_ihdp_0SNet()',
 'ddpm_ihdp_0PWNet()',
 'ddpm_ihdp_0RANet()',
 'ddpm_ihdp_0DRNet()',
 'arf_ihdp_1TNet()',
 'arf_ihdp_1SNet()',
 'arf_ihdp_1PWNet()',
 'arf_ihdp_1RANet()',
 'arf_ihdp_1DRNet()',
 'arf_ihdp_4TNet()',
 'arf_ihdp_4SNet()',
 'arf_ihdp_4PWNet()',
 'arf_ihdp_4RANet()',
 'arf_ihdp_4DRNet()',
 'arf_ihdp_0TNet()',
 'arf_ihdp_0SNet()',
 'arf_ihdp_0PWNet()',
 'arf_ihdp_0RANet()',
 'arf_ihdp_0DRNet()',
 'tvae_ihdp_2TNet()',
 'tvae_ihdp_2SNet()',
 'tvae_ihdp_2PWNet()',
 'tvae_ihdp_2RANet()',
 'tvae_ihdp_2DRNet()',
 'nflow_ihdp_1TNet()',
 'nflow_ihdp_1SNet()',
 'nflow_ihdp_1PWNet()',
 'nflow_ihdp_1RANet()',
 'nflow_ihdp_1DRNet()',
 'arf_ihdp_2TNet()',
 'arf_ihdp_2SNet()',
 'arf_ihdp_2PWNet()',
 'arf_ihdp_2RANet()',
 'arf_ihdp_2DRNet()',
 'nflow_ihdp_4TNet()',
 'nflow_ihdp_4SNet()',
 'nflow_ihdp_4PWNet()',
 'nflow_ihdp_4RANet()',
 'nflow_ihdp_4DRNet()',
 'ctgan_ihdp_3TNet()',
 'ctgan_ihdp_3SNet()',
 'ctgan_ihdp_3PWNet()',
 'ctgan_ihdp_3RANet()',
 'ctgan_ihdp_3DRNet()',
 'nflow_ihdp_0TNet()',
 'nflow_ihdp_0SNet()',
 'nflow_ihdp_0PWNet()',
 'nflow_ihdp_0RANet()',
 'nflow_ihdp_0DRNet()',
 'ddpm_ihdp_4TNet()',
 'ddpm_ihdp_4SNet()',
 'ddpm_ihdp_4PWNet()',
 'ddpm_ihdp_4RANet()',
 'ddpm_ihdp_4DRNet()',
 'ctgan_ihdp_2TNet()',
 'ctgan_ihdp_2SNet()',
 'ctgan_ihdp_2PWNet()',
 'ctgan_ihdp_2RANet()',
 'ctgan_ihdp_2DRNet()',
 'tvae_ihdp_1TNet()',
 'tvae_ihdp_1SNet()',
 'tvae_ihdp_1PWNet()',
 'tvae_ihdp_1RANet()',
 'tvae_ihdp_1DRNet()',
 'nflow_ihdp_2TNet()',
 'nflow_ihdp_2SNet()',
 'nflow_ihdp_2PWNet()',
 'nflow_ihdp_2RANet()',
 'nflow_ihdp_2DRNet()',
 'ctgan_ihdp_1TNet()',
 'ctgan_ihdp_1SNet()',
 'ctgan_ihdp_1PWNet()',
 'ctgan_ihdp_1RANet()',
 'ctgan_ihdp_1DRNet()',
 'realTNet()',
 'realSNet()',
 'realPWNet()',
 'realRANet()',
 'realDRNet()']

In [184]:
rates = rate_of_flipping(ihdp_cate_estims, cols)

In [188]:
rates = rates.transpose()

In [202]:
rates.sort_values(by=0)

Unnamed: 0,0
realDRNet(),0.004016
realRANet(),0.006693
realSNet(),0.010710
realTNet(),0.010710
ddpm_ihdp_4DRNet(),0.014726
...,...
nflow_ihdp_0RANet(),0.669344
nflow_ihdp_3PWNet(),0.676037
nflow_ihdp_0DRNet(),0.726908
nflow_ihdp_2PWNet(),0.775100


In [284]:
rates.to_csv('ihdp_rate_of_flipping.csv', index=False)

In [97]:
#ACIC
syn_data = {}

for i in listdir('syn_data'):
    if 'acic' in i:
        syn_data[i] = pd.read_csv('syn_data/'+i)

acic_encoded, syn_data_encoded = encode_acic_datasets(acic, syn_data)


acic_test_encoded = acic_encoded.loc[[i+4000 for i in range(802)]]

X_t = np.array(acic_test_encoded.drop(['z', 'y'], axis=1))
mu0 = acic_full.loc[[i+4000 for i in range(802)]]['mu0']
mu1 = acic_full.loc[[i+4000 for i in range(802)]]['mu1']

In [98]:
results = run_acic_cate_experiment(syn_data_encoded, X_t, mu0, mu1)

TNet() complete on nflow_acic_0.csv
PWNet() complete on nflow_acic_0.csv
RANet() complete on nflow_acic_0.csv
DRNet() complete on nflow_acic_0.csv
TNet() complete on arf_acic_3.csv
PWNet() complete on arf_acic_3.csv
RANet() complete on arf_acic_3.csv
DRNet() complete on arf_acic_3.csv
TNet() complete on nflow_acic_4.csv
PWNet() complete on nflow_acic_4.csv
RANet() complete on nflow_acic_4.csv
DRNet() complete on nflow_acic_4.csv
TNet() complete on ctgan_acic_2.csv
PWNet() complete on ctgan_acic_2.csv
RANet() complete on ctgan_acic_2.csv
DRNet() complete on ctgan_acic_2.csv
TNet() complete on tvae_acic_0.csv
PWNet() complete on tvae_acic_0.csv
RANet() complete on tvae_acic_0.csv
DRNet() complete on tvae_acic_0.csv
TNet() complete on ddpm_acic_0.csv
PWNet() complete on ddpm_acic_0.csv
RANet() complete on ddpm_acic_0.csv
DRNet() complete on ddpm_acic_0.csv
TNet() complete on arf_acic_1.csv
PWNet() complete on arf_acic_1.csv
RANet() complete on arf_acic_1.csv
DRNet() complete on arf_acic_1

In [99]:
acic_train_encoded = acic_encoded.loc[[i for i in range(4000)]]

real_data = {'real':acic_train_encoded}
real_results = run_acic_cate_experiment(real_data, X_t, mu0, mu1)

TNet() complete on real
PWNet() complete on real
RANet() complete on real
DRNet() complete on real


In [100]:
results['generator'] = [i.split('_')[0] + i.split('_')[2][0]  for i in results['generator']]

In [101]:
results['gen_type'] = [i[:-1] for i in results['generator']]

In [102]:
real_results['gen_type'] = real_results['generator']

In [103]:
all_results = pd.concat([results, real_results])

In [104]:
all_results.groupby(['gen_type', 'learner']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,rmse
gen_type,learner,Unnamed: 2_level_1
arf,DRNet(),4.376229
arf,PWNet(),6.643897
arf,RANet(),4.344109
arf,TNet(),4.320492
ctgan,DRNet(),4.430125
ctgan,PWNet(),4.994876
ctgan,RANet(),4.768166
ctgan,TNet(),5.279973
ddpm,DRNet(),7.874033
ddpm,PWNet(),5.425006


In [105]:
all_results.to_csv('acic_CATE_results.csv', index=False)

In [35]:
syn_data = {}

for i in listdir('syn_data'):
    if 'acic' in i:
        syn_data[i] = pd.read_csv('syn_data/'+i)

acic_encoded, syn_data_encoded = encode_acic_datasets(acic, syn_data)

syn_data_encoded['real'] = acic_encoded

X_t = acic_encoded.drop(['z', 'y'], axis=1)
mu0 = acic_full['mu0']
mu1 = acic_full['mu1']

In [36]:
acic_cate_estims = acic_all_predictions(syn_data_encoded, X_t, mu0, mu1)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


TNet() complete on nflow_acic_0.csv
PWNet() complete on nflow_acic_0.csv
RANet() complete on nflow_acic_0.csv
DRNet() complete on nflow_acic_0.csv
TNet() complete on arf_acic_3.csv
PWNet() complete on arf_acic_3.csv
RANet() complete on arf_acic_3.csv
DRNet() complete on arf_acic_3.csv
TNet() complete on nflow_acic_4.csv
PWNet() complete on nflow_acic_4.csv
RANet() complete on nflow_acic_4.csv
DRNet() complete on nflow_acic_4.csv
TNet() complete on ctgan_acic_2.csv
PWNet() complete on ctgan_acic_2.csv


In [None]:
acic_cate_estims.to_csv('all_acic_cate.csv', index=False)

In [None]:
acic_cate_estims

Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,ctgan_ihdp_1TNet(),ctgan_ihdp_1SNet(),ctgan_ihdp_1PWNet(),ctgan_ihdp_1RANet(),ctgan_ihdp_1DRNet(),realTNet(),realSNet(),realPWNet(),realRANet(),realDRNet()
0,-0.528603,-0.343455,1.128554,0.161703,-0.316603,1.295216,1,0,1,0,...,1.525365,6.147883,1.850494,1.411704,2.635499,3.206027,3.856652,2.224463,3.522441,4.327240
1,-1.736945,-1.802002,0.383828,2.244320,-0.629189,1.295216,0,0,0,1,...,2.113898,4.714939,0.594853,1.916519,2.005464,1.499983,3.002945,1.661345,1.981562,2.748401
2,-0.807451,-0.202946,-0.360898,-0.879606,0.808706,-0.526556,0,0,0,1,...,0.712719,0.467486,-3.124680,1.123163,2.495695,4.295106,6.065885,3.594182,4.523798,5.003065
3,0.390083,0.596582,-1.850350,-0.879606,-0.004017,-0.857787,0,0,0,0,...,2.677840,1.870960,5.493809,2.734890,4.052769,4.568575,4.207206,4.570941,4.714533,5.161874
4,-1.045229,-0.602710,0.011465,0.161703,0.683672,-0.360940,1,0,0,0,...,2.934916,3.049065,6.162396,2.790026,4.934308,4.904621,4.534925,5.056879,4.845352,5.415437
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
742,-0.007654,-0.202946,-0.360898,0.161703,-0.316603,1.792063,0,0,1,0,...,3.821021,3.147493,4.427485,3.630419,4.377537,1.321380,2.470347,3.690839,1.796879,2.480177
743,0.727295,-0.202946,-0.733261,-0.879606,0.808706,1.129600,0,0,1,0,...,1.567830,2.053231,-0.930796,1.708021,1.565880,2.498061,2.555947,3.754096,2.717450,2.466795
744,1.181234,0.196818,-1.477987,0.161703,0.746189,0.467138,0,0,0,0,...,0.981259,2.630848,1.702846,1.316971,2.094526,3.524793,3.152847,0.972531,3.774994,3.821777
745,-0.288664,-0.202946,-1.477987,-0.879606,1.621430,0.467138,1,0,1,0,...,2.022070,-2.972390,-0.623636,1.994366,2.480402,4.258956,3.918890,5.591290,4.152879,4.382747


In [None]:
cols = ['tvae_ihdp_0TNet()',
 'tvae_ihdp_0SNet()',
 'tvae_ihdp_0PWNet()',
 'tvae_ihdp_0RANet()',
 'tvae_ihdp_0DRNet()',
 'ctgan_ihdp_4TNet()',
 'ctgan_ihdp_4SNet()',
 'ctgan_ihdp_4PWNet()',
 'ctgan_ihdp_4RANet()',
 'ctgan_ihdp_4DRNet()',
 'ddpm_ihdp_1TNet()',
 'ddpm_ihdp_1SNet()',
 'ddpm_ihdp_1PWNet()',
 'ddpm_ihdp_1RANet()',
 'ddpm_ihdp_1DRNet()',
 'ctgan_ihdp_0TNet()',
 'ctgan_ihdp_0SNet()',
 'ctgan_ihdp_0PWNet()',
 'ctgan_ihdp_0RANet()',
 'ctgan_ihdp_0DRNet()',
 'arf_ihdp_3TNet()',
 'arf_ihdp_3SNet()',
 'arf_ihdp_3PWNet()',
 'arf_ihdp_3RANet()',
 'arf_ihdp_3DRNet()',
 'ddpm_ihdp_2TNet()',
 'ddpm_ihdp_2SNet()',
 'ddpm_ihdp_2PWNet()',
 'ddpm_ihdp_2RANet()',
 'ddpm_ihdp_2DRNet()',
 'tvae_ihdp_4TNet()',
 'tvae_ihdp_4SNet()',
 'tvae_ihdp_4PWNet()',
 'tvae_ihdp_4RANet()',
 'tvae_ihdp_4DRNet()',
 'ddpm_ihdp_3TNet()',
 'ddpm_ihdp_3SNet()',
 'ddpm_ihdp_3PWNet()',
 'ddpm_ihdp_3RANet()',
 'ddpm_ihdp_3DRNet()',
 'tvae_ihdp_3TNet()',
 'tvae_ihdp_3SNet()',
 'tvae_ihdp_3PWNet()',
 'tvae_ihdp_3RANet()',
 'tvae_ihdp_3DRNet()',
 'nflow_ihdp_3TNet()',
 'nflow_ihdp_3SNet()',
 'nflow_ihdp_3PWNet()',
 'nflow_ihdp_3RANet()',
 'nflow_ihdp_3DRNet()',
 'ddpm_ihdp_0TNet()',
 'ddpm_ihdp_0SNet()',
 'ddpm_ihdp_0PWNet()',
 'ddpm_ihdp_0RANet()',
 'ddpm_ihdp_0DRNet()',
 'arf_ihdp_1TNet()',
 'arf_ihdp_1SNet()',
 'arf_ihdp_1PWNet()',
 'arf_ihdp_1RANet()',
 'arf_ihdp_1DRNet()',
 'arf_ihdp_4TNet()',
 'arf_ihdp_4SNet()',
 'arf_ihdp_4PWNet()',
 'arf_ihdp_4RANet()',
 'arf_ihdp_4DRNet()',
 'arf_ihdp_0TNet()',
 'arf_ihdp_0SNet()',
 'arf_ihdp_0PWNet()',
 'arf_ihdp_0RANet()',
 'arf_ihdp_0DRNet()',
 'tvae_ihdp_2TNet()',
 'tvae_ihdp_2SNet()',
 'tvae_ihdp_2PWNet()',
 'tvae_ihdp_2RANet()',
 'tvae_ihdp_2DRNet()',
 'nflow_ihdp_1TNet()',
 'nflow_ihdp_1SNet()',
 'nflow_ihdp_1PWNet()',
 'nflow_ihdp_1RANet()',
 'nflow_ihdp_1DRNet()',
 'arf_ihdp_2TNet()',
 'arf_ihdp_2SNet()',
 'arf_ihdp_2PWNet()',
 'arf_ihdp_2RANet()',
 'arf_ihdp_2DRNet()',
 'nflow_ihdp_4TNet()',
 'nflow_ihdp_4SNet()',
 'nflow_ihdp_4PWNet()',
 'nflow_ihdp_4RANet()',
 'nflow_ihdp_4DRNet()',
 'ctgan_ihdp_3TNet()',
 'ctgan_ihdp_3SNet()',
 'ctgan_ihdp_3PWNet()',
 'ctgan_ihdp_3RANet()',
 'ctgan_ihdp_3DRNet()',
 'nflow_ihdp_0TNet()',
 'nflow_ihdp_0SNet()',
 'nflow_ihdp_0PWNet()',
 'nflow_ihdp_0RANet()',
 'nflow_ihdp_0DRNet()',
 'ddpm_ihdp_4TNet()',
 'ddpm_ihdp_4SNet()',
 'ddpm_ihdp_4PWNet()',
 'ddpm_ihdp_4RANet()',
 'ddpm_ihdp_4DRNet()',
 'ctgan_ihdp_2TNet()',
 'ctgan_ihdp_2SNet()',
 'ctgan_ihdp_2PWNet()',
 'ctgan_ihdp_2RANet()',
 'ctgan_ihdp_2DRNet()',
 'tvae_ihdp_1TNet()',
 'tvae_ihdp_1SNet()',
 'tvae_ihdp_1PWNet()',
 'tvae_ihdp_1RANet()',
 'tvae_ihdp_1DRNet()',
 'nflow_ihdp_2TNet()',
 'nflow_ihdp_2SNet()',
 'nflow_ihdp_2PWNet()',
 'nflow_ihdp_2RANet()',
 'nflow_ihdp_2DRNet()',
 'ctgan_ihdp_1TNet()',
 'ctgan_ihdp_1SNet()',
 'ctgan_ihdp_1PWNet()',
 'ctgan_ihdp_1RANet()',
 'ctgan_ihdp_1DRNet()',
 'realTNet()',
 'realSNet()',
 'realPWNet()',
 'realRANet()',
 'realDRNet()']

In [None]:
rates = rate_of_flipping(acic_cate_estims, cols)

In [None]:
rates = rates.transpose()

In [None]:
rates.sort_values(by=0)

Unnamed: 0,0
realDRNet(),0.004016
realRANet(),0.006693
realSNet(),0.010710
realTNet(),0.010710
ddpm_ihdp_4DRNet(),0.014726
...,...
nflow_ihdp_0RANet(),0.669344
nflow_ihdp_3PWNet(),0.676037
nflow_ihdp_0DRNet(),0.726908
nflow_ihdp_2PWNet(),0.775100


In [None]:
rates.to_csv('acic_rate_of_flipping.csv', index=False)

## Statistical metrics

In [218]:
inv_kl = InverseKLDivergence()
alpha_prec = AlphaPrecision()

In [274]:
#IHDP
syn_data = {}

for i in listdir('syn_data'):
    if 'ihdp' in i:
        syn_data[i] = pd.read_csv('syn_data/'+i)
        syn_data[i].loc[:, 'treatment']= syn_data[i]['treatment'].replace(True, 1).replace(False, 0)

In [275]:
ihdp.loc[:, 'treatment']= ihdp['treatment'].replace(True, 1).replace(False, 0)


In [276]:
d = GenericDataLoader(ihdp.loc[:599])


In [17]:

metrics = {'inv kl': InverseKLDivergence(), 'alpha prec': AlphaPrecision()}

In [279]:
stat_results =standard_metrics(syn_data, d, metrics)

In [281]:
stat_results.to_csv('ihdp_standard_metrics.csv')

In [15]:
#ACIC
syn_data = {}

for i in listdir('syn_data'):
    if 'acic' in i:
        syn_data[i] = pd.read_csv('syn_data/'+i)

acic_encoded, syn_data_encoded = encode_acic_datasets(acic, syn_data)

In [19]:
d = GenericDataLoader(acic_encoded.loc[:3999])

In [25]:
acic_standard_metrics = standard_metrics(syn_data_encoded, d, metrics)

In [27]:
acic_standard_metrics.to_csv('acic_standard_metrics.csv', index=False)