In [None]:
import sys
sys.path.insert(0, '../../')

In [None]:
import pandas as pd
import numpy as np

In [None]:
random_graphs_compare = pd.read_csv('../../results_pure_aba/compare_semantics_random.csv')
random_graphs_compare.rename(
    columns={
        # Rename columns for clarity, the PR actually refers to maximally complete, not preferred
        col: col.replace('PR', 'CO_MAX') for col in random_graphs_compare.columns if 'PR' in col
    },
    inplace=True
)
random_graphs_compare.head()

In [None]:
bnlearn_compare = pd.read_csv('../../results_pure_aba/compare_semantics_bnlearn.csv')
bnlearn_compare.rename(
    columns={
        # Rename columns for clarity, the PR actually refers to maximally complete, not preferred
        col: col.replace('PR', 'CO_MAX') for col in bnlearn_compare.columns if 'PR' in col
    }, inplace=True
)
bnlearn_compare.head()

In [None]:
random_graphs_compare.columns

In [None]:
bnlearn_compare.columns

# Runtime

In [None]:
runtime_df = random_graphs_compare.groupby(['n_nodes', 'n_edges'], as_index=False).agg(
    ST_elapsed_mean=('ST_elapsed', 'mean'),
    ST_elapsed_std=('ST_elapsed', 'std'),
    CO_elapsed_mean=('CO_elapsed', 'mean'),
    CO_elapsed_std=('CO_elapsed', 'std'),
    CO_MAX_elapsed_mean=('CO_MAX_elapsed', 'mean'),
    CO_MAX_elapsed_std=('CO_MAX_elapsed', 'std'),
    existing_elapsed_mean=('abapc_existing_elapsed', 'mean'),
    existing_elapsed_std=('abapc_existing_elapsed', 'std'),
)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_runtime_custom(df, plot_width=750, plot_height=300, font_size=20, save_figs=False, output_name="random_graphs_runtime.html"):

    fig = make_subplots(rows=1, cols=1, shared_yaxes=True)
    colors = ['red', 'blue', 'green', 'orange']
    semantics = ['ST', 'CO', 'CO_MAX', 'existing']
    name_mapping = {
        'ST': 'ST',
        'CO': 'CO',
        'CO_MAX': 'CO-max',
        'existing': 'ST (with Existing Implementation)'
    }
    for color, sem in zip(colors, semantics):
        fig.add_trace(
            go.Scatter(
                x=df['n_nodes'].astype(str),
                y=df[f'{sem}_elapsed_mean'],
                error_y=dict(type='data', array=df[f'{sem}_elapsed_std'], thickness=2),
                mode='lines+markers',
                name=name_mapping[sem],
                line=dict(color=color, width=2),
                marker=dict(symbol='circle', size=8, color=color),
                opacity=0.8,
            )
        )


    # Log scale for y-axis
    fig.update_yaxes(type="log", title='log(elapsed time [s])')

    # X axis title
    fig.update_xaxes(title='Number of Nodes (|V|)')

    # Layout and style
    fig.update_layout(
        legend=dict(orientation="h", xanchor="center", x=0.5, yanchor="bottom", y=1.05),
        template='plotly_white',
        width=plot_width,
        height=plot_height,
        margin=dict(l=10, r=10, b=80, t=10),
        font=dict(size=font_size, family="Serif", color="black")
    )

    if save_figs:
        fig.write_html(output_name)
        fig.write_image(output_name.replace('.html', '.jpeg'), scale=2)

    fig.show()


In [None]:
plot_runtime_custom(runtime_df, save_figs=True, output_name='runtime_random_graphs.html')

# SID

In [None]:
import json

def read_sid_from_json(string_to_read, sid_type='low'):
    metrics = json.loads(string_to_read)
    sid = metrics['sid']
    if isinstance(sid, list):
        sid_low, sid_high = sid
    elif isinstance(sid, float):
        sid_low = sid
        sid_high = sid
    else:
        raise ValueError("Unexpected format for SID in JSON string.")
    
    if sid_type == 'low':
        return sid_low
    elif sid_type == 'high':
        return sid_high


sem_dfs = []
for sem in ['ST', 'CO', 'CO_MAX']:
    sem_df = bnlearn_compare[[f'{sem}_mt_cpdag', 'dataset_name', 'seed']].copy().rename(columns={'dataset_name': 'dataset'})
    sem_df['sid_low'] = sem_df[f'{sem}_mt_cpdag'].apply(lambda x: read_sid_from_json(x, 'low'))
    sem_df['sid_high'] = sem_df[f'{sem}_mt_cpdag'].apply(lambda x: read_sid_from_json(x, 'high'))
    sem_df['model'] = f'ABAPC (New {sem})'
    sem_dfs.append(sem_df[['dataset', 'model', 'seed', 'sid_low', 'sid_high']])

sem_df_combined = pd.concat(sem_dfs, ignore_index=True)

sem_df_combined = sem_df_combined.groupby(['dataset', 'model'], as_index=False).agg(
    sid_low_mean=('sid_low', 'mean'),
    sid_low_std=('sid_low', 'std'),
    sid_high_mean=('sid_high', 'mean'),
    sid_high_std=('sid_high', 'std')
)
sem_df_combined

In [None]:
main_gray = '#262626'
sec_gray = '#595959'
main_blue = '#005383'
sec_blue = '#0085CA'
main_green = '#379f9f' 
sec_green = '#196363' 
main_purple='#9454c4'
sec_purple='#441469'
main_orange='#8a4500'
sec_orange='#b85c00'


save_figs = True
debug = False
datasets = ['cancer', 'earthquake', 'survey', 'asia']
dags_nodes_map = {'asia':8, 'cancer':5, 'earthquake':5, 'sachs':11, 'survey':6, 'alarm':37, 'child':20, 'insurance':27, 'hailfinder':56, 'hepar2':70}
dags_arcs_map = {'asia':8, 'cancer':4, 'earthquake':4, 'sachs':17, 'survey':6, 'alarm':46, 'child':25, 'insurance':52, 'hailfinder':66, 'hepar2':123}
methods = ['Random', 'FGS', 'NOTEARS-MLP', 'MPC', 'ABAPC (Existing)', 'ABAPC (New ST)', 'ABAPC (New CO)']
names_dict = {'fgs':'FGS', 'nt':'NOTEARS-MLP', 'mpc':'MPC', 'random':'Random', 'abapc':'ABAPC (Existing)', 'ABAPC (New ST)':'ABAPC (New ST)',
              'ABAPC (New CO)':'ABAPC (New CO)', 'ABAPC (New CO-max)':'ABAPC (New CO-max)'} 
colors_dict = {'abapc':sec_blue,'fgs':sec_orange,'nt':main_purple,'mpc':main_green,'random':'grey', 'ABAPC (New ST)':'black',
               'ABAPC (New CO)':sec_orange, 'ABAPC (New CO-max)':sec_green}
version = 'bnlearn_50rep' ## for 5000 samples
# version = 'bnlearn_dag_v5_2000' ## for 2000 samples

version_cpdag = version+'_cpdag'
all_sum = pd.read_csv(f"../../results_pure_aba/stored_results_{version}_cpdag.csv")
all_sum.loc[all_sum['model']=='ABAPC (Ours)', 'model'] = 'ABAPC (Existing)'
# all_sum.loc[all_sum['model']=='ABAPC (ASPforABA)', 'model'] = 'ABAPC (New)'
all_sum = all_sum[['dataset', 'model', 'sid_low_mean', 'sid_low_std', 'sid_high_mean', 'sid_high_std']].copy()
all_sum = pd.concat([all_sum, sem_df_combined], ignore_index=True).reset_index(drop=True).copy()
all_sum.loc[all_sum['model']=='ABAPC (New CO_MAX)', 'model'] = 'ABAPC (New CO-max)'



all_sum['n_edges'] = all_sum['dataset'].map(dags_arcs_map)
all_sum['n_nodes'] = all_sum['dataset'].map(dags_nodes_map)
for var in ['SID_low','SID_high']:
    all_sum['p_'+var+'_mean'] = all_sum[var.lower()+'_mean'].astype(float)/all_sum['n_edges'].astype(int)
    all_sum['p_'+var+'_std'] = all_sum[var.lower()+'_std'].astype(float)/all_sum['n_edges'].astype(int)
all_sum['dataset'] = [a.upper() for a in all_sum["dataset"].astype(str)]
all_sum['dataset'] = all_sum['dataset'] +np.repeat("<br> |V|=",len(all_sum))+ all_sum["n_nodes"].astype(str)+np.repeat(", |E|=",len(all_sum))+\
                     all_sum["n_edges"].astype(str)


all_sum.head()

In [None]:
def double_bar_chart_plotly(all_sum, names_dict, colors_dict, 
                            methods=['Random', 'FGS', 'NOTEARS-MLP', 'Shapley-PC', 'ABAPC (Ours)'],
                            range_y1=None, range_y2=None, font_size=20,
                            save_figs=False, output_name="bar_chart.html", debug=False):
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    # for dataset_name in ['asia','cancer','earthquake','sachs','survey','alarm','child','insurance','hepar2']:
    # for method in ['Random', 'FGS', 'MCSL-MLP', 'NOTEARS-MLP', 'Max-PC', 'SPC (Ours)', 'ABAPC (Ours)']:

    vars_to_plot = ['p_SID_low','p_SID_high']
    for n, var_to_plot in enumerate(vars_to_plot):
        for m, method in enumerate(methods):
            trace_name = 'True Graph Size' if var_to_plot=='nnz' and method=='Random' else method#+' '+var_to_plot
            fig.add_trace(go.Bar(x=all_sum[(all_sum.model==method)]['dataset'], 
                                yaxis=f"y{n+1}",
                                offsetgroup=m+len(methods)*n+(1*n),
                                y=all_sum[(all_sum.model==method)][var_to_plot+'_mean'], 
                                error_y=dict(type='data', array=all_sum[(all_sum.model==method)][var_to_plot+'_std'], visible=True),
                                name=trace_name,
                                marker_color=colors_dict[list(names_dict.keys())[list(names_dict.values()).index(method)]],
                                opacity=0.6,
                                #  width=0.1
                                showlegend=n==0
                                ))
        if n==0:
            fig.add_trace(go.Bar(x=all_sum[(all_sum.model==method)]['dataset'], 
                                    y=np.zeros(len(all_sum[(all_sum.model==method)]['dataset'])), 
                                    name='',
                                    offsetgroup=m+1,
                                    marker_color='white',
                                    opacity=1,
                                    # width=0.1
                                    showlegend=False
                                    )
                                    )
    second_ticks = False if all('SID' in var for var in vars_to_plot) else True
    # Change the bar mode
    fig.update_layout(barmode='group',
                        bargap=0.15, # gap between bars of adjacent location coordinates.
                        bargroupgap=0.15, # gap between bars of the same location coordinate.)

            legend=dict(orientation="h", xanchor="center", x=0.5, yanchor="top", y=1.1),
            template='plotly_white',
            # autosize=True,
            width=1600, 
            height=700,
            margin=dict(
                l=40,
                r=00,
                b=70,
                t=20,
                # pad=10
            ),hovermode='x unified',
            font=dict(size=font_size, family="Serif", color="black"),
            yaxis2=dict(scaleanchor=0, showline=False, showgrid=False, showticklabels=second_ticks, zeroline=True),
            )

    fig.add_annotation(
        xref="paper",
        yref="paper",
        xanchor="center",
        x=0,
        yanchor="bottom",
        y=-0.08,
        text=f"Dataset:",
        showarrow=False,    
        font=dict(
                    family="Serif",
                    size=font_size,
                    color="Black"
                    )
        )
    
    for n, var_to_plot in enumerate(vars_to_plot):
        if vars_to_plot == ['precision', 'recall']:
            if range_y1 is None:
                range_y = [0, 1.3]
            else:
                range_y = range_y1
        elif vars_to_plot == ['fdr', 'tpr']:
            range_y = [0, 1]
        elif 'shd' in var_to_plot or 'SID' in var_to_plot:
            if range_y1 is None and range_y2 is None:
                if 'high' in var_to_plot:
                    range_y = [0, max(all_sum['p_SID_high_mean'])+.3]
                elif 'low' in var_to_plot:
                    range_y = [0, max(all_sum['p_SID_low_mean'])+.3]
                else:
                    range_y = [0, 2] if n==0 else [0, max(all_sum['p_SID_mean'])+.3]
            else:
                range_y = range_y1 if n==0 else range_y2
        if 'n_' in var_to_plot or 'p_' in var_to_plot:
            orig_y = var_to_plot.replace('n_','').replace('p_','').replace('_low','').replace('_high','').upper()
            fig.update_yaxes(title={'text':f'Normalised {orig_y}','font':{'size':font_size}}, secondary_y=n==1, range=range_y)
            if second_ticks == False:
                fig.update_yaxes(title={'text':'','font':{'size':font_size}}, secondary_y=True, range=range_y, showticklabels=False)
        elif var_to_plot=='nnz':
            orig_y = 'Number of Edges in DAG'
            fig.update_yaxes(title={'text':f'{orig_y}','font':{'size':font_size}}, secondary_y=n==1, range=range_y)
        else:
            fig.update_yaxes(title={'text':f'{var_to_plot.title()}','font':{'size':font_size}}, secondary_y=n==1, range=range_y)

    start_pos = 0.039

    name1 = 'Best'
    name2 = 'Worst'
    lin_space=9
    nl_space=9
    intra_dis = 0.161
    inter_dis = 0.174

    n_x_cat = len(all_sum.dataset.unique())
    list_of_pos = []
    left=start_pos
    for i in range(n_x_cat):
        right = left+intra_dis
        list_of_pos.append((left, right))
        left = right+inter_dis

    for s1,s2 in list_of_pos:
        fig.add_annotation(
            xref="x domain",
            yref="y domain",
            xanchor="left",
            x=s1,
            y=1.015,
                    text=f"{' '*lin_space}{name1}{' '*(lin_space)}",
            showarrow=False,    
            font=dict(
                # family="Courier New, monospace",
                size=font_size,
                color="black"
                )
        , bordercolor='#E5ECF6'
        , borderwidth=2
        , bgcolor="#E5ECF6"
        , opacity=0.8
                )
        fig.add_annotation(
            xref="x domain",
            yref="y domain",
            xanchor="left",
            x=s2,
            y=1.015,
                    text=f"{' '*(nl_space)}{name2}{' '*nl_space}",
            showarrow=False,    
            font=dict(
                # family="Courier New, monospace",
                size=font_size,
                color="black"
                )
        , bordercolor='#E5ECF6'
        , borderwidth=2
        , bgcolor="#E5ECF6"
        , opacity=0.8
                )
    
    # Add vertical lines between bar groups
    s1 = 0
    sign = -1
    for i in range(n_x_cat*2):  # skip last one
        sign *= -1
        s1 += 0.1565
        fig.add_shape(
            type="line",
            x0=s1, x1=s1,
            y0=0, y1=1,
            xref="paper",
            yref="paper",
            line=dict(
                color="grey" if sign>0 else "black",
                width=1,
                dash="dash" if sign>0 else "solid",
            ),
            layer="below"
        )
        


    if save_figs:
        fig.write_html(output_name)
        fig.write_image(output_name.replace('.html','.jpeg'))

    fig.show()

In [None]:
double_bar_chart_plotly(all_sum, names_dict, colors_dict, methods, save_figs=save_figs, output_name="./Fig.2_SID_cpdag.html", debug=False, range_y1=[0,6], range_y2=[0,6])#

# Checks

In [None]:
random_graphs_compare.columns

In [None]:
random_graphs_compare[[
       'n_nodes', 'n_edges',
       'is_best_st_in_all_co', 'is_best_st_in_all_pr', 'is_best_pr_in_all_co',
       'is_all_st_subset_of_all_co', 'is_all_st_subset_of_all_pr',
       'is_all_pr_subset_of_all_co']].groupby(['n_nodes', 'n_edges']
                                              ).mean().T

In [None]:
print("cases when best stable model is stronger than best preferred model: ", 
      random_graphs_compare[random_graphs_compare['ST_best_I'] > random_graphs_compare['CO_MAX_best_I']].shape[0] / len(random_graphs_compare) * 100, '%')
print("cases when best stable model is weaker than best preferred model: ", 
      random_graphs_compare[random_graphs_compare['ST_best_I'] < random_graphs_compare['CO_MAX_best_I']].shape[0] / len(random_graphs_compare) * 100, '%')
print("cases when best preferred model is equal to the best complete model: ",
      random_graphs_compare[random_graphs_compare['CO_best_model']== random_graphs_compare['CO_MAX_best_model']].shape[0] / len(random_graphs_compare) * 100, '%')
print("cases when the strength of the best preferred model is equal to the strength of the best complete model: ",
      random_graphs_compare[random_graphs_compare['CO_best_I']== random_graphs_compare['CO_MAX_best_I']].shape[0] / len(random_graphs_compare) * 100, '%')

In [None]:
all(random_graphs_compare['CO_used_num_facts'] == random_graphs_compare['CO_total_num_facts']), all(random_graphs_compare['CO_MAX_used_num_facts'] == random_graphs_compare['CO_MAX_total_num_facts']), all(random_graphs_compare['ST_used_num_facts'] == random_graphs_compare['ST_total_num_facts'])

In [None]:
# Sanity check: all best models are DAGs

sys.path.insert(0, '../../ArgCausalDisco')
from ArgCausalDisco.utils.graph_utils import is_dag
from src.utils.utils import get_matrix_from_arrow_set

for sem in ['ST', 'CO', 'CO_MAX']:
    is_dag_list = []
    for n_nodes, model in zip(random_graphs_compare['n_nodes'], random_graphs_compare[f'{sem}_best_model']):
        model = eval(model)
        graph_matrix = get_matrix_from_arrow_set(model, n_nodes)
        is_dag_list.append(is_dag(graph_matrix))

    print(f"Is all {sem} best models a DAGs? {all(is_dag_list)}")

## Which model is stronger ST or CO

### Random graphs

In [None]:
strength_agg_func = lambda df: pd.Series({
        "ST_CO_equal_strength": (df["ST_best_I"] == df["CO_best_I"]).mean(),
        "ST_stronger": (df["ST_best_I"] > df["CO_best_I"]).mean(),
        "CO_stronger": (df["CO_best_I"] > df["ST_best_I"]).mean(),
    })

In [None]:
random_graphs_strength_compare = random_graphs_compare.groupby(['n_nodes', 'n_edges']).apply(strength_agg_func).reset_index()
random_graphs_strength_compare.head()

In [None]:
bn_graphs_strength_compare = bnlearn_compare.groupby('dataset_name').apply(strength_agg_func).reset_index()

# process the dataset names to make it look nicer in the plots
bn_graphs_strength_compare['n_edges'] = bn_graphs_strength_compare['dataset_name'].map(dags_arcs_map)
bn_graphs_strength_compare['n_nodes'] = bn_graphs_strength_compare['dataset_name'].map(dags_nodes_map)
bn_graphs_strength_compare['dataset_name'] = [a.upper() for a in bn_graphs_strength_compare["dataset_name"].astype(str)]
bn_graphs_strength_compare['dataset_name'] = (bn_graphs_strength_compare['dataset_name'] 
    + np.repeat("<br> |V|=",len(bn_graphs_strength_compare))
    + bn_graphs_strength_compare["n_nodes"].astype(str)
    + np.repeat(", |E|=",len(bn_graphs_strength_compare))
    + bn_graphs_strength_compare["n_edges"].astype(str))
bn_graphs_strength_compare.head()

In [None]:
import plotly.express as px


def plot_strength_comparison(df, id_vars, xaxis_title, filename):
    # Custom legend labels
    legend_map = {
        "ST_CO_equal_strength": "Equal Strength",
        "ST_stronger": "ST Best Model is Stronger",
        "CO_stronger": "CO Best Model is stronger"
    }

    # Melt the DataFrame for plotting
    df_melted = df.melt(
        id_vars=id_vars,
        value_vars=list(legend_map.keys()),
        var_name="Case",
        value_name="Proportion"
    )

    # Replace Case names using the mapping
    df_melted["Case"] = df_melted["Case"].map(legend_map)

    # Create the plot
    fig = px.bar(
        df_melted,
        x=id_vars,
        y="Proportion",
        color="Case",
        barmode="stack"
    )

    # Format axes and layout
    fig.update_layout(
        xaxis_title=xaxis_title,
        yaxis_title="Proportion",
        yaxis_range=[0, 1]
    )

    # Ensure x-ticks are only integers
    fig.update_xaxes(tickmode='linear', dtick=1)
    fig.update_layout(legend_title_text=None)
    fig.update_yaxes(
        showgrid=True,        # enable grid
        gridcolor='lightgray',  # color of grid lines
        gridwidth=1           # thickness of grid lines
    )
    fig.write_image(f'{filename}.jpeg', width=800, height=400, scale=3)
    fig.show()


In [None]:
plot_strength_comparison(random_graphs_strength_compare, 'n_nodes', "Number of Nodes (|V|)", "strength_compare_random")

In [None]:
plot_strength_comparison(bn_graphs_strength_compare, 'dataset_name', "Dataset", "strength_compare_bnlearn")