In [1]:
# run and measure metrics

In [2]:
import sys
sys.path.insert(0, '../..')
sys.path.insert(0, '../../ArgCausalDisco')
sys.path.insert(0, '../../notears')



In [3]:
from src.abapc import get_dataset, get_stable_arrow_sets, get_best_model
from src.abasp.utils import get_graph_matrix

from ArgCausalDisco.utils.graph_utils import DAGMetrics, dag2cpdag

  from .autonotebook import tqdm as notebook_tqdm
INFO:root:You can use `os.environ['CASTLE_BACKEND'] = backend` to set the backend(`pytorch` or `mindspore`).
INFO:root:You are using ``pytorch`` as the backend.


In [4]:
from pathlib import Path
from time import time
import numpy as np
import pandas as pd

In [5]:
import cdt

cdt.SETTINGS.rpath = '/usr/local/bin/Rscript'

In [6]:
def run_bnlearn_experiment(dataset_name, seed):
    X_s, B_true = get_dataset(dataset_name, seed)
    start = time()
    models, cg  = get_stable_arrow_sets(X_s, seed=seed)
    n_nodes = X_s.shape[1]
    _, B_est, _ = get_best_model(models, n_nodes, cg)
    elapsed = time() - start
    print(f"Elapsed time: {elapsed:.2f} seconds")

    B_est = (B_est != 0).astype(int)
    mt_cpdag = DAGMetrics(dag2cpdag(B_est), B_true).metrics
    mt_dag = DAGMetrics(B_est, B_true).metrics

    method_res = {'dataset':dataset_name, 
                  'model': 'ABASP (New)',
                  'elapsed':elapsed , 
                  **mt_dag}
    if type(mt_cpdag['sid'])==tuple:
        mt_sid_low = mt_cpdag['sid'][0]
        mt_sid_high = mt_cpdag['sid'][1]
    else:
        mt_sid_low = mt_cpdag['sid']
        mt_sid_high = mt_cpdag['sid']
    mt_cpdag.pop('sid')
    mt_cpdag['sid_low'] = mt_sid_low
    mt_cpdag['sid_high'] = mt_sid_high
    method_res_cpdag = {'dataset':dataset_name, 
                        'model': 'ABASP (New)',
                        'elapsed':elapsed ,
                         **mt_cpdag}
    
    return method_res, method_res_cpdag


In [7]:
from tqdm import tqdm

In [None]:
method_res_all = []
method_res_cpdag_all = []
n_runs = 50

np.random.seed(42)
seeds = np.random.randint(0, 10000, n_runs)

for dataset_name in ['cancer', 'earthquake', 'survey']:
    for seed in tqdm(seeds, desc=f"Running {dataset_name} experiments", leave=False):
        print(f"Running experiment with seed {seed}")
        method_res, method_res_cpdag = run_bnlearn_experiment(dataset_name, seed)
        method_res_all.append(method_res)
        method_res_cpdag_all.append(method_res_cpdag)




In [9]:

mt_res = pd.DataFrame()
mt_res_cpdag = pd.DataFrame()


method_sum = pd.DataFrame(method_res_all).groupby(['dataset', 'model'], as_index=False).agg(['mean','std']).round(2).reset_index(drop=True)
method_sum.columns = method_sum.columns.map('_'.join).str.strip('_')
mt_res = pd.concat([mt_res, method_sum], sort=False)

method_sum = pd.DataFrame(method_res_cpdag_all).groupby(['dataset', 'model'], as_index=False).agg(['mean','std']).round(2).reset_index(drop=True)
method_sum.columns = method_sum.columns.map('_'.join).str.strip('_')
mt_res_cpdag = pd.concat([mt_res_cpdag, method_sum], sort=False)

In [10]:
mt_res

Unnamed: 0,dataset,model,elapsed_mean,elapsed_std,nnz_mean,nnz_std,fdr_mean,fdr_std,tpr_mean,tpr_std,...,precision_mean,precision_std,recall_mean,recall_std,F1_mean,F1_std,shd_mean,shd_std,sid_mean,sid_std
0,cancer,ABASP (New),2.46,0.91,4.12,0.69,0.53,0.15,0.48,0.15,...,0.47,0.15,0.48,0.15,0.5,0.11,2.36,0.8,9.96,2.73
1,earthquake,ABASP (New),1.81,1.6,4.74,0.66,0.17,0.19,0.96,0.16,...,0.83,0.19,0.96,0.16,0.89,0.17,0.9,1.13,0.74,3.02
2,survey,ABASP (New),8.92,7.13,4.36,0.9,0.58,0.16,0.32,0.14,...,0.42,0.16,0.32,0.14,0.37,0.14,4.52,1.23,15.24,2.6


In [11]:
mt_res_cpdag

Unnamed: 0,dataset,model,elapsed_mean,elapsed_std,nnz_mean,nnz_std,fdr_mean,fdr_std,tpr_mean,tpr_std,...,recall_mean,recall_std,F1_mean,F1_std,shd_mean,shd_std,sid_low_mean,sid_low_std,sid_high_mean,sid_high_std
0,cancer,ABASP (New),2.46,0.91,4.12,0.69,0.51,0.12,0.5,0.12,...,0.48,0.15,0.5,0.11,2.36,0.8,9.26,1.88,11.26,2.08
1,earthquake,ABASP (New),1.81,1.6,4.74,0.66,0.14,0.12,1.0,0.0,...,0.96,0.14,0.89,0.16,0.9,1.07,0.12,0.59,10.5,4.7
2,survey,ABASP (New),8.92,7.13,4.36,0.9,0.49,0.18,0.37,0.13,...,0.35,0.12,0.41,0.13,4.86,1.32,14.12,3.31,15.76,2.6


In [12]:
np.save(f"./ABASP.npy", mt_res_cpdag)