In [27]:
import pandas as pd
other_methods_wo_aba = pd.read_csv('../../results/existing_wo_aba/all_existing_methods_metrics_cpdag.csv')
min_scale_v2 = pd.read_csv('../../results/gradual/min_scale_v2/cpdag_metrics.csv')
min_scale_v2_no_colliders = pd.read_csv('../../results/gradual/min_scale_v2_no_collider/cpdag_metrics.csv')
fgs_res = pd.read_csv('../../results/existing/fgs/all_existing_methods_metrics_cpdag.csv')
other_methods_wo_aba = pd.concat([other_methods_wo_aba, fgs_res], ignore_index=True)

In [28]:
DAG_ARCS_MAP = {'asia':8, 'cancer':4, 'earthquake':4, 'sachs':17, 'survey':6, 'alarm':46, 'child':25, 'insurance':52, 'hailfinder':66, 'hepar2':123}
DAG_NODES_MAP = {'asia':8, 'cancer':5, 'earthquake':5, 'sachs':11, 'survey':6, 'alarm':37, 'child':20, 'insurance':27, 'hailfinder':56, 'hepar2':70}

min_scale_v2_no_colliders['model'] = 'min_scale_v2_nc'
min_scale_v2['model'] = 'min_scale_v2'

def process_data(df):
    groupby_cols = ['dataset', 'model']
    df_grouped = df.groupby(groupby_cols, as_index=False).aggregate(
        sid_low_mean=('sid_low', 'mean'),
        sid_high_mean=('sid_high', 'mean'),
        sid_low_std=('sid_low', 'std'),
        sid_high_std=('sid_high', 'std'),
    )
    df_grouped['n_edges'] = df_grouped['dataset'].map(DAG_ARCS_MAP)
    df_grouped['n_nodes'] = df_grouped['dataset'].map(DAG_NODES_MAP)

    df_grouped['p_SID_low_mean'] = df_grouped['sid_low_mean'] / df_grouped['n_edges']
    df_grouped['p_SID_high_mean'] = df_grouped['sid_high_mean'] / df_grouped['n_edges']
    df_grouped['p_SID_low_std'] = df_grouped['sid_low_std'] / df_grouped['n_edges']
    df_grouped['p_SID_high_std'] = df_grouped['sid_high_std'] / df_grouped['n_edges']
    return df_grouped

min_scale_v2_no_colliders_processed = process_data(min_scale_v2_no_colliders)
min_scale_v2_processed = process_data(min_scale_v2)
other_methods_wo_aba_processed = process_data(other_methods_wo_aba)

In [30]:
data_to_plot_columns = ['dataset', 'n_nodes', 'n_edges', 'model', 'p_SID_low_mean', 'p_SID_high_mean', 'p_SID_low_std', 'p_SID_high_std']

data_to_plot = pd.concat([min_scale_v2_no_colliders_processed[data_to_plot_columns], 
                         other_methods_wo_aba_processed[data_to_plot_columns],
                         min_scale_v2_processed[data_to_plot_columns]], ignore_index=True)
data_to_plot.head()

Unnamed: 0,dataset,n_nodes,n_edges,model,p_SID_low_mean,p_SID_high_mean,p_SID_low_std,p_SID_high_std
0,asia,8,8,min_scale_v2_nc,3.155,4.455,0.596311,0.656968
1,cancer,5,4,min_scale_v2_nc,0.97,3.49,0.453557,0.486973
2,child,20,25,min_scale_v2_nc,12.72,13.6,,
3,earthquake,5,4,min_scale_v2_nc,0.0,3.5,0.0,0.0
4,sachs,11,17,min_scale_v2_nc,3.096471,3.252941,0.161537,0.193551


In [31]:
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

def plot_sid_grouped_by_type(df: pd.DataFrame, model_order: list[str]):
    """
    Plot grouped bar chart showing Parent SID (Low/High) for each model, with spacing between groups.

    Args:
        df (pd.DataFrame): DataFrame for one dataset with columns:
            - dataset, n_nodes, n_edges, model
            - p_SID_low_mean, p_SID_high_mean
            - p_SID_low_std, p_SID_high_std
        model_order (list[str]): List of models in desired order
    """
    # Ensure only one dataset
    dataset_labels = df.apply(lambda row: f"{row['dataset']} |V|={row['n_nodes']}, |E|={row['n_edges']}", axis=1).unique()
    if len(dataset_labels) != 1:
        raise ValueError("Function only supports a single dataset.")
    dataset_label = dataset_labels[0]

    # Melt into long format
    df_low = df[["model", "p_SID_low_mean", "p_SID_low_std"]].copy()
    df_low["SID_type"] = "Low"
    df_low.rename(columns={"p_SID_low_mean": "SID_mean", "p_SID_low_std": "SID_std"}, inplace=True)

    df_high = df[["model", "p_SID_high_mean", "p_SID_high_std"]].copy()
    df_high["SID_type"] = "High"
    df_high.rename(columns={"p_SID_high_mean": "SID_mean", "p_SID_high_std": "SID_std"}, inplace=True)

    df_long = pd.concat([df_low, df_high], ignore_index=True)
    df_long = df_long[df_long["model"].isin(model_order)].copy()
    df_long["model"] = pd.Categorical(df_long["model"], categories=model_order, ordered=True)

    # Colors
    palette = px.colors.qualitative.Set2 + px.colors.qualitative.Plotly
    model_colors = {model: palette[i % len(palette)] for i, model in enumerate(model_order)}

    # Plotting positions
    n_models = len(model_order)
    group_spacing = n_models + 1  # space between Low and High
    x = []
    y = []
    errors = []
    colors = []
    model_names = []

    for sid_group_idx, sid_type in enumerate(["Low", "High"]):
        base = sid_group_idx * group_spacing
        for model_idx, model in enumerate(model_order):
            xpos = base + model_idx
            row = df_long[(df_long["model"] == model) & (df_long["SID_type"] == sid_type)]
            if not row.empty:
                x.append(xpos)
                y.append(row["SID_mean"].values[0])
                errors.append(row["SID_std"].values[0])
                colors.append(model_colors[model])
                model_names.append(model)

    # Main bars (no text labels)
    fig = go.Figure()
    fig.add_trace(go.Bar(
        x=x,
        y=y,
        error_y=dict(type='data', array=errors, visible=True),
        marker_color=colors,
        hovertemplate="Model: %{customdata}<br>SID: %{y:.2f}<extra></extra>",
        customdata=model_names,
        showlegend=False
    ))

    # Manual legend
    for model in model_order:
        fig.add_trace(go.Bar(
            x=[None], y=[None],
            marker_color=model_colors[model],
            name=model,
            showlegend=True
        ))

    # X-axis tick labels centered
    tick_positions = [
        (0 + (n_models - 1) / 2),
        (group_spacing + (n_models - 1) / 2)
    ]
    tick_labels = ["Low", "High"]

    fig.update_layout(
        title=f"Parent SID (Low vs High) — {dataset_label}",
        xaxis=dict(
            tickmode='array',
            tickvals=tick_positions,
            ticktext=tick_labels,
            title="SID Type"
        ),
        yaxis_title="Parent SID",
        bargap=0,
        showlegend=True,
        legend_title="Model",
        height=800,  # ← Increase this value to make the figure taller
        width=600
    )

    fig.show()


In [32]:
model_order=['Random', 
             'NOTEARS-MLP', 
             'FGS',
             'MPC',
             'min_scale_v2',
            'min_scale_v2_nc']

In [33]:
plot_sid_grouped_by_type(data_to_plot[data_to_plot['dataset']=='cancer'],
                    model_order=model_order)

In [34]:
plot_sid_grouped_by_type(data_to_plot[data_to_plot['dataset']=='earthquake'],
                    model_order=model_order)

In [35]:
plot_sid_grouped_by_type(data_to_plot[data_to_plot['dataset']=='survey'],
                    model_order=model_order)

In [36]:
plot_sid_grouped_by_type(data_to_plot[data_to_plot['dataset']=='asia'],
                    model_order=model_order)

In [37]:
plot_sid_grouped_by_type(data_to_plot[data_to_plot['dataset']=='sachs'],
                    model_order=model_order)

In [38]:
plot_sid_grouped_by_type(data_to_plot[data_to_plot['dataset']=='child'],
                    model_order=model_order)

In [39]:
plot_sid_grouped_by_type(data_to_plot[data_to_plot['dataset']=='insurance'],
                    model_order=model_order)