# Argumentative Causal Discovery

Notebook collecting results for causal discovery algorithm ABAPC presented in the Argumentative Causal Discovery paper (Russo, Rapberger and Toni (2024))

In [4]:
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.insert(0,'../')
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 2000)
sys.path.append('../utils/')
from plotting import *
print(sys.path)

save_figs = False
debug = False
datasets = ['cancer', 'earthquake', 'survey', 'asia']
dags_nodes_map = {'asia':8, 'cancer':5, 'earthquake':5, 'sachs':11, 'survey':6, 'alarm':37, 'child':20, 'insurance':27, 'hailfinder':56, 'hepar2':70}
dags_arcs_map = {'asia':8, 'cancer':4, 'earthquake':4, 'sachs':17, 'survey':6, 'alarm':46, 'child':25, 'insurance':52, 'hailfinder':66, 'hepar2':123}
methods = ['Random', 'FGS', 'NOTEARS-MLP', 'MPC', 'ASPCR', 'ABAPC (Ours)']
names_dict = {'random':'Random', 'fgs':'FGS', 'nt':'NOTEARS-MLP', 'mpc':'MPC', 'aspcr-log':'ASPCR', 'abapc':'ABAPC (Ours)'}
symbols_dict = {'abapc':'triangle-down-dot','fgs':'triangle-up-dot','nt':'pentagon-dot','mpc':'hexagon2-dot', 'random':'x', 'aspcr-log':'diamond-dot'}  
colors_dict = {'abapc':sec_blue,'fgs':sec_orange,'nt':main_purple,'mpc':main_green,'random':'grey', 'aspcr-log':sec_purple}
version = 'bnlearn_50rep' ## for 5000 samples
# version = 'bnlearn_dag_v5_2000' ## for 2000 samples

['../', '../', '/vol/bitbucket/fr920/ArgCausalDisco/notebooks', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/vol/bitbucket/fr920/envs/discoclean/lib/python3.10/site-packages', '../utils/', '../utils/']


## Main Paper Results

### CPDAG Evaluation

In [5]:
version_cpdag = version+'_cpdag'

all_sum = pd.DataFrame(np.load(f"../results/stored_results_{version}_cpdag.npy", allow_pickle=True), 
                       columns=['dataset', 'model', 'elapsed_mean', 'elapsed_std', 'nnz_mean', 'nnz_std', 
                                'fdr_mean', 'fdr_std', 'tpr_mean', 'tpr_std', 'fpr_mean', 'fpr_std', 
                                'precision_mean', 'precision_std', 'recall_mean', 'recall_std',
                                'F1_mean', 'F1_std', 'shd_mean', 'shd_std', 
                                'SID_low_mean', 'SID_low_std', 'SID_high_mean', 'SID_high_std'
                                ])
all_sum_aspcr = pd.DataFrame(np.load(f"../results/stored_results_aspcr_{version}_cpdag.npy", allow_pickle=True), 
                       columns=['dataset', 'model', 'elapsed_mean', 'elapsed_std', 'nnz_mean', 'nnz_std', 
                                'fdr_mean', 'fdr_std', 'tpr_mean', 'tpr_std', 'fpr_mean', 'fpr_std', 
                                'precision_mean', 'precision_std', 'recall_mean', 'recall_std',
                                'F1_mean', 'F1_std', 'shd_mean', 'shd_std', 
                                'SID_low_mean', 'SID_low_std', 'SID_high_mean', 'SID_high_std'
                                ])
all_sum_aspcr['model'] = 'ASPCR'

all_sum = pd.concat([all_sum, all_sum_aspcr])

all_sum['n_edges'] = all_sum['dataset'].map(dags_arcs_map)
all_sum['n_nodes'] = all_sum['dataset'].map(dags_nodes_map)
for var in ['shd','SID_low','SID_high']:
    all_sum['p_'+var+'_mean'] = all_sum[var+'_mean'].astype(float)/all_sum['n_edges'].astype(int)
    all_sum['p_'+var+'_std'] = all_sum[var+'_std'].astype(float)/all_sum['n_edges'].astype(int)
all_sum['dataset'] = [a.upper() for a in all_sum["dataset"].astype(str)]
all_sum['dataset'] = all_sum['dataset'] +np.repeat("<br> |V|=",len(all_sum))+ all_sum["n_nodes"].astype(str)+np.repeat(", |E|=",len(all_sum))+\
                     all_sum["n_edges"].astype(str)

double_bar_chart_plotly(all_sum, ['p_SID_low','p_SID_high'], names_dict, colors_dict, methods, save_figs=save_figs, output_name="../results/figs/Fig.2_SID_cpdag.html", debug=False, range_y1=[0,6], range_y2=[0,6])#

### Runtime

In [9]:
plot_runtime(all_sum, ['n_nodes'], "n_nodes <= 10", 
                     names_dict, symbols_dict, colors_dict, ['fgs', 'nt', 'mpc', 'aspcr-log', 'abapc'],
                         share_y=False, save_figs=save_figs, 
                            output_name="../results/figs/Fig.3_runtime.html", debug=False, font_size=20,
                            plot_height=350, plot_width=800)

## Appendix

### t-Tests for difference in means

In [12]:
from scipy.stats import ttest_ind_from_stats
from itertools import combinations
nobs = 50
### teast difference in means between each pair of methods for each dataset for SID_low and SID_high, nobs = 10 (10 datasets)
print("0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 'ns' 1.\n")

for dataset in all_sum['dataset'].unique():
    print(dataset.replace("<br>","").upper())
    print(50*"=")
    for var in ['SID_low','SID_high']:
        print(var)
        for method1,method2 in combinations(sorted(all_sum['model'].unique()), 2):
            # if 'ABA' in method1:
            if 'ASIA' in dataset and ('ASPCR' in method1 or 'ASPCR' in method2):
                continue
            a = all_sum[(all_sum['dataset']==dataset) & (all_sum['model']==method1)][var+'_mean'].values[0]
            b = all_sum[(all_sum['dataset']==dataset) & (all_sum['model']==method2)][var+'_mean'].values[0]
            a_std = all_sum[(all_sum['dataset']==dataset) & (all_sum['model']==method1)][var+'_std'].values[0]
            b_std = all_sum[(all_sum['dataset']==dataset) & (all_sum['model']==method2)][var+'_std'].values[0]
            t, p = ttest_ind_from_stats(a, a_std, nobs, b, b_std, nobs, equal_var=False)
            method1 = "APC" if "ABA" in method1 else method1
            method2 = "APC" if "ABA" in method2 else method2
            method1 = "NT" if "NOTEARS" in method1 else method1
            method2 = "NT" if "NOTEARS" in method2 else method2
            method1 = "RND" if "Random" in method1 else method1
            method2 = "RND" if "Random" in method2 else method2
            print(f'\!\!\!{method1} $({a:.1f}\pm{a_std:.1f})$ \!v\! {method2} $({b:.1f}\pm{b_std:.1f})$ \!\!\!&\!\!\! {t:.3f} \!\!\!&\!\!\! {p:.3f}{"" if p>0.1 else "." if p<=0.1 and p>0.05 else "*" if p<=0.05 and p>0.01 else "**" if p<=0.01 and p>0.001 else "***"} \\\\')
        print("")

0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 'ns' 1.

CANCER |V|=5, |E|=4
SID_low
\!\!\!APC $(9.8\pm2.2)$ \!v\! ASPCR $(6.8\pm3.4)$ \!\!\!&\!\!\! 5.247 \!\!\!&\!\!\! 0.000*** \\
\!\!\!APC $(9.8\pm2.2)$ \!v\! FGS $(6.3\pm1.9)$ \!\!\!&\!\!\! 8.722 \!\!\!&\!\!\! 0.000*** \\
\!\!\!APC $(9.8\pm2.2)$ \!v\! MPC $(3.9\pm1.8)$ \!\!\!&\!\!\! 14.833 \!\!\!&\!\!\! 0.000*** \\
\!\!\!APC $(9.8\pm2.2)$ \!v\! NT $(9.3\pm1.6)$ \!\!\!&\!\!\! 1.476 \!\!\!&\!\!\! 0.144 \\
\!\!\!APC $(9.8\pm2.2)$ \!v\! RND $(7.9\pm3.6)$ \!\!\!&\!\!\! 3.169 \!\!\!&\!\!\! 0.002** \\
\!\!\!ASPCR $(6.8\pm3.4)$ \!v\! FGS $(6.3\pm1.9)$ \!\!\!&\!\!\! 0.976 \!\!\!&\!\!\! 0.332 \\
\!\!\!ASPCR $(6.8\pm3.4)$ \!v\! MPC $(3.9\pm1.8)$ \!\!\!&\!\!\! 5.360 \!\!\!&\!\!\! 0.000*** \\
\!\!\!ASPCR $(6.8\pm3.4)$ \!v\! NT $(9.3\pm1.6)$ \!\!\!&\!\!\! -4.621 \!\!\!&\!\!\! 0.000*** \\
\!\!\!ASPCR $(6.8\pm3.4)$ \!v\! RND $(7.9\pm3.6)$ \!\!\!&\!\!\! -1.586 \!\!\!&\!\!\! 0.116 \\
\!\!\!FGS $(6.3\pm1.9)$ \!v\! MPC $(3.9\pm1.8)$ \!\!\!&\!\!\! 6.503 \!\!\!&

### Graph Size, SHD and F1

In [13]:
double_bar_chart_plotly(all_sum, ['p_shd','F1'], names_dict, colors_dict, methods, save_figs=save_figs, output_name="../results/figs/Fig.4_SHD_F1_cpdag.html", debug=False, range_y1=[0,2], range_y2=[0,1.1])#
double_bar_chart_plotly(all_sum, ['precision','recall'], names_dict, colors_dict, methods, save_figs=save_figs, output_name="../results/figs/Fig.5_pre_rec_cpdag.html", debug=False)
bar_chart_plotly(all_sum, 'nnz', names_dict, colors_dict, methods, save_figs=save_figs, output_name="../results/figs/Fig.6_EGS_cpdag.html", debug=False)

### DAG evaluation

In [15]:
all_sum = pd.DataFrame(np.load(f"../results/stored_results_{version}.npy", allow_pickle=True), 
                       columns=['dataset', 'model', 'elapsed_mean', 'elapsed_std', 'nnz_mean', 'nnz_std', 
                                'fdr_mean', 'fdr_std', 'tpr_mean', 'tpr_std', 'fpr_mean', 'fpr_std', 
                                'precision_mean', 'precision_std', 'recall_mean', 'recall_std',
                                'F1_mean', 'F1_std', 'shd_mean', 'shd_std', 'SID_mean', 'SID_std'
                                ])

all_sum_aspcr = pd.DataFrame(np.load(f"../results/stored_results_aspcr_{version}.npy", allow_pickle=True), 
                       columns=['dataset', 'model', 'elapsed_mean', 'elapsed_std', 'nnz_mean', 'nnz_std', 
                                'fdr_mean', 'fdr_std', 'tpr_mean', 'tpr_std', 'fpr_mean', 'fpr_std', 
                                'precision_mean', 'precision_std', 'recall_mean', 'recall_std',
                                'F1_mean', 'F1_std', 'shd_mean', 'shd_std', 'SID_mean', 'SID_std'
                                ])
all_sum_aspcr['model'] = 'ASPCR'

all_sum = pd.concat([all_sum, all_sum_aspcr])

all_sum['n_edges'] = all_sum['dataset'].map(dags_arcs_map)
all_sum['n_nodes'] = all_sum['dataset'].map(dags_nodes_map)
for var in ['shd','SID']:
    all_sum['p_'+var+'_mean'] = all_sum[var+'_mean'].astype(float)/all_sum['n_edges'].astype(int)
    all_sum['p_'+var+'_std'] = all_sum[var+'_std'].astype(float)/all_sum['n_edges'].astype(int)
all_sum['dataset'] = [a.upper() for a in all_sum["dataset"].astype(str)]
all_sum['dataset'] = all_sum['dataset'] +np.repeat("<br> |V|=",len(all_sum))+ all_sum["n_nodes"].astype(str)+np.repeat(", |E|=",len(all_sum))+\
                     all_sum["n_edges"].astype(str)

double_bar_chart_plotly(all_sum, ['p_shd','p_SID'], names_dict, colors_dict, methods, save_figs=save_figs, output_name="../results/figs/Fig.7_SHD_SID_dag.html", debug=False, range_y1=[0,2], range_y2=[0,5.6])#
bar_chart_plotly(all_sum, 'nnz', names_dict, colors_dict, methods, save_figs=save_figs, output_name="../results/figs/Fig.8_EGS_dag.html", debug=False)
bar_chart_plotly(all_sum, 'F1', names_dict, colors_dict, methods, save_figs=save_figs, output_name="../results/figs/Fig.9_F1_dag.html", debug=False)
double_bar_chart_plotly(all_sum, ['precision','recall'], names_dict, colors_dict, methods, save_figs=save_figs, output_name="../results/figs/Fig.10_pre_rec_dag.html", debug=False)