In [None]:
import sys
sys.path.insert(0, '../..')
sys.path.insert(0, '../../ArgCausalDisco')
sys.path.insert(0, '../../notears')

In [None]:
from src.abasp.factory import ABASPSolverFactory
from src.abasp.utils import Fact

In [None]:
from aspforaba.src.aspforaba import ABASolver
from cd_algorithms.PC import pc
from utils.data_utils import load_bnlearn_data_dag, simulate_dag

from itertools import combinations, product, chain
from utils.graph_utils import initial_strength, set_of_models_to_set_of_graphs

## Irrelevant

In [None]:
dataset_name = 'cancer'
data_path='../../ArgCausalDisco/datasets'
sample_size = 5000
seed = 0

X_s, B_true = load_bnlearn_data_dag(dataset_name, 
                                    data_path, 
                                    sample_size, 
                                    seed=seed, 
                                    print_info=True, 
                                    standardise=True)

data = X_s
alpha = 0.01
indep_test = 'fisherz'
uc_rule = 5
stable = True

n_nodes = data.shape[1]

cg = pc(data=data, alpha=alpha, indep_test=indep_test, uc_rule=uc_rule, stable=stable, show_progress=True, verbose=True)

## Extract facts from PC
facts = []
for node1, node2 in combinations(range(n_nodes), 2):
    test_PC = [t for t in cg.sepset[node1,node2]]
    for sep_set, p in test_PC:
        dep_type_PC = "indep" if p > alpha else "dep" 
        init_strength_value = initial_strength(p, len(sep_set), alpha, 0.5, n_nodes)
        s_str = 'empty' if len(sep_set)==0 else 's'+'y'.join([str(i) for i in sep_set])
        facts.append((node1,sep_set,node2,dep_type_PC, f"{dep_type_PC}({node1},{node2},{s_str}).", init_strength_value))
print('Fact sample:', facts[3])

In [None]:
len(facts), len(set(facts))

In [None]:
sorted_facts = [Fact.from_tuple(fact) for fact in set(facts)]
# sorted by descending strength
sorted_facts = sorted(sorted_facts, key=lambda x: x.score, reverse=True)


In [None]:
# binary search to find the largest fact set where stable extensions exist

left_idx = 0
right_idx = len(sorted_facts) - 1
factory = ABASPSolverFactory(n_nodes=n_nodes)

result_table = dict()  # fact_idx: result, result is True if extension was found

def get_extensions(factory: ABASPSolverFactory, facts, result_table):
    index = len(facts) - 1

    if index not in result_table:
        solver = factory.create_solver(facts)
        extensions = solver.enumerate_extensions('ST')
        result_table[index] = extensions if extensions is not None else []

    return result_table[index]

final_extensions = None
final_facts = None
final_fact_index = None

while left_idx <= right_idx:
    mid = (left_idx + right_idx) // 2

    exts_mid = get_extensions(factory, sorted_facts[:mid + 1], result_table)
    exts_mid_next = get_extensions(factory, sorted_facts[:mid + 2], result_table)

    if len(exts_mid) == 0:  # overshoot
        right_idx = mid - 1
    elif len(exts_mid_next) > 0:  # undershoot
        left_idx = mid + 1
    else:
        # mid is the largest index where extensions exist
        final_extensions = exts_mid
        final_facts = sorted_facts[:mid + 1]
        final_fact_index = mid
        break


In [None]:
final_fact_index

In [None]:
len(final_extensions)

In [None]:
len(sorted_facts)

In [None]:
final_extensions

In [None]:
for ext in final_extensions:
    print([a for a in ext.assumptions if a.startswith('arr') or a.startswith('noe')])

In [None]:
final_facts

## Fact Test

In [None]:
facts3 = """True fact: dep(0,3,empty). I=1.0, truth= NA
   True fact: dep(0,4,empty). I=1.0, truth= NA
   True fact: dep(0,2,empty). I=0.9999958381326048, truth= NA
   True fact: dep(0,1,empty). I=0.9998651373118164, truth= NA
   True fact: indep(1,3,empty). I=0.9307979291435597, truth= NA
   True fact: indep(1,3,empty). I=0.9307979291435597, truth= NA
   True fact: indep(2,3,empty). I=0.8819705060456833, truth= NA
   True fact: indep(2,3,empty). I=0.8819705060456833, truth= NA
   True fact: indep(1,4,empty). I=0.8390142120581524, truth= NA
   True fact: indep(1,4,empty). I=0.8390142120581524, truth= NA
   True fact: indep(2,4,empty). I=0.7502762578454715, truth= NA
   True fact: indep(2,4,empty). I=0.7502762578454715, truth= NA
   True fact: dep(0,3,s1). I=0.6666666666666667, truth= NA
   True fact: dep(0,3,s2). I=0.6666666666666667, truth= NA
   True fact: dep(0,3,s4). I=0.6666666666666667, truth= NA
   True fact: dep(0,4,s1). I=0.6666666666666667, truth= NA
   True fact: dep(0,4,s2). I=0.6666666666666667, truth= NA
   True fact: dep(0,4,s3). I=0.6666666666666667, truth= NA
   True fact: dep(0,2,s4). I=0.6666659094811525, truth= NA
   True fact: dep(0,2,s3). I=0.6666639550266579, truth= NA
   True fact: dep(0,2,s1). I=0.6666603711883725, truth= NA
   True fact: dep(0,1,s3). I=0.6665829788960941, truth= NA
   True fact: dep(0,1,s4). I=0.6665799211462186, truth= NA
   True fact: dep(0,1,s2). I=0.6664612834576787, truth= NA"""

In [None]:
lines = facts3.split('\n')

In [None]:
x = lines[1].strip()
x = x.strip('True fact: ')
dep, x = x.split('(', 1)
node1, node2, node_set_str = x.split(')', 1)[0].split(',')
node1 = int(node1)
node2 = int(node2)
if node_set_str == 'empty':
    node_set = set()
else:
    node_set = set([int(i) for i in node_set_str[1:].split('y')])
print(dep, node1, node2, node_set)

In [None]:
from src.abasp.utils import RelationEnum

In [None]:
RelationEnum('indep') == 'indep'

In [None]:
# parse into Fact objects

def parse_line(line):
    print(line)
    line = line.strip()
    line = line.strip('True fact: ')
    relation, line = line.split('(', 1)
    node1, node2, node_set_str = line.split(')', 1)[0].split(',')
    node1 = int(node1)
    node2 = int(node2)
    if node_set_str == 'empty':
        node_set = set()
    else:
        node_set = set([int(i) for i in node_set_str[1:].split('y')])
    fact = Fact(
        relation=RelationEnum(relation),
        node1=node1,
        node2=node2,
        node_set=node_set,
        score=0.0,  # or whatever score you want to assign
    )
    return fact


In [None]:
facts4 = [parse_line(line) for line in lines]

In [None]:
n_nodes = 5

In [None]:
len(facts4)

In [None]:
facts5 = []
for f in facts4:
    if f not in facts5:
        facts5.append(f)

In [None]:
len(facts5)

In [None]:
factory = ABASPSolverFactory(n_nodes=n_nodes)
solver = factory.create_solver(facts5)
solver.enumerate_extensions('ST')

## Start Here

In [None]:
factory = ABASPSolverFactory(n_nodes=2)
from src.abasp.utils import RelationEnum, Fact
solver = factory.create_solver([
    Fact(
        relation=RelationEnum('dep'),
        node1=0,
        node2=1,
        node_set=set(),
        score=0.0,
    )
])

In [None]:
solver.enumerate_extensions('ST')

In [None]:
solver.rules

In [None]:
len(solver.rules)

In [None]:
solver.abaf.asmpt_to_idx

In [None]:
solver.abaf.atom_to_idx

In [None]:
factory = ABASPSolverFactory(n_nodes=5)
from src.abasp.utils import RelationEnum, Fact
solver = factory.create_solver([
    Fact(
        relation=RelationEnum('dep'),
        node1=0,
        node2=1,
        node_set=set(),
        score=0.0,
    ),
    Fact(
        relation=RelationEnum('indep'),
        node1=0,
        node2=1,
        node_set={2},
        score=0.0,
    )
])

In [None]:
len(solver.rules)

In [None]:
solver.rules[70:85]

In [None]:
solver.assumptions

In [None]:
x = []
for a in solver.abaf.idx_to_atom.values():
    if a in x:
        print(a)
    else:
        x.append(a)
print('done')

In [None]:
len(solver.atoms)

In [None]:
len(solver.abaf.idx_to_atom), len(set(solver.abaf.idx_to_atom.values()))

In [None]:
solver.enumerate_extensions('ST')

In [None]:
def f(r):
    cond = (
        'blocked_path_0_1__1__' in r[0] or
        'blocked_path_0_1__1__' in r[1] or 
        'nb_2' in r[0] or
        'not_collider_0_2_1' in r[0] or
        'indep' in r[0]
    )
    return cond

[r for r in solver.rules if f(r)]

In [None]:
solver.decide_credulous('ST', ['not_collider_0_2_1'])
# solver.decide_credulous('ST', 'arr_2_0')

In [None]:
[r for r in solver.rules if 'indep_0_1__' in r[0]]

In [None]:
[r for r in solver.rules if '-blocked_path_0_1__0__' in r[0]]

In [None]:
[r for r in solver.rules if '-blocked_path_0_1__1__' in r[0]]

In [None]:
[r for r in solver.rules if 'nb_2__0_1__'==r[0]]

In [None]:
[r for r in solver.rules if 'path_0_1__1'==r[0]]

In [None]:
[r for r in solver.rules if 'not_collider_0_2_1'==r[0]]

In [None]:
[r for r in solver.rules if 'collider_0_2_1'==r[0]]

In [None]:
[r for r in solver.rules if '-noe_0_2'==r[0]]

In [None]:
[r for r in solver.rules if '-noe_1_2'==r[0]]

In [None]:
[r for r in solver.rules if 'arr_2_1' in r[1] and 'arr_0_2' in r[1]]

In [None]:
from aspforaba.src.aspforaba import ABASolver

In [None]:
solver = ABASolver()

In [None]:
solver = ABASolver()
solver.add_assumption('a1')
solver.add_assumption('a2')
solver.add_assumption('noe1')
solver.add_assumption('noe2')
solver.add_contrary('a1', '-a1')
solver.add_contrary('a2', '-a2')
solver.add_contrary('noe1', '-noe1')
solver.add_contrary('noe2', '-noe2')
solver.add_assumption('blocked_path')
solver.add_contrary('blocked_path', '-blocked_path')
solver.add_assumption('indep')
solver.add_contrary('indep', '-indep')

# solver.add_assumption('d')
# solver.add_contrary('d', '-d')


solver.add_rule('not_collider', ['a1', 'a2'])
solver.add_rule('not_blocked', ['not_collider'])
solver.add_rule('-noe1', ['a1'])
solver.add_rule('-noe2', ['a1'])
solver.add_rule('path', ['-noe1', '-noe2'])
solver.add_rule('-blocked_path', ['path', 'not_blocked'])
solver.add_rule('-indep', ['-blocked_path'])
# solver.add_rule('indep', [])




# solver.add_rule('a', [])

solver.enumerate_extensions('ST')

In [None]:
def model_to_graph(model):
    # Extract the assumptions from the model
    assumptions = [a for a in model.assumptions if a.startswith('arr')]
    # Convert the assumptions to a graph representation
    graph = set()
    for assumption in assumptions:
        if assumption.startswith('arr'):
            _, node1, node2 = assumption.split('_')
            graph.add((int(node1), int(node2)))
        elif assumption.startswith('noe'):
            _, node1, node2 = assumption.split('_')
            graph.add((int(node1), int(node2)))
    return graph

In [None]:
import unittest