# Results Notebook for Large Bnlearn Graphs

In [1]:
import pandas as pd
import numpy as np

import importlib
import utils
importlib.reload(utils)
from utils import plot_runtime, double_bar_chart_plotly, process_model_names_and_runtime_v1_data, process_mean_std_sid_data, DAG_NODES_MAP, DAG_EDGES_MAP

In [2]:
import glob

# Get all cpdag_metrics.csv files in the specified folder
csv_files = glob.glob('../../results/gradual/v2_run_for_bnlearn/*/*/cpdag_metrics.csv')

# Concatenate all the csv files into a single DataFrame
v2_data_bn = pd.concat([pd.read_csv(f) for f in csv_files], ignore_index=True)

v2_data_bn.loc[0].T

nnz                                                                      12
fdr                                                                  0.8333
tpr                                                                    0.25
fpr                                                                     0.5
precision                                                            0.1667
recall                                                                 0.25
F1                                                                      0.2
shd                                                                    16.0
sid_low                                                                15.0
sid_high                                                               15.0
dataset                                                                asia
seed                                                                   5864
n_nodes                                                                   8
n_edges     

In [3]:
# same for the DAG construction version for larger DAGs
# Get all cpdag_metrics.csv files in the specified folder
csv_files_search_every_step = glob.glob('../../results/gradual/v2_run_for_bnlearn_search_every_step/*/*/cpdag_metrics.csv')

# Concatenate all the csv files into a single DataFrame
v2_data_bn_search_every_step = pd.concat([pd.read_csv(f) for f in csv_files_search_every_step], ignore_index=True)

v2_data_bn_search_every_step.dataset.value_counts()

dataset
asia     50
sachs    50
child    19
Name: count, dtype: int64

In [4]:
baselines_data_bn = pd.read_csv('../../results/existing/bnlearn_graphs/all_existing_methods_metrics_cpdag.csv')

In [5]:
baselines_data_bn.dataset.unique()

array(['cancer', 'earthquake', 'survey', 'asia', 'sachs', 'child',
       'insurance'], dtype=object)

In [6]:
baselines_data_bn = pd.read_csv('../../results/existing/bnlearn_graphs/all_existing_methods_metrics_cpdag.csv')
baselines_data_bn_fgs = pd.read_csv('../../results/existing/bnlearn_graphs/fgs/all_existing_methods_metrics_cpdag.csv')
baselines_data_bn = pd.concat([baselines_data_bn, baselines_data_bn_fgs], ignore_index=True)
baselines_data_bn.loc[baselines_data_bn['model'] == 'ABAPC (Ours)', 'model'] = 'ABAPC (Original)'
baselines_data_bn['n_nodes'] = baselines_data_bn['dataset'].map(DAG_NODES_MAP)
baselines_data_bn['n_edges'] = baselines_data_bn['dataset'].map(DAG_EDGES_MAP)

In [7]:
v2_data_bn['model'] = 'V2'
v2_data_bn['elapsed'] = (
    v2_data_bn['elapsed_bsaf_creation'] +
    v2_data_bn['elapsed_model_solution'] + 
    v2_data_bn['aba_elapsed'] +
    v2_data_bn['ranking_elapsed']
)

v2_data_bn_search_every_step['model'] = 'V2_search_every_step'
v2_data_bn_search_every_step['elapsed'] = (
    v2_data_bn_search_every_step['elapsed_bsaf_creation'] +
    v2_data_bn_search_every_step['elapsed_model_solution'] + 
    v2_data_bn_search_every_step['aba_elapsed'] +
    v2_data_bn_search_every_step['ranking_elapsed']
)

In [8]:


methods = [
    'Random',
    'FGS',
    'NOTEARS-MLP',
    'MPC',
    'Causal ABA (Original)',
    'Gradual Causal ABA (Ours)',
    'Gradual Causal ABA (Ours, Adapted)'
]
names_dict = {
   m: m for m in methods
}

colors_dict = {
        'Random': 'grey',
        'FGS': '#b85c00',
        'NOTEARS-MLP': '#9454c4',
        'MPC': '#379f9f',
        'Causal ABA (Original)': '#0085CA',
        'Gradual Causal ABA (Ours)': "#ff8c00",
        'Gradual Causal ABA (Ours, Adapted)': "#ef6262",
    }

In [9]:
baselines_bn_processed = process_mean_std_sid_data(baselines_data_bn)
v2_data_bn_processed = process_mean_std_sid_data(v2_data_bn)
v2_data_bn_search_every_step_processed = process_mean_std_sid_data(v2_data_bn_search_every_step)

In [10]:

plot_data = pd.concat([baselines_bn_processed[baselines_bn_processed['dataset'].isin(['asia', 'sachs', 'child'])],
                       v2_data_bn_processed,
                       v2_data_bn_search_every_step_processed], 
                       ignore_index=True)

plot_data = pd.concat([
    plot_data[plot_data['dataset']=='asia'],
    plot_data[plot_data['dataset']=='sachs'],
    plot_data[plot_data['dataset']=='child'],
])

abapc_asia_data = {
    'dataset': ['asia'],
    'model': ['ABAPC (Original)'],
    'n_nodes': [8],
    'n_edges': [8],
    'sid_low_mean': [11.72],
    'sid_high_mean': [33.52],
    'sid_low_std': [6.79],
    'sid_high_std': [7.92],
    'precision_mean': [0.49],
    'precision_std': [0.18],
    'recall_mean': [0.51],
    'recall_std': [0.18],
    'f1_mean': [0.5],
    'f1_std': [0.18],
    'shd_mean': [5.24],
    'shd_std': [2.26],
    'n_shd_mean': [0.655],
    'n_shd_std': [0.2825],
    'nnz_mean': [8.44],
    'nnz_std': [0.84],
    'n_sid_low_mean': [1.465],
    'n_sid_high_mean': [4.19],
    'n_sid_low_std': [0.84875],
    'n_sid_high_std': [0.99],
}
abapc_asia_data = pd.DataFrame(abapc_asia_data)

plot_data = pd.concat([plot_data, abapc_asia_data], ignore_index=True)
plot_data['dataset'] = plot_data['dataset'].str.upper() + '<br>' + '|V|=' + plot_data['n_nodes'].astype(str) + ', |E|=' + plot_data['n_edges'].astype(str)

In [11]:
plot_data['model'] = plot_data['model'].map({
    'Random': 'Random',
    'FGS': 'FGS',
    'NOTEARS-MLP': 'NOTEARS-MLP',
    'MPC': 'MPC',
    'ABAPC (Original)': 'Causal ABA (Original)',
    'V1.1 Refined Fact Ranking': 'Causal ABA (Refined Fact Ranking)',
    'V1.2 Model Selection by Refined Fact Strengths': 'Causal ABA (Refined Model Ranking)',
    'V1.3 Model Selection by Arrows Sum': 'Causal ABA (Arrows Sum Model Ranking)',
    'V1.4 Model Selection by Arrows Mean': 'Causal ABA (Arrows Mean Model Ranking)',
    'V2': 'Gradual Causal ABA (Ours)',
    'V2_search_every_step': 'Gradual Causal ABA (Ours, Adapted)',
})
plot_data.head()

Unnamed: 0,dataset,n_nodes,n_edges,model,sid_low_mean,sid_high_mean,sid_low_std,sid_high_std,precision_mean,precision_std,...,shd_mean,shd_std,nnz_mean,nnz_std,n_sid_low_mean,n_sid_high_mean,n_sid_low_std,n_sid_high_std,n_shd_mean,n_shd_std
0,"ASIA<br>|V|=8, |E|=8",8,8,FGS,32.26,41.7,5.045952,2.808515,0.254134,0.047901,...,12.12,0.773014,7.68,0.767716,4.0325,5.2125,0.630744,0.351064,1.515,0.096627
1,"ASIA<br>|V|=8, |E|=8",8,8,MPC,15.24,41.02,5.59723,3.771483,0.364068,0.062073,...,4.64,0.662709,6.58,0.702474,1.905,5.1275,0.699654,0.471435,0.58,0.082839
2,"ASIA<br>|V|=8, |E|=8",8,8,NOTEARS-MLP,19.36,41.98,3.921474,2.094599,0.244302,0.026503,...,8.24,0.893514,8.32,0.652781,2.42,5.2475,0.490184,0.261825,1.03,0.111689
3,"ASIA<br>|V|=8, |E|=8",8,8,Random,24.98,37.24,6.988299,6.086318,0.1125,0.091925,...,13.0,1.714286,8.0,0.0,3.1225,4.655,0.873537,0.76079,1.625,0.214286
4,"ASIA<br>|V|=8, |E|=8",8,8,Gradual Causal ABA (Ours),22.96,25.48,7.637074,6.494079,0.21077,0.097715,...,11.62,2.202874,8.26,2.087964,2.87,3.185,0.954634,0.81176,1.4525,0.275359


In [12]:
fig = double_bar_chart_plotly(plot_data, 
                        names_dict,
                        colors_dict,
                        vars_to_plot=['n_sid_low', 'n_sid_high'],
                        names=['Best', 'Worst'],
                        labels=['Normalised SID', ''],
                        methods=methods,
                        dist_between_lines=0.1565,
                        lin_space=6,
                        nl_space=6,
                        intra_dis = 0.161,
                        inter_dis = 0.174,
                        start_pos = 0.04,
                            width=1400,
                            height=700,
                            range_y1=(0, 16),
                            range_y2=(0, 16))

fig.write_image('v2-sid-large.png', scale=3, width=1400, height=700)