In [None]:
import sys
sys.path.insert(0, '../../')

In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# SID

In [None]:
main_gray = '#262626'
sec_gray = '#595959'
main_blue = '#005383'
sec_blue = '#0085CA'
main_green = '#379f9f' 
sec_green = '#196363' 
main_purple='#9454c4'
sec_purple='#441469'
main_orange='#8a4500'
sec_orange='#b85c00'


save_figs = True
debug = False
datasets = ['cancer', 'earthquake', 'survey', 'asia']
dags_nodes_map = {'asia':8, 'cancer':5, 'earthquake':5, 'sachs':11, 'survey':6, 'alarm':37, 'child':20, 'insurance':27, 'hailfinder':56, 'hepar2':70}
dags_arcs_map = {'asia':8, 'cancer':4, 'earthquake':4, 'sachs':17, 'survey':6, 'alarm':46, 'child':25, 'insurance':52, 'hailfinder':66, 'hepar2':123}
methods = ['Random', 'FGS', 'NOTEARS-MLP', 'MPC', 'ABAPC (Existing)', 'ABAPC (ASPforABA)']
names_dict = {'fgs':'FGS', 'nt':'NOTEARS-MLP', 'mpc':'MPC', 'random':'Random', 'abapc':'ABAPC (Existing)', 'ABAPC (ASPforABA)': 'ABAPC (ASPforABA)'} 
colors_dict = {'abapc':sec_blue,'fgs':sec_orange,'nt':main_purple,'mpc':main_green,'random':'grey', 'ABAPC (ASPforABA)':'black'}
version = 'bnlearn_50rep' ## for 5000 samples
# version = 'bnlearn_dag_v5_2000' ## for 2000 samples

version_cpdag = version+'_cpdag'
all_sum = pd.read_csv(f"../../results_pure_aba/stored_results_{version}_cpdag.csv")
all_sum.loc[all_sum['model']=='ABAPC (Ours)', 'model'] = 'ABAPC (Existing)'
# all_sum.loc[all_sum['model']=='ABAPC (ASPforABA)', 'model'] = 'ABAPC (New)'
all_sum = all_sum[['dataset', 'model', 'sid_low_mean', 'sid_low_std', 'sid_high_mean', 'sid_high_std']].copy()



all_sum['n_edges'] = all_sum['dataset'].map(dags_arcs_map)
all_sum['n_nodes'] = all_sum['dataset'].map(dags_nodes_map)
for var in ['SID_low','SID_high']:
    all_sum['p_'+var+'_mean'] = all_sum[var.lower()+'_mean'].astype(float)/all_sum['n_edges'].astype(int)
    all_sum['p_'+var+'_std'] = all_sum[var.lower()+'_std'].astype(float)/all_sum['n_edges'].astype(int)
all_sum['dataset'] = [a.upper() for a in all_sum["dataset"].astype(str)]
all_sum['dataset'] = all_sum['dataset'] +np.repeat("<br> |V|=",len(all_sum))+ all_sum["n_nodes"].astype(str)+np.repeat(", |E|=",len(all_sum))+\
                     all_sum["n_edges"].astype(str)


all_sum.head()

In [None]:
def double_bar_chart_plotly(all_sum, names_dict, colors_dict, 
                            methods=['Random', 'FGS', 'NOTEARS-MLP', 'Shapley-PC', 'ABAPC (Existing)', 'ABAPC (ASPforABA)'],
                            range_y1=None, range_y2=None, font_size=20,
                            save_figs=False, output_name="bar_chart.html", debug=False):
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    # for dataset_name in ['asia','cancer','earthquake','sachs','survey','alarm','child','insurance','hepar2']:
    # for method in ['Random', 'FGS', 'MCSL-MLP', 'NOTEARS-MLP', 'Max-PC', 'SPC (Ours)', 'ABAPC (Ours)']:

    vars_to_plot = ['p_SID_low','p_SID_high']
    for n, var_to_plot in enumerate(vars_to_plot):
        for m, method in enumerate(methods):
            trace_name = 'True Graph Size' if var_to_plot=='nnz' and method=='Random' else method#+' '+var_to_plot
            fig.add_trace(go.Bar(x=all_sum[(all_sum.model==method)]['dataset'], 
                                yaxis=f"y{n+1}",
                                offsetgroup=m+len(methods)*n+(1*n),
                                y=all_sum[(all_sum.model==method)][var_to_plot+'_mean'], 
                                error_y=dict(type='data', array=all_sum[(all_sum.model==method)][var_to_plot+'_std'], visible=True),
                                name=trace_name,
                                marker_color=colors_dict[list(names_dict.keys())[list(names_dict.values()).index(method)]],
                                opacity=0.6,
                                #  width=0.1
                                showlegend=n==0
                                ))
        if n==0:
            fig.add_trace(go.Bar(x=all_sum[(all_sum.model==method)]['dataset'], 
                                    y=np.zeros(len(all_sum[(all_sum.model==method)]['dataset'])), 
                                    name='',
                                    offsetgroup=m+1,
                                    marker_color='white',
                                    opacity=1,
                                    # width=0.1
                                    showlegend=False
                                    )
                                    )
    second_ticks = False if all('SID' in var for var in vars_to_plot) else True
    # Change the bar mode
    fig.update_layout(barmode='group',
                        bargap=0.15, # gap between bars of adjacent location coordinates.
                        bargroupgap=0.15, # gap between bars of the same location coordinate.)

            legend=dict(orientation="h", xanchor="center", x=0.5, yanchor="top", y=1.1),
            template='plotly_white',
            # autosize=True,
            width=1600, 
            height=700,
            margin=dict(
                l=40,
                r=00,
                b=70,
                t=20,
                # pad=10
            ),hovermode='x unified',
            font=dict(size=font_size, family="Serif", color="black"),
            yaxis2=dict(scaleanchor=0, showline=False, showgrid=False, showticklabels=second_ticks, zeroline=True),
            )

    fig.add_annotation(
        xref="paper",
        yref="paper",
        xanchor="center",
        x=0,
        yanchor="bottom",
        y=-0.08,
        text=f"Dataset:",
        showarrow=False,    
        font=dict(
                    family="Serif",
                    size=font_size,
                    color="Black"
                    )
        )
    
    for n, var_to_plot in enumerate(vars_to_plot):
        if vars_to_plot == ['precision', 'recall']:
            if range_y1 is None:
                range_y = [0, 1.3]
            else:
                range_y = range_y1
        elif vars_to_plot == ['fdr', 'tpr']:
            range_y = [0, 1]
        elif 'shd' in var_to_plot or 'SID' in var_to_plot:
            if range_y1 is None and range_y2 is None:
                if 'high' in var_to_plot:
                    range_y = [0, max(all_sum['p_SID_high_mean'])+.3]
                elif 'low' in var_to_plot:
                    range_y = [0, max(all_sum['p_SID_low_mean'])+.3]
                else:
                    range_y = [0, 2] if n==0 else [0, max(all_sum['p_SID_mean'])+.3]
            else:
                range_y = range_y1 if n==0 else range_y2
        if 'n_' in var_to_plot or 'p_' in var_to_plot:
            orig_y = var_to_plot.replace('n_','').replace('p_','').replace('_low','').replace('_high','').upper()
            fig.update_yaxes(title={'text':f'Normalised {orig_y}','font':{'size':font_size}}, secondary_y=n==1, range=range_y)
            if second_ticks == False:
                fig.update_yaxes(title={'text':'','font':{'size':font_size}}, secondary_y=True, range=range_y, showticklabels=False)
        elif var_to_plot=='nnz':
            orig_y = 'Number of Edges in DAG'
            fig.update_yaxes(title={'text':f'{orig_y}','font':{'size':font_size}}, secondary_y=n==1, range=range_y)
        else:
            fig.update_yaxes(title={'text':f'{var_to_plot.title()}','font':{'size':font_size}}, secondary_y=n==1, range=range_y)

    start_pos = 0.039

    name1 = 'Best'
    name2 = 'Worst'
    lin_space=9
    nl_space=9
    intra_dis = 0.161
    inter_dis = 0.174

    n_x_cat = len(all_sum.dataset.unique())
    list_of_pos = []
    left=start_pos
    for i in range(n_x_cat):
        right = left+intra_dis
        list_of_pos.append((left, right))
        left = right+inter_dis

    for s1,s2 in list_of_pos:
        fig.add_annotation(
            xref="x domain",
            yref="y domain",
            xanchor="left",
            x=s1,
            y=1.015,
                    text=f"{' '*lin_space}{name1}{' '*(lin_space)}",
            showarrow=False,    
            font=dict(
                # family="Courier New, monospace",
                size=font_size,
                color="black"
                )
        , bordercolor='#E5ECF6'
        , borderwidth=2
        , bgcolor="#E5ECF6"
        , opacity=0.8
                )
        fig.add_annotation(
            xref="x domain",
            yref="y domain",
            xanchor="left",
            x=s2,
            y=1.015,
                    text=f"{' '*(nl_space)}{name2}{' '*nl_space}",
            showarrow=False,    
            font=dict(
                # family="Courier New, monospace",
                size=font_size,
                color="black"
                )
        , bordercolor='#E5ECF6'
        , borderwidth=2
        , bgcolor="#E5ECF6"
        , opacity=0.8
                )
    
    # Add vertical lines between bar groups
    s1 = 0
    sign = -1
    for i in range(n_x_cat*2):  # skip last one
        sign *= -1
        s1 += 0.1565
        fig.add_shape(
            type="line",
            x0=s1, x1=s1,
            y0=0, y1=1,
            xref="paper",
            yref="paper",
            line=dict(
                color="grey" if sign>0 else "black",
                width=1,
                dash="dash" if sign>0 else "solid",
            ),
            layer="below"
        )
        


    if save_figs:
        fig.write_html(output_name)
        fig.write_image(output_name.replace('.html','.jpeg'))

    fig.show()

In [None]:
double_bar_chart_plotly(all_sum, names_dict, colors_dict, methods, save_figs=save_figs, output_name="./Fig.2_SID_cpdag.html", debug=False, range_y1=[0,6], range_y2=[0,6])#