In [None]:
%load_ext autoreload
%autoreload 2

import json
import re
from pathlib import Path
from typing import Callable

import numpy as np
import networkx as nx
import pandas as pd
import pyagrum as gum

from abapc import ABAPC
from utils.helpers import random_stability


dataset_size = {
    "asia": "small",
    "cancer": "small",
    "earthquake": "small",
    "sachs": "small",
    "survey": "small",
    # "child": "medium",  # We skip child dataset as it has irregular characters causing pyagrum parse errors
    "insurance": "medium",
}


def load_dataset(
    bif_path: Path, sample_size: int, seed: int
) -> tuple[gum.BayesNet, np.ndarray]:
    bn = gum.loadBN(str(bif_path))
    bn.name = bif_path.stem

    gum.initRandom(seed=seed)
    df = gum.generateSample(bn, sample_size, with_labels=False, random_order=False)[0]
    sorted_vars = sorted(df.columns)
    data = df[sorted_vars].to_numpy().astype(float)

    return bn, data


def extract_facts(string: str) -> list[dict]:
    pattern = re.compile(r"ext_((in)?dep)\((.+)\). I=([01].\d+), NA\n")
    matches = pattern.findall(string)
    facts = []
    for match in matches:
        cit_type, _, triple, score = match
        X, Y, S = triple.split(",")
        facts.append(
            {
                "cit_type": cit_type,
                "X": int(X),
                "Y": int(Y),
                "S": set() if S == "empty" else {int(var) for var in S[1:].split("y")},
                "score": float(score),
            }
        )
    return pd.DataFrame(facts).sort_values(
        by="score", ascending=False, ignore_index=True
    )

## Evaluation of the external facts ranking using Average Precision

Let's use Asia dataset as an example to compare the ranking of the following configurations:
- S_weight = True, cit method = "fisherz"
- S_weight = False, cit method = "fisherz"
- S_weight = True, cit method = "gsq"
- S_weight = False, cit method = "gsq"

Here `S_weight` represents whether to penalise the size of the conditioning set, and `cit method` refers to the conditional independence test method used in the MPC algorithm.

We use `Average Precision (AP)` and `Normalized Discounted Cumulative Gain (NDCG)` as the evaluation metrics, they are common recommendation system metrics that evaluate the quality of ranked lists.

In [None]:
# %%script false --no-raise-error
from sklearn.metrics import average_precision_score, ndcg_score


repeats = {
    "small": 50,
    "medium": 50,
}
sample_size = 5000

ranking_metrics = []
for dataset, size in dataset_size.items():
    bif_path = Path(f"datasets/bayesian/{size}/{dataset}.bif/{dataset}.bif")
    random_stability(2024)
    seeds = np.random.randint(0, 10000, size=repeats[size]).tolist()
    for S_weight in [True, False]:
        for cit_method in ["fisherz", "gsq"]:
            total_average_precision = 0.0
            total_ndcg = 0.0
            for seed in seeds:
                bn_true, data = load_dataset(
                    bif_path, sample_size=sample_size, seed=seed
                )
                facts_I_path, _ = ABAPC(
                    data,
                    seed=seed,
                    alpha=0.01,
                    indep_test=cit_method,
                    S_weight=S_weight,
                    pre_grounding=True,
                    out_mode="facts_only",
                    scenario=f"ranking_test_{dataset}",
                )
                sorted_vars = sorted(bn_true.names())
                df = extract_facts(Path(facts_I_path).read_text())

                df["X"] = df["X"].apply(lambda x: sorted_vars[x])
                df["Y"] = df["Y"].apply(lambda y: sorted_vars[y])
                df["S"] = df["S"].apply(lambda s: {sorted_vars[i] for i in s})
                df["correct"] = df.apply(
                    lambda row: (
                        row["cit_type"] == "indep"
                        if bn_true.isIndependent(row["X"], row["Y"], row["S"])
                        else (row["cit_type"] == "dep")
                    ),
                    axis=1,
                )

                total_average_precision += average_precision_score(
                    df["correct"], df["score"]
                )
                total_ndcg += ndcg_score([df["correct"]], [df["score"]])

            ranking_metrics.append(
                (
                    dataset,
                    S_weight,
                    cit_method,
                    total_average_precision / repeats[size],
                    total_ndcg / repeats[size],
                )
            )

ranking_df = pd.DataFrame(
    ranking_metrics,
    columns=[
        "dataset",
        "S_weight",
        "cit_method",
        "mean_average_precision",
        "mean_ndcg",
    ],
)
ranking_df.to_csv("results/2025/ranking_metrics.csv", index=False)
ranking_df

## Computation Efficiency Comparison

Let's compare the original implementation with the optimized one in terms of computation time on conflicts resolution. We use the same external facts generated on the setting `S_weight = False` and `cit method = gsq`, and `disable weak constraints`.

### Original implementation

In [None]:
# Copied and modifed from
# https://github.com/briziorusso/ArgCausalDisco/blob/579a1c95f576ecb53d97b0c62d4ceebb90368189/causalaba.py
import os
import logging
from clingo.control import Control
from clingo import Number, Function
import numpy as np
from itertools import combinations
from datetime import datetime

from utils.graph_utils import powerset, extract_test_elements_from_symbol
from priors.generate_priors import Constraints, PriorKnowledge


def original_compile_and_ground(n_nodes:int, facts_location:str="",
                skeleton_rules_reduction:bool=False,
                weak_constraints:bool=False,
                indep_facts:set=set(),
                dep_facts:set=set(),
                opt_mode:str='optN',
                show:list=['arrow'],
                *args,
                **kwargs,  # Ignore parameters introduced in the new implementation
                )->Control:

    logging.info(f"Compiling the program")
    ### Create Control
    ctl = Control(['-t %d' % os.cpu_count()])
    ctl.configuration.solve.parallel_mode = os.cpu_count()
    ctl.configuration.solve.models=1
    ctl.configuration.solver.seed="2024"
    ctl.configuration.solve.opt_mode = opt_mode

    ### Add set definition
    for S in powerset(range(n_nodes)):
        for s in S:
            ctl.add("specific", [], f"in({s},{'s'+'y'.join([str(i) for i in S])}).")
            # logging.debug(f"in({s},{'s'+'y'.join([str(i) for i in S])}).")

    ### Load main program and facts
    ctl.load("./encodings/causalaba.lp")
    if facts_location != "":
        ctl.load(facts_location)
    if weak_constraints:
        ctl.load(facts_location.replace(".lp","_wc.lp"))

    ctl.add("specific", [], "indep(X,Y,S) :- ext_indep(X,Y,S), var(X), var(Y), set(S), X!=Y.")
    ctl.add("specific", [], "dep(X,Y,S) :- ext_dep(X,Y,S), var(X), var(Y), set(S), X!=Y.")
    ### add nonblocker rules
    logging.info("   Adding Specific Rules...")

    ### Active paths rules
    n_p = 0
    G = nx.complete_graph(n_nodes)
    for (X,Y) in combinations(range(n_nodes),2):
        paths = nx.all_simple_paths(G, source=X, target=Y)
        ### remove paths that contain an indep fact
        paths_mat = np.array([np.array(list(xi)+[None]*(n_nodes-len(xi))) for xi in paths])
        if skeleton_rules_reduction:
            paths_mat_red = paths_mat[[not any([(paths_mat[i,j],paths_mat[i,j+1]) in indep_facts or
                                                (paths_mat[i,j+1],paths_mat[i,j]) in indep_facts
                                                for j in range(n_nodes-1) if paths_mat[i,j] is not None]) \
                                                    for i in range(len(paths_mat))]] ##TODO: think about interaction with weak constraints
            remaining_paths = [list(filter(lambda x: x is not None, paths_mat_red[i])) for i in range(len(paths_mat_red))]
            logging.debug(f"Paths from {X} to {Y}: {len(paths_mat)}, removing indep: {len(remaining_paths)}")#
            excluded_paths = [list(filter(lambda x: x is not None, paths_mat[i])) for i in range(len(paths_mat)) if any([(paths_mat[i,j],paths_mat[i,j+1]) in indep_facts for j in range(n_nodes-1) if paths_mat[i,j] is not None])]
            logging.debug(f"Excluded paths from {X} to {Y}: {len(excluded_paths)}")
        else:
            remaining_paths = [list(filter(lambda x: x is not None, paths_mat[i])) for i in range(len(paths_mat))]

        indep_rule_body = []
        for path in remaining_paths:
            n_p += 1
            ### build indep rule body
            indep_rule_body.append(f" not ap({X},{Y},p{n_p},S)")

            ### add path rule
            path_edges = [f"edge({path[idx]},{path[idx+1]})" for idx in range(len(path)-1)]
            ctl.add("specific", [], f"p{n_p} :- {','.join(path_edges)}.")
            logging.debug(f"   p{n_p} :- {','.join(path_edges)}.")

            ### add active path rule
            nbs = [f"nb({path[idx]},{path[idx-1]},{path[idx+1]},S)" for idx in range(1,len(path)-1)]
            nbs_str = ','.join(nbs)+"," if len(nbs) > 0 else ""
            ctl.add("specific", [], f"ap({X},{Y},p{n_p},S) :- p{n_p}, {nbs_str} not in({X},S), not in({Y},S), set(S).")
            logging.debug(f"   ap({X},{Y},p{n_p},S) :- p{n_p}, {nbs_str} not in({X},S), not in({Y},S), set(S).")

        ### add indep rule
        indep_rule_body = "" if not indep_rule_body else f"{','.join(indep_rule_body)}, "
        if skeleton_rules_reduction:
            if (X,Y) in dep_facts:
                indep_rule = f"indep({X},{Y},S) :- {indep_rule_body}not in({X},S), not in({Y},S), set(S)."
                ctl.add("specific", [], indep_rule)
                logging.debug(   indep_rule)
        else:
            indep_rule = f"indep({X},{Y},S) :- {indep_rule_body} not in({X},S), not in({Y},S), set(S)."
            ctl.add("specific", [], indep_rule)
            logging.debug(   indep_rule)

    ### add dep rule
    ctl.add("specific", [], f"dep(X,Y,S):- ap(X,Y,_,S), var(X), var(Y), X!=Y, not in(X,S), not in(Y,S), set(S).")
    logging.debug(   f"dep(X,Y,S) :- ap(X,Y,_,S), var(X), var(Y), X!=Y, not in(X,S), not in(Y,S), set(S).")

    ### add show statements
    if 'arrow' in show:
        ctl.add("base", [], "#show arrow/2.")
    if 'indep' in show:
        ctl.add("base", [], "#show indep/3.")
    if 'dep' in show:
        ctl.add("base", [], "#show dep/3.")
    if 'collider' in show:
        ctl.add("base", [], "#show collider/3.")
    if 'collider_desc' in show:
        ctl.add("base", [], "#show collider_desc/4.")
    if 'nb' in show:
        ctl.add("base", [], "#show nb/4.")
    if 'ap' in show:
        ctl.add("base", [], "#show ap/4.")
    if 'dpath' in show:
        ctl.add("base", [], "#show dpath/2.")

    ### Ground
    logging.info("   Grounding...")
    start_ground = datetime.now()
    ctl.ground([("base", []), ("facts", []), ("specific", []), ("main", [Number(n_nodes-1)])])
    logging.info(f"   Grounding time: {str(datetime.now()-start_ground)}")

    return ctl

### Define Causal ABA Interface to call both implementations

In [None]:
def CausalABA(
    n_nodes: int,
    compile_and_ground: Callable,
    facts_location: str = "",
    print_models: bool = True,
    skeleton_rules_reduction: bool = False,
    weak_constraints: bool = False,
    opt_mode: str = "optN",
    show: list = ["arrow"],
    pre_grounding: bool = False,
    disable_reground: bool = False,
    prior_knowledge: PriorKnowledge | None = None,
) -> list:
    """
    CausalABA, a function that takes in the number of nodes in a graph and a string of facts and returns a list of compatible causal graphs.

    """
    # (X, Y) -> their condition sets S
    indep_facts: dict[tuple, set[tuple]] = {}
    dep_facts: dict[tuple, set[tuple]] = {}
    facts = []
    ext_flag = False
    if facts_location:
        facts_loc = (
            facts_location.replace(".lp", "_I.lp")
            if weak_constraints
            else facts_location
        )
        logging.debug(f"   Loading facts from {facts_location}")
        with open(facts_loc, "r") as file:
            for line in file:
                if "dep" not in line or line.startswith("%"):
                    continue
                line_clean = line.replace("#external ", "").replace("\n", "")
                if "ext_" in line_clean:
                    ext_flag = True
                if weak_constraints:
                    statement, Is = line_clean.split(" I=")
                    I, truth = Is.split(",")
                    X, S, Y, dep_type = extract_test_elements_from_symbol(statement)
                    facts.append((X, S, Y, dep_type, statement, float(I), truth))
                else:
                    X, S, Y, dep_type = extract_test_elements_from_symbol(line_clean)
                    facts.append((X, S, Y, dep_type, line_clean, np.nan, "unknown"))

                assert (X not in S) and (Y not in S), f"X or Y in S: {line_clean}"
                condition_set = tuple(S)

                facts_group = indep_facts if "indep" in line_clean else dep_facts
                if (X, Y) not in facts_group:
                    facts_group[(X, Y)] = set()
                assert condition_set not in facts_group[(X, Y)], (
                    f"Redundant external fact: {line_clean}"
                )
                facts_group[(X, Y)].add(condition_set)

    ctl = compile_and_ground(
        n_nodes,
        facts_location,
        skeleton_rules_reduction,
        weak_constraints,
        indep_facts,
        dep_facts,
        opt_mode,
        show,
        pre_grounding,
        ext_flag,
        prior_knowledge,
    )

    facts = sorted(facts, key=lambda x: x[5], reverse=True)
    for fact in facts:
        ctl.assign_external(
            Function(
                fact[3],
                [
                    Number(fact[0]),
                    Number(fact[2]),
                    Function(fact[4].replace(").", "").split(",")[-1]),
                ],
            ),
            True,
        )
        logging.debug(f"   True fact: {fact[4]} I={fact[5]}, truth={fact[6]}")
    models = []
    logging.info("   Solving...")
    with ctl.solve(yield_=True) as handle:
        for model in handle:
            models.append(model.symbols(shown=True))
            if print_models:
                logging.info(f"Answer {len(models)}: {model}")
    n_models = int(ctl.statistics["summary"]["models"]["enumerated"])
    logging.info(f"Number of models: {n_models}")
    times = {
        key: ctl.statistics["summary"]["times"][key]
        for key in ["total", "cpu", "solve"]
    }
    logging.info(f"Times: {times}")
    remove_n = 0
    logging.info(f"Number of facts removed: {remove_n}")

    ## start removing facts if no models are found
    while n_models == 0 and remove_n < len(facts):
        remove_n += 1
        logging.info(f"Number of facts removed: {remove_n}")

        reground = False
        fact_to_remove = facts[-remove_n]
        X, S, Y, dep_type, fact_str = fact_to_remove[:5]
        logging.debug(f"Removing fact {fact_str}")

        facts_group = indep_facts if dep_type == "ext_indep" else dep_facts
        facts_group[(X, Y)].remove(tuple(S))
        if not facts_group[(X, Y)]:
            del facts_group[(X, Y)]
            reground = (
                disable_reground is False
                and skeleton_rules_reduction
                and (ext_flag is False or dep_type == "ext_indep")
            )
        else:
            logging.debug(
                f"   Not removing fact {fact_str} because there are multiple facts with the same X and Y"
            )
        ctl.assign_external(
            Function(
                dep_type,
                [
                    Number(X),
                    Number(Y),
                    Function(fact_str.replace(").", "").split(",")[-1]),
                ],
            ),
            None,
        )

        if reground:
            ### Save external statements
            logging.info("Recompiling and regrounding...")
            ctl = compile_and_ground(
                n_nodes,
                facts_location,
                skeleton_rules_reduction,
                weak_constraints,
                indep_facts,
                dep_facts,
                opt_mode,
                show,
                pre_grounding=pre_grounding,
                ext_flag=ext_flag,
                prior_knowledge=prior_knowledge,
            )
            for fact in facts[:-remove_n]:
                ctl.assign_external(
                    Function(
                        fact[3],
                        [
                            Number(fact[0]),
                            Number(fact[2]),
                            Function(fact[4].replace(").", "").split(",")[-1]),
                        ],
                    ),
                    True,
                )
                logging.debug(f"   True fact: {fact[4]} I={fact[5]}, truth={fact[6]}")
            for fact in facts[-remove_n:]:
                ctl.assign_external(
                    Function(
                        fact[3],
                        [
                            Number(fact[0]),
                            Number(fact[2]),
                            Function(fact[4].replace(").", "").split(",")[-1]),
                        ],
                    ),
                    None,
                )
                logging.debug(f"   False fact: {fact[4]} I={fact[5]}, truth={fact[6]}")
        models = []
        logging.info("   Solving...")
        with ctl.solve(yield_=True) as handle:
            for model in handle:
                models.append(model.symbols(shown=True))
                if print_models:
                    logging.info(f"Answer {len(models)}: {model}")
        n_models = int(ctl.statistics["summary"]["models"]["enumerated"])
        logging.info(f"Number of models: {n_models}")
        times = {
            key: ctl.statistics["summary"]["times"][key]
            for key in ["total", "cpu", "solve"]
        }
        logging.info(f"Times: {times}")

    return {
        "remove_n": remove_n,
        "statistics": ctl.statistics,
        "models": models,
    }


### Define the evaluation function to compare performance of both implementations

In [None]:
import timeit
from typing import Any

from tqdm import tqdm

import causalaba
from causalaba import compile_and_ground as new_compile_and_ground


causalaba.tqdm = tqdm  # Avoid using progress bar widgets

def original_implementation(bn: gum.BayesNet, facts_path: str, return_ref: list) -> Callable:
    def helper():
        res = CausalABA(
                n_nodes=bn.size(),
                compile_and_ground=original_compile_and_ground,
                facts_location=facts_path,
                print_models=False,
                skeleton_rules_reduction=True,
                weak_constraints=False,  # We only care about hard constraints
                pre_grounding=False,
                prior_knowledge=None,
            )
        return_ref.append(
            {
                "remove_n": res["remove_n"],
                "atoms": res["statistics"]["problem"]["lp"]["atoms"],
                "rules": res["statistics"]["problem"]["lp"]["rules"],
                "conflicts": res["statistics"]["solving"]["solvers"]["conflicts"],
            }
        )

    return helper

def new_implementation(bn: gum.BayesNet, facts_path: str, return_ref: list) -> Callable:
    def helper():
        res = CausalABA(
                n_nodes=bn.size(),
                compile_and_ground=new_compile_and_ground,
                facts_location=facts_path,
                print_models=False,
                skeleton_rules_reduction=True,
                weak_constraints=False,  # We only care about hard constraints
                pre_grounding=False,
                prior_knowledge=None,
            )
        return_ref.append(
            {
                "remove_n": res["remove_n"],
                "atoms": res["statistics"]["problem"]["lp"]["atoms"],
                "rules": res["statistics"]["problem"]["lp"]["rules"],
                "conflicts": res["statistics"]["solving"]["solvers"]["conflicts"],
            }
        )

    return helper

def compare_performance(
    bif_path: str | Path,
    sample_size: int,
    impl1: Callable[[gum.BayesNet, str, list], Callable],
    impl2: Callable[[gum.BayesNet, str, list], Callable],
    repeats: int = 1,
) -> tuple[str, int, float, float, list, list, np.ndarray, Any]:
    random_stability(2024)
    seeds = np.random.randint(0, 10000, size=repeats).tolist()

    if isinstance(bif_path, str):
        bif_path = Path(bif_path)

    impl1_time, impl2_time = 0, 0
    impl1_return = []
    impl2_return = []

    for seed in seeds:
        bn_true, data = load_dataset(bif_path, sample_size=sample_size, seed=seed)
        facts_I_path, sepsets = ABAPC(
            data,
            seed=seed,
            alpha=0.01,
            indep_test="gsq",
            S_weight=False,
            out_mode="facts_only",
            scenario="performance_comparison",
        )
        facts_path = facts_I_path.replace("_I.lp", ".lp")

        impl1_time += timeit.timeit(impl1(bn_true, facts_path, impl1_return), number=1)
        impl2_time += timeit.timeit(impl2(bn_true, facts_path, impl2_return), number=1)

    return bif_path.stem, bn_true.size(), impl1_time / repeats, impl2_time / repeats, impl1_return, impl2_return, seeds, sepsets

### Evaluate on small size bayesian networks

In [None]:
# %%script false --no-raise-error
avg_time_res = []
repeats = 50
datasets = {"asia", "cancer", "survey", "earthquake"}
for dataset, size in dataset_size.items():
    if dataset not in datasets:
        continue  # Due to scalability limitations of original implementation, we only test small datasets
    (
        bif_name,
        graph_order,
        impl1_avg_time,
        impl2_avg_time,
        impl1_return,
        impl2_return,
        *_,
    ) = compare_performance(
        bif_path=f"datasets/bayesian/{size}/{dataset}.bif/{dataset}.bif",
        sample_size=5000,
        impl1=new_implementation,
        # Note that original_implementation will modify the files in facts_I.lp,
        # so we call it after the new implementation
        impl2=original_implementation,
        repeats=repeats,
    )
    impl1_avg_stats = pd.DataFrame(impl1_return).add_prefix("new_").mean().to_dict()
    impl2_avg_stats = pd.DataFrame(impl2_return).add_prefix("org_").mean().to_dict()
    avg_time_res.append({
        "dataset": dataset,
        "num_nodes": graph_order,
        "new_time": impl1_avg_time,
        "org_time": impl2_avg_time,
        "repeats": repeats,
        **impl1_avg_stats,
        **impl2_avg_stats,
    })

avg_time_res_df = pd.DataFrame(
    avg_time_res,
    columns=[
        "dataset",
        "num_nodes",
        "org_time",
        "new_time",
        "org_remove_n",
        "new_remove_n",
        "org_atoms",
        "new_atoms",
        "org_rules",
        "new_rules",
        "org_conflicts",
        "new_conflicts",
        "repeats",
    ],
)
assert avg_time_res_df["org_remove_n"].equals(avg_time_res_df["new_remove_n"])
avg_time_res_df.to_csv("results/2025/small_bnlearn_scalability.csv", index=False)
avg_time_res_df

### Evaluate on random bayesian networks

Let's compare the performance of the original and new implementations on random bayesian networks.

In [None]:
# %%script false --no-raise-error
import tempfile

# Define the graph orders and their corresponding number of repeats
graph_orders = [5, 6, 7, 8]
repeats = 50
avg_time_res = []
for order in graph_orders:
    gum.initRandom(2025)
    random_bn = gum.randomBN(n=order, ratio_arc=1)
    temp_bif_path = str(Path(tempfile.gettempdir()) / f"random_bn_{order}.bif")
    random_bn.saveBIF(temp_bif_path)

    (
        bif_name,
        graph_order,
        impl1_avg_time,
        impl2_avg_time,
        impl1_return,
        impl2_return,
        *_,
    ) = compare_performance(
        bif_path=temp_bif_path,
        sample_size=5000,
        impl1=new_implementation,
        impl2=original_implementation,
        repeats=repeats,
    )
    impl1_avg_stats = pd.DataFrame(impl1_return).add_prefix("new_").mean().to_dict()
    impl2_avg_stats = pd.DataFrame(impl2_return).add_prefix("org_").mean().to_dict()
    avg_time_res.append({
        "dataset": bif_name,
        "num_nodes": graph_order,
        "new_time": impl1_avg_time,
        "org_time": impl2_avg_time,
        "repeats": repeats,
        **impl1_avg_stats,
        **impl2_avg_stats,
    })

avg_time_res_df = pd.DataFrame(
    avg_time_res,
    columns=[
        "dataset",
        "num_nodes",
        "org_time",
        "new_time",
        "org_remove_n",
        "new_remove_n",
        "org_atoms",
        "new_atoms",
        "org_rules",
        "new_rules",
        "org_conflicts",
        "new_conflicts",
        "repeats",
    ],
)
assert avg_time_res_df["org_remove_n"].equals(avg_time_res_df["new_remove_n"])
avg_time_res_df.to_csv("results/2025/rand_graph_scalability.csv", index=False)
avg_time_res_df