In [61]:
from sklearn.neural_network import MLPRegressor, MLPClassifier
import numpy as np
import pandas as pd

def dgp(n_obs:int=1000,
        n_covariates:int=10,
        n_confounders:int=10,
        n_treatments:int=1,
        n_outcomes:int=1,
        binary_treatment:bool=False,
        fraction_treated:float=0.5,
        binary_outcome:bool=False,
        fraction_positive:float=0.5,
        scale:float=1,
        seed:int|None=None,
        diagonal_covariance_matrix:bool=True) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Generate data for a treatment effect estimation problem.
    """

    np.random.seed(seed)
    
    # Generate Covariates and Confounders
    means = np.random.normal(0, 1, n_covariates) * scale
    cov_mat = np.eye(n_covariates) if diagonal_covariance_matrix else create_random_covariance_matrix(n_covariates)
    covariates = np.random.multivariate_normal(means, cov_mat, n_obs)

    # Generate Confounders
    means = np.random.normal(0, 1, n_confounders) * scale 
    cov_mat = np.eye(n_confounders) if diagonal_covariance_matrix else create_random_covariance_matrix(n_confounders)
    confounders = np.random.multivariate_normal(means, cov_mat, n_obs)

    ## Neural Network for Nuissance Functions
    activation = 'relu'

    # Generate Treatment
    hidden_layers = (20,20,20)
    pseudo_x = np.random.randn(n_obs, n_confounders) * scale
    if binary_treatment:
        pseudo_t = (np.random.binomial(1, 0.5, (n_obs, n_treatments))).reshape(-1, n_treatments)
        treatment_model = MLPClassifier(hidden_layer_sizes=hidden_layers, activation=activation, random_state=seed, max_iter=1000)
    else:
        pseudo_t = (np.random.randn(n_obs, n_treatments) * scale).reshape(-1, n_treatments)
        treatment_model = MLPRegressor(hidden_layer_sizes=hidden_layers, activation=activation, random_state=seed, max_iter=1000)
    if n_treatments == 1:
        treatment_model.fit(pseudo_x, pseudo_t.ravel())
    else:
        treatment_model.fit(pseudo_x, pseudo_t)

    if n_treatments == 1 and binary_treatment:
        treatments = treatment_model.predict_proba(confounders)[:,1].reshape(-1, n_treatments)
    elif binary_treatment:
        treatments = treatment_model.predict_proba(confounders).reshape(-1, n_treatments)
    else:
        treatments = treatment_model.predict(confounders).reshape(-1, n_treatments)

    treatments = np.where(treatments > find_optimal_threshold(treatments, fraction_treated), 1, 0) if binary_treatment else treatments

    if not binary_treatment:
        treatments += np.random.randn(n_obs, n_treatments) * scale
    
    # Generate Outcome
    hidden_layers = (20,20,20)
    pseudo_x = np.random.randn(n_obs, (n_confounders + n_covariates)) * scale

    if binary_outcome:
        pseudo_y = (np.random.binomial(1, 0.5, (n_obs, n_outcomes))).reshape(-1, n_outcomes)
        outcome_model = MLPClassifier(hidden_layer_sizes=hidden_layers, activation=activation, random_state=seed, max_iter=1000)
    else:
        pseudo_y = (np.random.randn(n_obs, n_outcomes) * scale).reshape(-1, n_outcomes)
        outcome_model = MLPRegressor(hidden_layer_sizes=hidden_layers, activation=activation, random_state=seed, max_iter=1000)
    if n_outcomes == 1:
        outcome_model.fit(np.concatenate((pseudo_x,pseudo_t),axis=1), pseudo_y.ravel())
    else:
        outcome_model.fit(np.concatenate((pseudo_x,pseudo_t),axis=1), pseudo_y)

    X_W = np.concatenate([covariates, confounders], axis=1)
    X_W_t = np.concatenate([X_W, treatments], axis=1)

    if n_outcomes == 1 and binary_outcome:
        outcome = outcome_model.predict_proba(X_W_t)[:,1].reshape(-1, n_outcomes)
    elif binary_outcome:
        outcome = outcome_model.predict_proba(X_W_t).reshape(-1, n_outcomes)
    else:
        outcome = outcome_model.predict(X_W_t).reshape(-1, n_outcomes)
        
    outcome = np.where(outcome > find_optimal_threshold(outcome, fraction_positive), 1, 0) if binary_outcome else outcome
    
    if not binary_outcome:
        outcome += np.random.randn(n_obs, n_outcomes) * scale
    
    # Compute true treatment effect
    cates = np.zeros((n_obs, n_treatments*n_outcomes))
    for t in range(n_treatments):
        # Potential Outcome Under Treatment=1
        t1 = np.copy(treatments) 
        t1[:,t] = 1
        X_W_t1 = np.concatenate((X_W,t1),axis=1)
        if n_outcomes == 1 and binary_outcome:
            y1 = outcome_model.predict_proba(X_W_t1)[:,1].reshape(-1,n_outcomes)
        elif binary_outcome:
            y1 = outcome_model.predict_proba(X_W_t1).reshape(-1,n_outcomes)
        else:
            y1 = outcome_model.predict(X_W_t1).reshape(-1,n_outcomes)

        # Potential Outcome Under Treatment=0
        t0 = np.copy(treatments)
        t0[:,t] = 0
        X_W_t0 = np.concatenate((X_W,t0),axis=1)
        if n_outcomes == 1 and binary_outcome:
            y0 = outcome_model.predict_proba(X_W_t0)[:,1].reshape(-1,n_outcomes)
        elif binary_outcome:
            y0 = outcome_model.predict_proba(X_W_t0).reshape(-1,n_outcomes)
        else:
            y0 = outcome_model.predict(X_W_t0).reshape(-1,n_outcomes)

        cates[:,t*n_outcomes:t*n_outcomes+n_outcomes] = y1 - y0

    # Save results
    results_np = np.concatenate((covariates, confounders, treatments, outcome),axis=1)
    data_columns = [f'W{i}' for i in range(n_covariates)] + [f'X{i}' for i in range(n_confounders)] + [f'T{i}' for i in range(n_treatments)] + [f'Y{i}' for i in range(n_outcomes)]
    data_df = pd.DataFrame(results_np, columns=data_columns)

    cate_columns = [f'cate_T{i}_Y{j}' for i in range(n_treatments) for j in range(n_outcomes)]
    cate_df = pd.DataFrame(cates, columns=cate_columns)

    ate_columns = [f'ate_T{i}_Y{j}' for i in range(n_treatments) for j in range(n_outcomes)]
    ate_df = pd.DataFrame(cates.mean(axis=0).reshape(1,-1), columns=ate_columns)

    return data_df, cate_df, ate_df

def find_optimal_threshold(pred_probs, approx_percentage=0.5):
    """
    Find the optimal threshold to achieve the approx. percentage of positive class. 
    This helps to prevent having no overlap between the treated and control groups.
    """
    sorted_probs = np.sort(pred_probs,axis=0)
    threshold_index = int(len(sorted_probs) * approx_percentage)
    optimal_threshold = sorted_probs[threshold_index]
    return optimal_threshold

def create_random_covariance_matrix(dim):
    """
    Create a random covariance matrix.
    """
    A = np.random.randn(dim, dim)
    cov_matrix = np.dot(A, A.T)

    diag_indices = np.diag_indices_from(cov_matrix)
    cov_matrix[diag_indices] = np.abs(cov_matrix[diag_indices]) + 0.1
    return cov_matrix

In [74]:
data_df, cate_df, ate_df = dgp(n_obs=1000,
                               n_covariates=10,
                               n_confounders=10,
                               n_treatments=5,
                               n_outcomes=1,
                               binary_treatment=True,
                               fraction_treated=0.5,
                               binary_outcome=False,
                               fraction_positive=0.5,
                               scale=30,
                               seed=None,
                               diagonal_covariance_matrix=True)



In [75]:
data_df

Unnamed: 0,W0,W1,W2,W3,W4,W5,W6,W7,W8,W9,...,X6,X7,X8,X9,T0,T1,T2,T3,T4,Y0
0,-63.869717,10.626870,27.480257,-16.048573,-7.718788,-28.804516,-57.662106,36.465700,32.379582,-10.538396,...,-28.926936,30.034871,-2.507630,15.574479,0.0,1.0,1.0,0.0,1.0,57.111071
1,-64.463573,7.476289,26.594316,-17.019390,-8.889804,-30.338451,-58.665304,35.041399,30.371979,-9.419874,...,-28.876089,28.794599,-1.227040,16.080700,0.0,1.0,1.0,0.0,0.0,33.349419
2,-61.746744,10.629522,26.386854,-14.588581,-6.058691,-30.063857,-59.076826,37.872997,31.520530,-8.518584,...,-29.760401,30.128397,-1.533563,14.142542,0.0,1.0,0.0,1.0,1.0,144.353392
3,-62.483201,9.925598,25.187248,-16.930734,-7.824104,-28.545788,-59.951985,35.527234,31.098299,-9.575597,...,-29.144325,29.703110,-1.838615,15.529414,1.0,0.0,0.0,0.0,0.0,55.143905
4,-64.141353,8.824521,25.666961,-15.776921,-7.719007,-28.879479,-59.012123,37.143329,31.138973,-10.017532,...,-28.505809,30.402409,-1.931680,16.584241,0.0,1.0,0.0,1.0,0.0,34.860055
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,-64.112182,9.194732,24.923201,-16.779856,-9.058896,-30.647842,-57.669198,36.953432,31.301440,-9.682108,...,-28.617890,31.652116,-2.755394,17.330656,0.0,0.0,0.0,1.0,0.0,56.587517
996,-62.334175,10.928465,26.142200,-14.163146,-6.805477,-29.655725,-58.486430,36.851087,31.122672,-8.487671,...,-28.807007,28.530435,-1.780780,15.189558,0.0,1.0,1.0,0.0,0.0,90.499861
997,-64.144458,8.141573,27.781207,-15.045452,-6.793817,-31.282990,-57.488416,35.937420,32.677759,-9.607303,...,-29.860703,31.605106,-2.309214,17.967319,1.0,0.0,0.0,0.0,0.0,19.479951
998,-64.089699,9.845137,24.597014,-15.259416,-5.967002,-29.593532,-58.246697,36.328354,32.193929,-9.369815,...,-27.646487,31.642444,-1.879639,16.150896,0.0,1.0,0.0,1.0,0.0,106.057616


In [76]:
data_df[[c for c in data_df.columns if "T" in c and "_" not in c]].sum()

T0    499.0
T1    499.0
T2    499.0
T3    499.0
T4    499.0
dtype: float64

In [77]:
data_df[[c for c in data_df.columns if "Y" in c and "_" not in c]].sum()

Y0    66772.50619
dtype: float64

In [78]:
cate_df

Unnamed: 0,cate_T0_Y0,cate_T1_Y0,cate_T2_Y0,cate_T3_Y0,cate_T4_Y0
0,1.072227,-0.890890,-1.943593,0.565794,-5.538935
1,0.528202,-0.190281,-1.836461,0.106238,-5.762228
2,0.624164,-0.442827,-2.453883,0.576646,-5.090871
3,2.581802,0.469734,-0.553652,0.480468,-4.038498
4,1.786063,0.991513,-0.553652,0.480468,-3.584446
...,...,...,...,...,...
995,1.104402,1.261288,-1.005219,0.391291,-3.546594
996,0.624164,-0.442827,-2.266542,0.117730,-5.847483
997,1.082013,1.590224,-0.780861,0.208762,-2.443432
998,1.246446,1.415971,-0.591702,0.266363,-2.808559


In [79]:
print(data_df['Y0'].mean())
print(data_df['Y0'].min())
print(data_df['Y0'].max())
print(data_df['Y0'].std())

66.77250618968141
-48.25354574232318
172.7054309508692
30.43738357201222


In [80]:
print(cate_df.min(axis=0))
print(cate_df.max(axis=0))
print(cate_df.mean(axis=0))
print(cate_df.std(axis=0))

cate_T0_Y0    0.045799
cate_T1_Y0   -1.678223
cate_T2_Y0   -2.744183
cate_T3_Y0   -0.151698
cate_T4_Y0   -6.570058
dtype: float64
cate_T0_Y0    3.207311
cate_T1_Y0    1.626244
cate_T2_Y0   -0.034552
cate_T3_Y0    2.776847
cate_T4_Y0   -2.443432
dtype: float64
cate_T0_Y0    1.170852
cate_T1_Y0    0.251697
cate_T2_Y0   -1.576044
cate_T3_Y0    0.477604
cate_T4_Y0   -4.360983
dtype: float64
cate_T0_Y0    0.545780
cate_T1_Y0    0.853007
cate_T2_Y0    0.830741
cate_T3_Y0    0.518761
cate_T4_Y0    0.971089
dtype: float64


In [81]:
ate_df

Unnamed: 0,ate_T0_Y0,ate_T1_Y0,ate_T2_Y0,ate_T3_Y0,ate_T4_Y0
0,1.170852,0.251697,-1.576044,0.477604,-4.360983
