In [1]:
import dowhy
import pandas as pd
import numpy as np
from sklearn import preprocessing

In [2]:
from itertools import product

In [3]:
import econml
from econml.dml import LinearDML, CausalForestDML
from sklearn.ensemble import GradientBoostingRegressor


In [5]:
import logging

logger = logging.getLogger()
logger.setLevel(logging.WARNING)

In [6]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)

In [7]:
df = pd.read_csv("trapping_b2.csv",index_col=0)

In [8]:
causal_graph = """digraph {
"Salinity"->"Dispersion effectiveness";
"Salinity"->"Oil trapping efficiency";
"Salinity"->"Settling efficiency";
"Dispersant"->"Dispersion effectiveness";
"Dispersant"->"Oil trapping efficiency";
"Dispersant"->"Settling efficiency";
"Dispersion effectiveness"->"Settling efficiency";
"Dispersion effectiveness"->"Oil trapping efficiency";
"Settling efficiency"->"Oil trapping efficiency";
}"""


In [9]:
scaler = preprocessing.MinMaxScaler()

In [10]:
df_scaled = pd.DataFrame(scaler.fit_transform(df),columns = df.columns)

In [11]:
t_y = [("Salinity","Dispersion effectiveness"),("Salinity","Oil trapping efficiency"),
         ("Salinity","Settling efficiency"),("Dispersant","Dispersion effectiveness"),
         ("Dispersant","Oil trapping efficiency"),("Dispersant","Settling efficiency"),
         ("Dispersion effectiveness","Settling efficiency"),("Dispersion effectiveness","Oil trapping efficiency"),
         ("Settling efficiency","Oil trapping efficiency")]

In [12]:
params = [(s1, s2, causal_graph,df_scaled) for (s1, s2) in t_y]

In [188]:
def causal_all_in_one(treatment,outcome,graph,df):
    model = dowhy.CausalModel(data = df,
                        treatment = treatment,
                        outcome = outcome,
                        graph = graph)
    
    modifiers = model.get_effect_modifiers()
    confounders = model.get_common_causes()   

    estimand = model.identify_effect(proceed_when_unidentifiable=True)
    backdoor_var = estimand.backdoor_variables
    
    #  Linear
    estimate_li = model.estimate_effect(estimand,method_name = "backdoor.linear_regression", method_params = None, confidence_intervals = True)
    
    print(treatment,outcome,"############### Now refuting: Random Common Cause (Linear)#######################")
    res_random_li=model.refute_estimate(estimand,estimate_li, method_name="random_common_cause")    
    print(treatment,outcome,"############### Now refuting: Add Unobserved Common Cause (Linear)######################")
    res_unobserved_li=model.refute_estimate(estimand, estimate_li, method_name="add_unobserved_common_cause",
                                         confounders_effect_on_treatment="binary_flip", confounders_effect_on_outcome="linear",
                                        effect_strength_on_treatment=0.01, effect_strength_on_outcome=0.02)
    print(treatment,outcome,"############### Now refuting: Placebo (Linear)##############################")
    res_placebo_li=model.refute_estimate(estimand, estimate_li, method_name="placebo_treatment_refuter",placebo_type="permute")
    li_res = [estimate_li.value, estimate_li.get_confidence_intervals(),res_random_li,res_unobserved_li,res_placebo_li]

    #  DML
    if len(confounders)>0 or len(modifiers)>0:     
        
        #est_nonparam = CausalForestDML(model_y=GradientBoostingRegressor(), model_t=GradientBoostingRegressor(),random_state=12)
        est_nonparam = CausalForestDML(random_state=12)

        Y = df[outcome].values
        T = df[treatment].values

        args = [Y, T]


        if len(modifiers)== 0:
                        
            print('Special case: NO Effect Modifier!')
            kwargs = {'inference':'auto'}
            X = df[confounders].values
            kwargs['X'] = X
            W = None

            # Here in the special case, we use raw EconML interface instead of DoWhy wrapper or EconML wrapper in either package to avoid confusion
            est_nonparam.fit(*args, **kwargs)            
            te_pred = est_nonparam.effect(X)
            estimated_ate_ci = (est_nonparam.effect_interval(X)[0].mean(),est_nonparam.effect_interval(X)[1].mean())

            print(treatment,outcome,"############### NO DML REFUTATION!#######################")
            
            res_random_dml,res_unobserved_dml,res_placebo_dml = None,None,None   

            

        else:
            print('Ordinary case: has effect modifier and confounders.')
            
            kwargs = {'outcome_names':[outcome],'treatment_names':[treatment],'graph':causal_graph,'inference':'auto'}
            

            X = df[modifiers].values
            kwargs['X'] = X
            kwargs['feature_names'] = modifiers      
            
            if confounders:
                W = df[confounders].values
                kwargs['W'] = W
                kwargs['confounder_names'] = confounders
            
            est_nonparam_dw = est_nonparam.dowhy.fit(*args,**kwargs)            

            print(treatment,outcome,"############### Now refuting: Random Common Cause (DML)##################")
            res_random_dml = est_nonparam_dw.refute_estimate(method_name="random_common_cause")

            print(treatment,outcome,"############### Now refuting: Add Unobserved Common Cause (DML)##################")
            res_unobserved_dml = est_nonparam_dw.refute_estimate(method_name="add_unobserved_common_cause",
                                                 confounders_effect_on_treatment="binary_flip", confounders_effect_on_outcome="linear",
                                                effect_strength_on_treatment=0.01, effect_strength_on_outcome=0.02)

            print(treatment,outcome,"############### Now refuting: Placebo (DML)##############################")
            res_placebo_dml = est_nonparam_dw.refute_estimate(method_name="placebo_treatment_refuter", placebo_type="permute", num_simulations=100)
            
            te_pred = est_nonparam_dw.effect(X)
            estimated_ate_ci = est_nonparam_dw.ate_interval(X)
        

        estimated_ate = te_pred.mean()

        dml_res = [estimated_ate, te_pred,estimated_ate_ci,res_random_dml,res_unobserved_dml,res_placebo_dml]        

    else:
        dml_res = None

    return li_res,dml_res,modifiers,confounders, backdoor_var

In [189]:
res = []

In [190]:
for param in params:
    res.append(causal_all_in_one(*param))

Dispersion effectiveness Settling efficiency ############### Now refuting: Random Common Cause (Linear)#######################
Dispersion effectiveness Settling efficiency ############### Now refuting: Add Unobserved Common Cause (Linear)######################
Dispersion effectiveness Settling efficiency ############### Now refuting: Placebo (Linear)##############################
Special case: NO Effect Modifier!
Dispersion effectiveness Settling efficiency ############### NO DML REFUTATION!#######################
Dispersion effectiveness Oil trapping efficiency ############### Now refuting: Random Common Cause (Linear)#######################
Dispersion effectiveness Oil trapping efficiency ############### Now refuting: Add Unobserved Common Cause (Linear)######################
Dispersion effectiveness Oil trapping efficiency ############### Now refuting: Placebo (Linear)##############################
Special case: NO Effect Modifier!
Dispersion effectiveness Oil trapping efficiency ##

In [191]:
results_full = res

In [192]:
df_res_full=pd.DataFrame(columns = 
                         ['treatment','outcome',
                          'ate_li','ci_li',
                          'rand_li','rand_li-p-val','rand_li-is_statistically_significant',
                          'unobserved_li','placebo_li','li-pl-p-val','li-pl_is_statistically_significant',
                          'ate_dml','ate2_dml','ci_dml',
                          'rand_dml','rand_dml-p-val','rand_dml-is_statistically_significant',
                          'unobserved_dml','placebo_dml','dml_pl_p_val','dml_pl_is_statistically_significant',
                          'modifiers','confounders','backdoor_var'])

In [193]:
df_res_full['treatment'] = [ele[0] for ele in params]
df_res_full['outcome'] = [ele[1] for ele in params]

In [194]:
df_res_full['ate_li'] = [x[0][0] for x in results_full]
df_res_full['ci_li'] = [x[0][1] for x in results_full]

In [195]:
df_res_full['rand_li'] = [x[0][2].new_effect for x in results_full]
df_res_full['rand_li-p-val'] = [x[0][2].refutation_result['p_value'] for x in results_full]
df_res_full['rand_li-is_statistically_significant'] = [x[0][2].refutation_result['is_statistically_significant'] for x in results_full]
df_res_full['unobserved_li'] = [x[0][3].new_effect for x in results_full]
df_res_full['placebo_li'] = [x[0][4].new_effect for x in results_full]
df_res_full['li-pl-p-val'] = [x[0][4].refutation_result['p_value'] for x in results_full]
df_res_full['li-pl_is_statistically_significant'] = [x[0][4].refutation_result['is_statistically_significant'] for x in results_full]

In [196]:
df_res_full['ate_dml'] = [x[1][0] if x[1] else None for x in results_full]
df_res_full['ate2_dml'] = [x[1][1] if x[1] else None for x in results_full]
df_res_full['ci_dml'] = [x[1][2] if x[1] else None for x in results_full]

In [197]:
df_res_full['rand_dml'] = [x[1][3].new_effect if x[1][3] else None for x in results_full]
df_res_full['rand_dml-p-val'] = [x[1][3].refutation_result['p_value'] if x[1][3] else None for x in results_full]
df_res_full['rand_dml-is_statistically_significant'] = [x[1][3].refutation_result['is_statistically_significant'] if x[1][3] else None for x in results_full]
df_res_full['unobserved_dml'] = [x[1][4].new_effect if x[1][4] else None for x in results_full]
df_res_full['placebo_dml'] = [x[1][5].new_effect if x[1][4] else None for x in results_full]
df_res_full['dml_pl_p_val'] = [x[1][5].refutation_result['p_value'] if x[1][5] else None for x in results_full]
df_res_full['dml_pl_is_statistically_significant'] = [x[1][5].refutation_result['is_statistically_significant'] if x[1][5] else None for x in results_full]

In [198]:
df_res_full['modifiers'] = [x[2] for x in results_full]
df_res_full['confounders'] = [x[3] for x in results_full]

In [199]:
df_res_full['backdoor_var'] = [x[4] for x in results_full]

In [201]:
df_res_full.to_csv("causal_result_dml_hongrui_v2.csv")

In [None]:
# Create a pandas DataFrame to store the estimated ATEs, their confidence intervals, and refutation p-values
data = pd.DataFrame({
    'Causal Relationship': #[x[0]+' -> '+x[1] for x in list(zip(df_res_full["treatment"].values,df_res_full["outcome"].values))],
    ['SAL -> DE', 'SAL -> OTE', 'SAL -> SE', 'DISP -> DE', 'DISP -> OTE', 'DISP -> SE', 
                            'DE -> SE', 'DE -> OTE', 'SE -> OTE'],
    'ATE': df_res_full['ate_li'].values,
    'Lower CI': [x[0] if type(x) is tuple else x[0][0] for x in df_res_full['ci_li'].values],
    'Upper CI': [x[1] if type(x) is tuple else x[0][1] for x in df_res_full['ci_li'].values],
})

colors = ['#4C72B0', '#55A868', '#C44E52', '#8172B2', '#CCB974', '#64B5CD', '#4C4C4C', '#EDC948', '#FF9DA7']

# Define a color gradient for the ATE points
vmin = 0
vmax = 1
#cmap = sns.color_palette('Blues', as_cmap=True)
cmap = sns.cubehelix_palette(start=.5, rot=-.75, as_cmap=True)
norm = plt.Normalize(vmin=vmin, vmax=vmax)
sns.set_palette(colors)

# Set up the plot
with plt.rc_context({'font.family': 'Liberation Serif'}):
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.set_style('whitegrid')

    # Plot the ATE points
    sns.pointplot(x='ATE', y='Causal Relationship', data=data, join=False, palette=cmap(norm(np.abs(data['ATE']))), ax=ax)

    # Add the confidence intervals
    for i, row in data.iterrows():
        ax.errorbar(row['ATE'], i, xerr=[[row['ATE'] - row['Lower CI']], [row['Upper CI'] - row['ATE']]],
                    capsize=5, color=colors[1], alpha=0.6)

    # Add a horizontal line at y=0 to indicate no causal effect
    ax.axvline(x=0, linestyle='--', color='black')

    # Set the x-axis limits
    ax.set_xlim(-1.8, 1.8)

    # Add axis labels and a title
    ax.set_xlabel('ATE')
    ax.set_ylabel('Causal Relationships')
    #ax.set_yticklabels(data['Causal Relationship'], rotation=90)

    ax.set_title('Estimated ATEs with 95% Confidence Intervals')

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm)
    cbar.ax.set_ylabel('|ATE|', rotation=270, labelpad=15,font_properties='Liberation Serif')
    for label in cbar.ax.get_yticklabels() :
        label.set_fontproperties('Liberation Serif')

    fig.savefig('causal_effects.png', dpi=300)


In [309]:
cbar.set_ticks(cbar.get_ticks(),fontfamily='Liberation Serif')

In [281]:
data['Causal Relationship']

0                 Salinity -> Dispersion effectiveness
1                  Salinity -> Oil trapping efficiency
2                      Salinity -> Settling efficiency
3               Dispersant -> Dispersion effectiveness
4                Dispersant -> Oil trapping efficiency
5                    Dispersant -> Settling efficiency
6      Dispersion effectiveness -> Settling efficiency
7    Dispersion effectiveness -> Oil trapping effic...
8       Settling efficiency -> Oil trapping efficiency
Name: Causal Relationship, dtype: object

In [None]:
# Generate heatmap without mirroring
with plt.rc_context({'font.family': 'Liberation Serif'}):
    fig, ax = plt.subplots(figsize=(5, 5))
    mask = np.triu(np.ones_like(corr_data, dtype=bool))
    heatmap = sns.heatmap(corr_data, annot=True, fmt='.3f', mask=mask, cmap='coolwarm', cbar_kws={'location': 'right', 'pad':-0.05,'shrink':0.8}, square=True, ax=ax, xticklabels=['SAL', 'DISP', 'DE', 'SE', 'OTE'], yticklabels=['SAL', 'DISP', 'DE', 'SE', 'OTE'])

    # Set title

    # Show plot
    fig.savefig('correlations.png', dpi=300)