In [1]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
from pathlib import Path

# Set plotly theme
pio.templates.default = "plotly_white"

# Load the data
data_path = Path("/Users/yeva/imperial/master-proj/gradual-sem-causal-aba/results/gradual/v2_ablation_random_graphs_7nodes/cpdag_metrics.csv")
df = pd.read_csv(data_path)

print(f"Data shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")
print("\nFirst few rows:")
df.head()

Data shape: (850, 30)
Columns: ['nnz', 'fdr', 'tpr', 'fpr', 'precision', 'recall', 'F1', 'shd', 'sid_low', 'sid_high', 'dataset', 'seed', 'n_nodes', 'n_edges', 'neighbourhood_n_nodes', 'max_cycle_length', 'max_ct_depth', 'max_path_length', 'max_c_set_size', 'search_depth', 'elapsed_bsaf_creation', 'elapsed_model_solution', 'is_converged', 'fact_ranking_method', 'model_ranking_method', 'num_edges_est', 'best_model', 'aba_elapsed', 'ranking_elapsed', 'best_I']

First few rows:


Unnamed: 0,nnz,fdr,tpr,fpr,precision,recall,F1,shd,sid_low,sid_high,...,elapsed_bsaf_creation,elapsed_model_solution,is_converged,fact_ranking_method,model_ranking_method,num_edges_est,best_model,aba_elapsed,ranking_elapsed,best_I
0,0,0.0,0.0,0.0,,0.0,,7.0,22.0,22.0,...,136.762595,429.702305,True,v2,original_ranking,0,[],0.013444,0.046985,-2.454362
1,0,0.0,0.0,0.0,,0.0,,7.0,22.0,22.0,...,136.762595,429.702305,True,v2,original_ranking,0,[],0.013444,0.093017,-2.454362
2,0,0.0,0.0,0.0,,0.0,,7.0,22.0,22.0,...,136.762595,429.702305,True,v2,original_ranking,0,[],0.013444,0.186441,-2.454362
3,0,0.0,0.0,0.0,,0.0,,7.0,22.0,22.0,...,136.762595,429.702305,True,v2,original_ranking,0,[],0.013444,0.375774,-2.454362
4,0,0.0,0.0,0.0,,0.0,,7.0,22.0,22.0,...,136.762595,429.702305,True,v2,original_ranking,0,[],0.013444,0.757659,-2.454362


In [6]:
df.columns

Index(['nnz', 'fdr', 'tpr', 'fpr', 'precision', 'recall', 'F1', 'shd',
       'sid_low', 'sid_high', 'dataset', 'seed', 'n_nodes', 'n_edges',
       'neighbourhood_n_nodes', 'max_cycle_length', 'max_ct_depth',
       'max_path_length', 'max_c_set_size', 'search_depth',
       'elapsed_bsaf_creation', 'elapsed_model_solution', 'is_converged',
       'fact_ranking_method', 'model_ranking_method', 'num_edges_est',
       'best_model', 'aba_elapsed', 'ranking_elapsed', 'best_I'],
      dtype='object')

In [7]:
df['neighbourhood_n_nodes'] = df['max_cycle_length']
df['use_collider_arguments'] = df['max_ct_depth'] > -1  # Assuming -1 indicates no collider arguments used

In [15]:
# Explore the unique values for ablation parameters
print("Unique values for ablation parameters:")
print(f"neighbourhood_n_nodes: {sorted(df['neighbourhood_n_nodes'].unique())}")
print(f"use_collider_arguments: {sorted(df['use_collider_arguments'].unique())}")
print(f"max_c_set_size: {sorted(df['max_c_set_size'].unique())}")
print(f"search_depth: {sorted(df['search_depth'].unique())}")

# Compute mean and std across runs, grouped by the ablation parameters
groupby_cols = ['neighbourhood_n_nodes', 'use_collider_arguments', 'max_c_set_size', 'search_depth']
metrics_of_interest = ['sid_low', 'sid_high']

# Group by ablation parameters and compute statistics
grouped_stats = df.groupby(groupby_cols)[metrics_of_interest].agg(['mean', 'std', 'count']).reset_index()

# Flatten column names
grouped_stats.columns = [col[0] if col[1] == '' else f'{col[0]}_{col[1]}' for col in grouped_stats.columns]

print(f"\nGrouped statistics shape: {grouped_stats.shape}")
grouped_stats.head(10)

Unique values for ablation parameters:
neighbourhood_n_nodes: [np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7)]
use_collider_arguments: [np.False_, np.True_]
max_c_set_size: [np.int64(0), np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5)]
search_depth: [np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10)]

Grouped statistics shape: (17, 10)


Unnamed: 0,neighbourhood_n_nodes,use_collider_arguments,max_c_set_size,search_depth,sid_low_mean,sid_low_std,sid_low_count,sid_high_mean,sid_high_std,sid_high_count
0,3,True,5,10,19.36,4.822397,50,21.42,5.522089,50
1,4,True,5,10,17.74,5.00942,50,22.72,6.857262,50
2,5,True,5,10,18.08,5.375075,50,22.48,6.810077,50
3,6,True,5,10,18.02,5.460395,50,22.46,6.842872,50
4,7,False,5,10,18.64,3.685382,50,23.44,5.89763,50
5,7,True,0,10,18.48,5.003019,50,23.28,5.547715,50
6,7,True,1,10,18.16,6.109077,50,22.6,6.948792,50
7,7,True,2,10,18.02,5.460395,50,22.46,6.842872,50
8,7,True,3,10,18.76,5.208157,50,22.7,6.609363,50
9,7,True,4,10,18.76,5.208157,50,22.7,6.609363,50


In [9]:
all(grouped_stats.sid_high_count == 50)

True

In [16]:
# Function to create ablation plots for a specific parameter
def create_ablation_plot(param_name, fixed_values_dict=None, n_edges=7):
    """
    Create ablation plot for a specific parameter while keeping others fixed.
    
    Args:
        param_name: The parameter to vary ('neighbourhood_n_nodes', 'use_collider_arguments', 'max_c_set_size', 'search_depth')
        fixed_values_dict: Dictionary of fixed values for other parameters
    """
    
    # Default fixed values (median or reasonable defaults)
    default_fixed = {
        'neighbourhood_n_nodes': 7,
        'use_collider_arguments': True, 
        'max_c_set_size': 5,
        'search_depth': 10
    }
    
    if fixed_values_dict:
        default_fixed.update(fixed_values_dict)
    
    # Remove the varying parameter from fixed values
    varying_param = param_name
    fixed_params = {k: v for k, v in default_fixed.items() if k != varying_param}
    
    # Filter data for fixed parameter values
    filtered_data = df.copy()
    for param, value in fixed_params.items():
        filtered_data = filtered_data[filtered_data[param] == value]
    
    if len(filtered_data) == 0:
        print(f"No data found for the specified fixed parameters: {fixed_params}")
        return None
    
    # Group by the varying parameter
    param_stats = filtered_data.groupby(varying_param)[['sid_low', 'sid_high']].agg(['mean', 'std', 'count']).reset_index()
    param_stats.columns = [col[0] if col[1] == '' else f'{col[0]}_{col[1]}' for col in param_stats.columns]

    param_stats['sid_low_mean'] = param_stats['sid_low_mean'] / n_edges
    param_stats['sid_high_mean'] = param_stats['sid_high_mean'] / n_edges
    param_stats['sid_low_std'] = param_stats['sid_low_std'] / n_edges
    param_stats['sid_high_std'] = param_stats['sid_high_std'] / n_edges
    
    # Create subplot with two y-axes
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=['SID Low', 'SID High'],
        horizontal_spacing=0.1
    )
    
    # Plot SID Low
    fig.add_trace(
        go.Scatter(
            x=param_stats[varying_param],
            y=param_stats['sid_low_mean'],
            error_y=dict(type='data', array=param_stats['sid_low_std'], visible=True),
            mode='lines+markers',
            name='SID Low',
            line=dict(width=3),
            marker=dict(size=8)
        ),
        row=1, col=1
    )
    
    # Plot SID High  
    fig.add_trace(
        go.Scatter(
            x=param_stats[varying_param],
            y=param_stats['sid_high_mean'],
            error_y=dict(type='data', array=param_stats['sid_high_std'], visible=True),
            mode='lines+markers',
            name='SID High',
            line=dict(width=3),
            marker=dict(size=8)
        ),
        row=1, col=2
    )
    
    # Update layout
    title_text = f'Ablation Analysis: {varying_param}'
    if fixed_params:
        fixed_str = ', '.join([f'{k}={v}' for k, v in fixed_params.items()])
        title_text += f'<br><sub>Fixed parameters: {fixed_str}</sub>'
    
    fig.update_layout(
        title=title_text,
        height=400,
        showlegend=False
    )
    
    # Update x-axis labels
    fig.update_xaxes(title_text=varying_param.replace('_', ' ').title(), row=1, col=1)
    fig.update_xaxes(title_text=varying_param.replace('_', ' ').title(), row=1, col=2)
    
    # Update y-axis labels
    fig.update_yaxes(title_text='SID Low (Mean ± Std)', row=1, col=1)
    fig.update_yaxes(title_text='SID High (Mean ± Std)', row=1, col=2)
    
    fig.show()
    
    return param_stats



In [17]:
# 1. Neighbourhood N Nodes ablation
print("\n1. Neighbourhood N Nodes Ablation")
stats_cycle = create_ablation_plot('neighbourhood_n_nodes')

# 2. Use collider arguments ablation
print("\n2. Use Collider Arguments Ablation")
stats_ct_depth = create_ablation_plot('use_collider_arguments',)

# 3. Max C Set Size ablation  
print("\n3. Max C Set Size Ablation")
stats_c_set = create_ablation_plot('max_c_set_size')

# 4. Search Depth ablation
print("\n4. Search Depth Ablation") 
stats_search = create_ablation_plot('search_depth')


1. Neighbourhood N Nodes Ablation



1. Neighbourhood N Nodes Ablation



2. Use Collider Arguments Ablation



1. Neighbourhood N Nodes Ablation



2. Use Collider Arguments Ablation



3. Max C Set Size Ablation



1. Neighbourhood N Nodes Ablation



2. Use Collider Arguments Ablation



3. Max C Set Size Ablation



4. Search Depth Ablation


In [30]:
# Function to create ablation plots for a specific parameter
def create_ablation_plot2(param_name, fixed_values_dict=None, n_edges=7):
    """
    Create ablation plot for a specific parameter while keeping others fixed.
    
    Args:
        param_name: The parameter to vary ('neighbourhood_n_nodes', 'use_collider_arguments', 'max_c_set_size', 'search_depth')
        fixed_values_dict: Dictionary of fixed values for other parameters
    """
    
    # Default fixed values (median or reasonable defaults)
    default_fixed = {
        'neighbourhood_n_nodes': 7,
        'use_collider_arguments': True, 
        'max_c_set_size': 5,
        'search_depth': 10
    }
    
    if fixed_values_dict:
        default_fixed.update(fixed_values_dict)
    
    # Remove the varying parameter from fixed values
    varying_param = param_name
    fixed_params = {k: v for k, v in default_fixed.items() if k != varying_param}
    
    # Filter data for fixed parameter values
    filtered_data = df.copy()
    for param, value in fixed_params.items():
        filtered_data = filtered_data[filtered_data[param] == value]
    
    if len(filtered_data) == 0:
        print(f"No data found for the specified fixed parameters: {fixed_params}")
        return None
    
    
    filtered_data['sid_low'] = filtered_data['sid_low'] / n_edges
    filtered_data['sid_high'] = filtered_data['sid_high'] / n_edges


    param_stats = filtered_data[filtered_data[varying_param] != default_fixed[varying_param]]
    # param_stats = filtered_data.copy()

    baseline = filtered_data[filtered_data[varying_param] == default_fixed[varying_param]]

    non_varying_params = list(fixed_params.keys())
    param_stats = param_stats.merge(baseline[[*non_varying_params, 'seed', 'sid_low', 'sid_high']], on=['seed', *non_varying_params], suffixes=('', '_baseline'), how='left')
    
    param_stats['sid_low'] = (param_stats['sid_low'] - param_stats['sid_low_baseline']) / param_stats['sid_low_baseline']
    param_stats['sid_high'] = (param_stats['sid_high'] - param_stats['sid_high_baseline']) / param_stats['sid_high_baseline']
    
    # # Group by the varying parameter
    param_stats = param_stats.groupby(varying_param)[['sid_low', 'sid_high']].agg(['mean', 'std', 'count']).reset_index()
    param_stats.columns = [col[0] if col[1] == '' else f'{col[0]}_{col[1]}' for col in param_stats.columns]


    # Create subplot with two y-axes
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=['SID Low', 'SID High'],
        horizontal_spacing=0.1
    )
    
    # Plot SID Low
    fig.add_trace(
        go.Scatter(
            x=param_stats[varying_param],
            y=param_stats['sid_low_mean'],
            error_y=dict(type='data', array=param_stats['sid_low_std'], visible=True),
            mode='lines+markers',
            name='SID Low',
            line=dict(width=3),
            marker=dict(size=8)
        ),
        row=1, col=1
    )
    
    # Plot SID High  
    fig.add_trace(
        go.Scatter(
            x=param_stats[varying_param],
            y=param_stats['sid_high_mean'],
            error_y=dict(type='data', array=param_stats['sid_high_std'], visible=True),
            mode='lines+markers',
            name='SID High',
            line=dict(width=3),
            marker=dict(size=8)
        ),
        row=1, col=2
    )
    
    # Update layout
    title_text = f'Ablation Analysis: {varying_param}'
    if fixed_params:
        fixed_str = ', '.join([f'{k}={v}' for k, v in fixed_params.items()])
        title_text += f'<br><sub>Fixed parameters: {fixed_str}</sub>'
    
    fig.update_layout(
        title=title_text,
        height=400,
        showlegend=False
    )
    
    # Update x-axis labels
    fig.update_xaxes(title_text=varying_param.replace('_', ' ').title(), row=1, col=1)
    fig.update_xaxes(title_text=varying_param.replace('_', ' ').title(), row=1, col=2)
    
    # Update y-axis labels
    fig.update_yaxes(title_text='SID Low (Mean ± Std)', row=1, col=1)
    fig.update_yaxes(title_text='SID High (Mean ± Std)', row=1, col=2)
    
    fig.show()
    
    return param_stats



In [31]:
# 1. Neighbourhood N Nodes ablation
print("\n1. Neighbourhood N Nodes Ablation")
stats_cycle = create_ablation_plot2('neighbourhood_n_nodes')

# 2. Use collider arguments ablation
print("\n2. Use Collider Arguments Ablation")
stats_ct_depth = create_ablation_plot2('use_collider_arguments',)

# 3. Max C Set Size ablation  
print("\n3. Max C Set Size Ablation")
stats_c_set = create_ablation_plot2('max_c_set_size')

# 4. Search Depth ablation
print("\n4. Search Depth Ablation") 
stats_search = create_ablation_plot2('search_depth')


1. Neighbourhood N Nodes Ablation



2. Use Collider Arguments Ablation



3. Max C Set Size Ablation



4. Search Depth Ablation
