In [2]:
%reload_ext autoreload
%autoreload 2

import sys, os, joblib, wandb, json
from collections import defaultdict

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import torch, wandb
torch.set_default_dtype(torch.float64)
import numpy as np
import pandas as pd

from econml import dr, dml, metalearners

import sklearn
from sklearn.cluster import KMeans
from sklearn.linear_model import LassoCV
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import Ridge
from sklearn.svm import SVC, SVR, LinearSVR
from sklearn.decomposition import PCA
from sklearn.neural_network import MLPRegressor, MLPClassifier
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures, PowerTransformer
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier

from xgboost import XGBRegressor, XGBClassifier

from src.data.data_module import LBIDD, IHDP, Twins, Synth
from src.models.nce_ite import NCE
from src.models.cate_model import CATEModel
from src.models.benchmarks import AE




import matplotlib.pyplot as plt
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
## for Palatino and other serif fonts use:
#rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)


from pytorch_lightning import Trainer

# Exp. prep.

You need to run the below 2 cells regardless of what experiment. These cells load the learnt representations (with parameters at top of first cell), and prep the CATE learners. 

Of course, before loading representations, you first have to learn them + upload to W&B.

In [3]:
data = 'ihdp'                         # or 'twins' or 'synth'
project = f'nce-ite-{data}'           # make sure this matches your W&B project

train_size=500                        # set your training set size
fix=True                              # fix your dimensions regardless of training (note this is only possible for synth. data)
n_models=10                           # amount of different trained models
fix_K = False                         # keep your K fixed (use for covariance checks)

varying_sweeps = {                    # select your sweep number for correct training size:
    100: 'e4cmxk15',                  #    {dim: sweep_id, ...}
    250: 't4bw8nvi',
    500: '3ii8nsyi',
    1000: 'wz29ayzz',
    1500: 'hwu8y0u6',
}

fixed_sweeps = {
    100: 'ajt2xpzj', 
    250: 'ncxqa6kx', 
    500: 'epv99a73', 
    1000:'t3z4442e', 
    1500:'ulralaif', 
}

twin_sweeps = {
    500: 'rgrvgmmm',
    1000:'j8fv1g7f',
    1500:'zq5io55q',
    2000: 'fkch1ixg',
    2500: 'l66oyh02',
    5000: 'xv3krdpx',
    10000:'5h9so1uc',
}

twins_sweeps_fixed_K = {
    500: '1e30plop',
    1000: '03ixmhhc',
    1500: 'y0diivjg',
    2000: 'ehu52p5o',
    2500: 'xwamjm77',
    5000: 'xv3krdpx',
    10000:'5h9so1uc',
}

ihdp_sweeps = {
    100: 'ozt3qms4',
    250: 'g5qa43de',
    500: 'hz03jfu7',
}


if data == 'synth':
    sweep_name = fixed_sweeps[train_size] if fix else varying_sweeps[train_size]
elif data == 'twins':
    sweep_name = twin_sweeps[train_size] if not fix_K else twins_sweeps_fixed_K[train_size]
elif data == 'ihdp':
    sweep_name = ihdp_sweeps[train_size]


# LOAD DATA
data_seed_standard = 524
data_seeds = defaultdict(lambda: data_seed_standard)
data_seeds[100] = 979
data_seeds[500] = 255

if data == 'ihdp':
    dm = IHDP(batch_size=256, limit_train_size=train_size)
elif data == 'lbidd':
    dm =  LBIDD(
        batch_size=256, id='93aab00aeb234a3b985eeb32e04a353d',
        location='../data/LBIDD', limit_train_size=train_size)
elif data == 'twins':
    dm = Twins(batch_size=256, limit_train_size=train_size)
elif data == 'synth':
    dims = {
        100: 50,
        250: 100,
        500: 150,
        1000: 200,
        1500: 250,
    }
    
    if fix:
        dim = 100
    else:
        dim = dims[train_size]

    dm = Synth(
        batch_size=64,
        use_existing_data=False,
        n=train_size,
        seed=data_seeds[train_size],
        dim=dim,
        scale_treatment_balance=1)
else:
    raise ValueError('Error value on --data. Please give one of "ihdp", "lbidd", or "twins".')

dm.prepare_data()
dm.setup(stage='test')
dm.setup(stage='fit')

X_train = dm.train[dm.x_cols].copy(deep=True)
Z_train = dm.train.z.copy(deep=True)
Y_train = dm.train.y.copy(deep=True)

X_test = dm.test[dm.x_cols].copy(deep=True)
Y0_test= dm.test.y0.copy(deep=True)
Y1_test= dm.test.y1.copy(deep=True)

Y_train[Z_train.to_numpy() == 1] += 1.
Y1_test += 1.

    
# LOAD MODELS
api = wandb.Api()
sweep = api.sweep(path=f'jeroenbe/{project}/{sweep_name}') #nce-ite-lbidd
                
summary_list = []
for run in sweep.runs:
    if run.state == "finished":
        d = run.summary._json_dict
        d['id'] = run.id
        summary_list.append(d)
                        

summary_df = pd.DataFrame.from_records(summary_list)
model_ids = summary_df.nsmallest(n=n_models, columns=['PEHE with representation']).id.to_numpy().astype(str)

models = dict()
for model_id in model_ids:
    try:
        params = wandb.restore(f'nce.ckpt.ckpt', run_path=f'jeroenbe/{project}/{model_id}', replace=True)
    except:
        params = wandb.restore(f'nce.ckpt-v0.ckpt', run_path=f'jeroenbe/{project}/{model_id}', replace=True)
    
    models[model_id] = NCE.load_from_checkpoint(params.name).double()

Trying to unpickle estimator SVC from version 0.23.2 when using version 0.24.2. This might lead to breaking code or invalid results. Use at your own risk.
Trying to unpickle estimator KernelRidge from version 0.23.2 when using version 0.24.2. This might lead to breaking code or invalid results. Use at your own risk.
Trying to unpickle estimator KMeans from version 0.23.2 when using version 0.24.2. This might lead to breaking code or invalid results. Use at your own risk.


In [4]:
# Here we build several combinations for CATE learners, 
# please make sure all combinations match to those reported
# in our paper, before reproducing our results.
# For example, the regressor() function should return the
# regressor that is reported in our paper. Some examples
# are provided below.



cate_amount = n_models

kwargs = {
            #'max_depth': 40,
            #'n_estimators': 100,
            'max_iter': 50000,
            'tol': .01
        }


def regressor():
    return Ridge()
    return KernelRidge()
    return Pipeline([
        #('poly', PolynomialFeatures(degree=2)),
        ('power', PowerTransformer()),
        ('regr', Ridge(alpha=.001))
    ])

def classifier():
    return SVC(probability=True)

def get_cate_learners(amount):
    
    return {
        'X': lambda: [
            metalearners.XLearner(models= regressor(),
                propensity_model=classifier(),
                cate_models= regressor()) for _ in range(amount)],
        'DR': lambda: [
            dr.LinearDRLearner(
                model_propensity=classifier(),
                model_regression= regressor()) for _ in range(amount)],
        'S': lambda: [
            metalearners.SLearner(overall_model= regressor()) for _ in range(amount)],
        'T': lambda: [
            metalearners.TLearner(models= regressor()) for _ in range(amount)],
        'R': lambda: [dml.NonParamDML(
                model_y=regressor(),
                model_t=classifier(),
                model_final= KernelRidge(),
                discrete_treatment=True) for _ in range(amount)],    
    }

# Test CATE learners with and without EBM

These cells need only executing when interested in CATE performance with and without EBM

In [5]:
out = dict()
cate_learners = get_cate_learners(4)
for k, v in cate_learners.items():
    out[k] = {
        'with representation': [],
        'without representation': [],
        'with latent': [],
    }
    for _, nce in models.items():
        cate_learner = v()
        cate_nr = CATEModel(model=cate_learner[0], representation=None, standardize=False)
        cate_r = CATEModel(model=cate_learner[1], representation=nce, standardize=False)
        
        
        
        cate_nr.fit(X_train, Z_train, Y_train)
        cate_r.fit(X_train, Z_train, Y_train)
        
        
        pehe_no_repr = cate_nr.eval(X_test, Y0_test, Y1_test)
        pehe_repr = cate_r.eval(X_test, Y0_test, Y1_test)

        
        out[k]['with representation'].append(pehe_repr)
        out[k]['without representation'].append(pehe_no_repr)
        
        
        if data == 'synth':
            cate_u = CATEModel(model=cate_learner[2], representation=None, standardize=False)
            
            U_train = dm.train_u
            mu_u = U_train.mean(axis=0)
            std_u = U_train.std(axis=0)
            U_train -= mu_u
            U_train /= std_u
        
            cate_u.fit(pd.DataFrame(dm.train_u), Z_train, Y_train)
        
            U_test = dm.test_u
            U_test -= mu_u
            U_test /= std_u
            
            pehe_u = cate_u.eval(pd.DataFrame(dm.test_u), Y0_test, Y1_test)
            out[k]['with latent'].append(pehe_u)


In [6]:
results = dict()
for k, v in out.items():
    wr = 'with representation'
    wor = 'without representation'
    results[k] = {
        wr: (np.mean(out[k][wr]), np.std(out[k][wr])),
        wor: (np.mean(out[k][wor]), np.std(out[k][wor])),
        'latent': (np.mean(out[k]['with latent']), np.std(out[k]['with latent'])),
    }
print(json.dumps(results, indent=4))

{
    "X": {
        "with representation": [
            1.619916425156085,
            0.0052781109929950915
        ],
        "without representation": [
            1.5644931207443868,
            1.8417375312065561e-06
        ],
        "latent": [
            NaN,
            NaN
        ]
    },
    "DR": {
        "with representation": [
            1.6349975347585122,
            0.05064465206539426
        ],
        "without representation": [
            1.690497780248042,
            0.04683625822346018
        ],
        "latent": [
            NaN,
            NaN
        ]
    },
    "S": {
        "with representation": [
            1.6232791986577861,
            0.0015007790764973424
        ],
        "without representation": [
            1.6370527343705024,
            2.220446049250313e-16
        ],
        "latent": [
            NaN,
            NaN
        ]
    },
    "T": {
        "with representation": [
            1.6174386049494625,
            0.

Mean of empty slice.
invalid value encountered in double_scalars
Degrees of freedom <= 0 for slice
invalid value encountered in true_divide
invalid value encountered in double_scalars


# Different Dim-red.
Here we train and compare with various dimensionality reduction methods

In [None]:
# DIM RED.s:
# -> AE
# -> PCA
# -> FeatureAgglomeration
# -> TruncatedSVD

from sklearn.cluster import FeatureAgglomeration
from sklearn.decomposition import TruncatedSVD, KernelPCA
from sklearn.manifold import TSNE, SpectralEmbedding, Isomap, LocallyLinearEmbedding, MDS


class DIM_WRAPPER:
    def __init__(self, dm_red):
        self.dm_red = dm_red
        self.device='cpu'
    def __call__(self, x):
        if getattr(self.dm_red, "transform", None) is not None:
            return self.dm_red.transform(x)
        else:
            return self.dm_red.fit_transform(x)


dimred_out = dict()
cate_learners = get_cate_learners(8)
i = 0
for k, v in cate_learners.items():
    dimred_out[k] = {
        'with EBM': [],
        'with PCA': [],
        'with FA': [],
        'with SE': [],
        'with IM': [],
        'with KernelPCA': [],
        'with AE': [],
        'with latent': [],
    }
    for _, nce in models.items():
        cate_learner = v()
        
        print('now on PCA')
        pca_ = PCA(n_components=nce.K)
        pca_.fit(X_train)
        pca = DIM_WRAPPER(pca_)
        
        print('now on FA')
        fa_ = FeatureAgglomeration(n_clusters=nce.K)
        fa_.fit(X_train)
        fa = DIM_WRAPPER(fa_)
        
        print('now on SE')
        se_ = SpectralEmbedding(n_components=nce.K)
        se_.fit(X_train)
        se = DIM_WRAPPER(se_)
        
        print('now on IM')
        im_ = Isomap(n_neighbors=nce.K)
        im_.fit(X_train)
        im = DIM_WRAPPER(im_)
        
        print('now on LLE')
        lle_ = KernelPCA(n_components=nce.K, kernel='rbf')
        lle_ = lle_.fit(X_train)
        lle = DIM_WRAPPER(lle_)
        
        print('now on AE')
        if data == 'synth':
            pl.utilities.seed.seed_everything(i * dm.size(1))
            i+=1
        ae = AE(
            input_dim = dm.size(1), 
            K=nce.K, 
            lr=nce.lr, architecture=[(40,40) for _ in range(6)])
        trainer = Trainer(callbacks=[EarlyStopping(monitor='val_loss')], max_epochs=50)
        trainer.fit(ae, dm)
        
        print('Dim red. Done')
        
        cate_r = CATEModel(model=cate_learner[0], representation=nce, standardize=False)
        cate_pca = CATEModel(model=cate_learner[1], representation=pca, standardize=False)
        cate_fa = CATEModel(model=cate_learner[2], representation=fa, standardize=False)
        cate_se = CATEModel(model=cate_learner[3], representation=se, standardize=False)
        cate_im = CATEModel(model=cate_learner[5], representation=im, standardize=False)
        cate_lle = CATEModel(model=cate_learner[4], representation=lle, standardize=False)
        cate_ae = CATEModel(model=cate_learner[6], representation=ae, standardize=False)
        
        
        
        print('learn EBM')
        cate_r.fit(X_train, Z_train, Y_train)
        print('learn PCA')
        cate_pca.fit(X_train, Z_train, Y_train)
        print('learn FA')
        cate_fa.fit(X_train, Z_train, Y_train)
        print('learn SE')
        cate_se.fit(X_train, Z_train, Y_train)
        print('learn IM')
        cate_im.fit(X_train, Z_train, Y_train)
        print('learn MDS')
        cate_lle.fit(X_train, Z_train, Y_train)
        print('learn AE')
        cate_ae.fit(X_train, Z_train, Y_train)
        
        print('eval EBM')
        pehe_repr = cate_r.eval(X_test, Y0_test, Y1_test)
        print('eval PCA')
        pehe_pca = cate_pca.eval(X_test, Y0_test, Y1_test)
        print('eval FA')
        pehe_fa = cate_fa.eval(X_test, Y0_test, Y1_test)
        print('eval SE')
        pehe_se = cate_se.eval(X_test.iloc[:2000,:], Y0_test.iloc[:2000], Y1_test.iloc[:2000])
        print('eval IM')
        pehe_im = cate_im.eval(X_test, Y0_test, Y1_test)
        print('eval MDS')
        pehe_lle = cate_lle.eval(X_test, Y0_test, Y1_test)
        print('eval AE')
        pehe_ae = cate_ae.eval(X_test, Y0_test, Y1_test)

        
        dimred_out[k]['with EBM'].append(pehe_repr)
        dimred_out[k]['with PCA'].append(pehe_pca)
        dimred_out[k]['with FA'].append(pehe_fa)
        dimred_out[k]['with SE'].append(pehe_se)
        dimred_out[k]['with IM'].append(pehe_im)
        dimred_out[k]['with KernelPCA'].append(pehe_lle)
        dimred_out[k]['with AE'].append(pehe_ae)
        
        
        if data == 'synth':
            cate_u = CATEModel(model=cate_learner[-1], representation=None, standardize=False)
            
            U_train = dm.train_u
            mu_u = U_train.mean(axis=0)
            std_u = U_train.std(axis=0)
            U_train -= mu_u
            U_train /= std_u
        
            cate_u.fit(pd.DataFrame(dm.train_u), Z_train, Y_train)
        
            U_test = dm.test_u
            U_test -= mu_u
            U_test /= std_u
            
            pehe_u = cate_u.eval(pd.DataFrame(dm.test_u), Y0_test, Y1_test)
            dimred_out[k]['with latent'].append(pehe_u)



In [None]:
results = dict()
for k, v in dimred_out.items():
    results[k] = {
        
        'with PCA': (np.mean(dimred_out[k]['with PCA']), np.std(dimred_out[k]['with PCA'])),
        'with FA': (np.mean(dimred_out[k]['with FA']), np.std(dimred_out[k]['with FA'])),
        'with SE': (np.mean(dimred_out[k]['with SE']), np.std(dimred_out[k]['with SE'])),
        'with IM': (np.mean(dimred_out[k]['with IM']), np.std(dimred_out[k]['with IM'])),
        'with KernelPCA': (np.mean(dimred_out[k]['with KernelPCA']), np.std(dimred_out[k]['with KernelPCA'])),
        'with AE': (np.mean(dimred_out[k]['with AE']), np.std(dimred_out[k]['with AE'])),
        'with EBM': (np.mean(dimred_out[k]['with EBM']), np.std(dimred_out[k]['with EBM'])),
        'latent': (np.mean(dimred_out[k]['with latent']), np.std(dimred_out[k]['with latent'])),
    }
print(json.dumps(results, indent=4))