# Goal: Implement Causal Discovery ABAF using the ABASP libruary

## Preliminaries

### Imports

In [None]:
import sys
from dataclasses import dataclass
from enum import Enum
import networkx as nx

In [None]:
sys.path.append('../..')
sys.path.insert(0, '../..')
sys.path.insert(0, '../../ArgCausalDisco')
sys.path.insert(0, '../../notears')

In [None]:
from aspforaba.src.aspforaba import ABASolver
from cd_algorithms.PC import pc
from utils.data_utils import load_bnlearn_data_dag, simulate_dag

In [None]:
from itertools import combinations, product, chain
from utils.graph_utils import initial_strength, set_of_models_to_set_of_graphs

### Get dataset

In [None]:
dataset_name = 'cancer'
data_path='../../ArgCausalDisco/datasets'
sample_size = 5000
seed = 42

In [None]:
X_s, B_true = load_bnlearn_data_dag(dataset_name, 
                                    data_path, 
                                    sample_size, 
                                    seed=seed, 
                                    print_info=True, 
                                    standardise=True)

### Get Facts from PC

In [None]:
data = X_s
alpha = 0.01
indep_test = 'fisherz'
uc_rule = 5
stable = True

In [None]:
cg = pc(data=data, alpha=alpha, indep_test=indep_test, uc_rule=uc_rule, stable=stable, show_progress=True, verbose=True)

In [None]:
n_nodes = data.shape[1]
print('Number of nodes:', n_nodes)

In [None]:
## Extract facts from PC
facts = []
for node1, node2 in combinations(range(n_nodes), 2):
    test_PC = [t for t in cg.sepset[node1,node2]]
    for sep_set, p in test_PC:
        dep_type_PC = "indep" if p > alpha else "dep" 
        init_strength_value = initial_strength(p, len(sep_set), alpha, 0.5, n_nodes)
        s_str = 'empty' if len(sep_set)==0 else 's'+'y'.join([str(i) for i in sep_set])
        facts.append((node1,sep_set,node2,dep_type_PC, f"{dep_type_PC}({node1},{node2},{s_str}).", init_strength_value))
print('Fact sample:', facts[3])

### Define Facts as Dataclass

In [None]:
class RelationEnum(str, Enum):
    dep = "dep"
    indep = "indep"

@dataclass
class Fact:
    relation: RelationEnum
    node1: int
    node2: int
    node_set: set
    score: float

    @classmethod
    def from_tuple(cls, tpl):
        node1, node_set_tuple, node2, relation, _, score = tpl
        # independence is symmetric
        if node1 > node2:
            node1, node2 = node2, node1
        return cls(
            relation=RelationEnum(relation),
            node1=int(node1),
            node2=int(node2),
            node_set=set(int(i) for i in node_set_tuple),
            score=float(score)
        )

In [None]:
fact = Fact.from_tuple(facts[3])
fact

In [None]:
raw_facts = facts

In [None]:
facts = [Fact.from_tuple(fact) for fact in raw_facts]

## Define ABAF

### Helper Functions

In [None]:
# Contrary

def contrary(assumption):
    return f"-{assumption}"

In [None]:
# Assumptions and their contraries

def arr(X, Y):
    return f"arr_{X}_{Y}"

def noe(X, Y):
    # no-edge is symmetric
    if X > Y:
        X, Y = Y, X
    return f"noe_{X}_{Y}"

def edge(X, Y):
    return contrary(noe(X, Y))

def indep(X, Y, S):
    # is symmetric with respect to X and Y
    if X > Y:
        X, Y = Y, X
    S = sorted(list(S))
    return f"indep_{X}_{Y}__" + '_'.join([str(i) for i in S])

def dep(X, Y, S):
    return contrary(indep(X, Y, S))

def blocked_path(source, target, path_id:int, S: set):
    S = sorted(list(S))
    # is symmetric with respect to source and target
    if source > target:
        source, target = target, source
    return f"blocked_path_{source}_{target}__{path_id}__" + '_'.join([str(i) for i in S])

def active_path(source, target, path_id:int, S: set):
    return contrary(blocked_path(source, target, path_id, S))


In [None]:
# Atoms

def dpath(X, Y):
    return f"dpath_{X}_{Y}"

def collider(X, Y, Z):
    # colliding on middle node
    # collider_X_Y_Z equivalent to X->Y<-Z
    # is simmetric with respect to X and Z
    if X > Z:
        X, Z = Z, X
    return f"collider_{X}_{Y}_{Z}"

def not_collider(X, Y, Z):
    # symmetric with respect to X and Z
    if X > Z:
        X, Z = Z, X
    return f"not_collider_{X}_{Y}_{Z}"

def descendant_of_collider(Z, X, N, Y):
    # descendant of collider
    # Z is the descendant node
    # X and Y are the colliding nodes
    # N is the middle node
    # symmetric with respect to X and Y
    if X > Y:
        X, Y = Y, X
    return f"desc_{Z}_of_collider_{X}_{N}_{Y}"

def non_blocking(N:int, X:int, Y:int, S: set):
    # N is non-blocking node for the path X-Y
    # X and Y immediate neighbours of N
    # S is the set of nodes for which the path is S-active

    # is symmetric with respect to X and Y
    if X > Y:
        X, Y = Y, X
    S = sorted(list(S))
    return f"nb_{N}__{X}_{Y}__" + '_'.join([str(i) for i in S])

def path(source:int, target:int, path_id: int):
    # path is symmetric with respect to source and target
    if source > target:
        source, target = target, source
    return f'path_{source}_{target}__{path_id}'


In [None]:
# Utils

def is_unique(ary):
    return len(ary) == len(set(ary))

def powerset(s):
    s = sorted(list(s))
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def unique_product(elements, repeat: int):
    '''
    Generate all combinations of elements without duplicates.
    '''
    for element_set in product(elements, repeat=repeat):
        if is_unique(element_set):
            yield element_set

### Define ABAF Core - edges, collider rules, etc.

In [None]:
def add_graph_edge_assumptions(solver, X, Y):
    for assumption in [arr(X, Y), arr(Y, X), noe(X, Y)]:
        solver.add_assumption(assumption)
        solver.add_contrary(assumption, contrary(assumption))

    for assumption1, assumption2 in unique_product([arr(X, Y), arr(Y, X), noe(X, Y)], repeat=2):
        solver.add_rule(contrary(assumption2), [assumption1])

    solver.add_rule(dpath(X, Y), [arr(X, Y)])
    solver.add_rule(dpath(Y, X), [arr(Y, X)])

    solver.add_rule(edge(X, Y), [arr(X, Y)])
    solver.add_rule(edge(X, Y), [arr(Y, X)])

In [None]:
def add_acyclicity_rules(solver, X, Y):
    solver.add_rule(contrary(arr(Y, X)), [dpath(X, Y)])

In [None]:
def add_non_blocking_rules(solver, X, Y, S, n_nodes):
    for N in S:
        if N not in {X, Y}:  # unique X, Y, N
            # 1) N doesn't block the S-active path between its neighbours X and Y
            #    if N is a collider and belongs to the set S
            solver.add_rule(non_blocking(N, X, Y, S), [collider(X, N, Y)])
    
    for N in set(range(n_nodes)) - set(S):  # nodes not in S
        if N not in {X, Y}:  # unique X, Y, N
            # 2) N doesn't block the S-active path between its neighbours X and Y
            #    if N is not a collider and doesn't belong to the set S
            solver.add_rule(non_blocking(N, X, Y, S), [not_collider(X, N, Y)])

            # 3) N doesn't block the S-active path between its neighbours X and Y
            #    if N doesn't belong to the set S and has descendant that belongs to S
            for Z in S:
                if Z not in {X, Y, N}:
                    solver.add_rule(non_blocking(N, X, Y, S), [collider(X, N, Y), descendant_of_collider(Z, X, N, Y)])
    

In [None]:
def add_direct_path_definition_rules(solver, X, Y, Z):
   solver.add_rule(dpath(X, Y), [arr(X, Z), dpath(Z, Y)])

In [None]:
def add_collider_definition_rules(solver, X, Y, Z):
    # collider on middle node: X->Y<-Z
    solver.add_rule(collider(X, Y, Z), [arr(X, Y), arr(Z, Y)])

    # all not collider cases
    # X->Y->Z
    solver.add_rule(not_collider(X, Y, Z), [arr(X, Y), arr(Y, Z)])
    # X<-Y->Z
    solver.add_rule(not_collider(X, Y, Z), [arr(Y, X), arr(Y, Z)])
    # X<-Y<-Z
    solver.add_rule(not_collider(X, Y, Z), [arr(Z, Y), arr(Y, X)])

In [None]:
def add_collider_descendant_definition_rules(solver, X, Y, Z, N):
    solver.add_rule(descendant_of_collider(N, X, Y, Z), [collider(X, Y, Z), dpath(Y, N)])

In [None]:
def define_abaf_graph(n_nodes):
    solver = ABASolver()

    for X, Y in unique_product(range(n_nodes), repeat=2):
        add_acyclicity_rules(solver, X, Y)

        if X < Y:  # for X, Y unique combinations
            add_graph_edge_assumptions(solver, X, Y)

            for S in powerset(range(n_nodes)):
                add_non_blocking_rules(solver, X, Y, S, n_nodes)

        for Z in range(n_nodes):
            if Z not in {X, Y}:  # X, Y, Z unique
                add_direct_path_definition_rules(solver, X, Y, Z)

                if X < Z: 
                    # X < Z is to avoid duplicates as colliders are symmetric
                    add_collider_definition_rules(solver, X, Y, Z)

                    for N in range(n_nodes):
                        if N not in {X, Y, Z}:  # X, Y, Z, N unique
                            add_collider_descendant_definition_rules(solver, X, Y, Z, N)
                     
    return solver



### Test Run with Only Core

Must get 25 extensions for 3 nodes:  
- 3 possible options for each edge: 3 * 3 = 27 
- minus 2 extensions for 2 cycles: 27 - 2 = 25

In [None]:
solver = define_abaf_graph(3)
len(solver.enumerate_extensions('ST'))  

### Define Active/Blocked Path, Dep/Indep Assumptions and Corresponding Rules.

In [None]:
def add_path_definition_rules(solver, paths, X, Y):
    for path_id, my_path in enumerate(paths):
        # path definition
        solver.add_rule(path(X, Y, path_id), [edge(my_path[i], my_path[i+1]) 
                                                for i in range(len(my_path)-1)])

In [None]:
def add_indep_assumptions(solver, X, Y, S):
    solver.add_assumption(indep(X, Y, S))
    solver.add_contrary(indep(X, Y, S), contrary(indep(X, Y, S)))

In [None]:
def add_independence_rules(solver, paths, X, Y, S):
    indep_body = [blocked_path(X, Y, path_id, S) for path_id in range(len(paths))]
    if len(indep_body) > 0:  # avoid adding duplicate rule to the facts
        solver.add_rule(indep(X, Y, S), indep_body)

In [None]:
def add_blocked_path_assumptions(solver, path_id, X, Y, S):
    # active path definition
    solver.add_assumption(blocked_path(X, Y, path_id, S))
    solver.add_contrary(blocked_path(X, Y, path_id, S), 
                        contrary(blocked_path(X, Y, path_id, S)))

In [None]:
def add_dependence_rules(solver, path_id, path_nodes, X, Y, S):
    non_blocking_body = [non_blocking(path_nodes[i], path_nodes[i-1], path_nodes[i+1], S)
                            for i in range(1, len(path_nodes)-1)]
    solver.add_rule(active_path(X, Y, path_id, S), [path(X, Y, path_id), *non_blocking_body])
    solver.add_rule(dep(X, Y, S), [active_path(X, Y, path_id, S)])

In [None]:
def add_facts(solver, facts):
    for fact in facts:
        if fact.relation == RelationEnum.dep:
            solver.add_rule(dep(fact.node1, fact.node2, fact.node_set), [])
        else:
            solver.add_rule(indep(fact.node1, fact.node2, fact.node_set), [])

In [None]:

def append_path_rules_and_knowledge(solver, facts, n_nodes):
    '''
    Append path rules and knowledge to the solver.
    
    NOTE 1: add ap assumption and rule only for dep and indep present in facts
    NOTE 2: don't consider paths that contain edges where nodes are independent according to external facts

    '''
    # Only node pairs that are present in external facts are considered for active/blocked paths 
    node_pairs_considered = list()
    # not consider paths that have edges with nodes that are independent (for any set S)
    edges_to_remove = set()
    for fact in facts:
        node_pairs_considered.append((fact.node1, fact.node2))
        if fact.relation == RelationEnum.indep:
            edges_to_remove.add((fact.node1, fact.node2))
    
    graph = nx.complete_graph(n_nodes)
    # remove edges that are independent according to external facts
    graph.remove_edges_from(edges_to_remove)

    for (X, Y) in node_pairs_considered:
        paths = [tuple(p) for p in nx.all_simple_paths(graph, source=X, target=Y)]
        add_path_definition_rules(solver, paths, X, Y)

        for S in powerset(set(range(n_nodes)) - {X, Y}):
            add_indep_assumptions(solver, X, Y, S)

            # add rule for independence
            add_independence_rules(solver, paths, X, Y, S)

            # S-active path definition for each available path
            for path_id, my_path in enumerate(paths):        
                # active path definition
                add_blocked_path_assumptions(solver, path_id, X, Y, S)
                # active path rule
                add_dependence_rules(solver, path_id, my_path, X, Y, S)
    
    
    # finally add facts
    add_facts(solver, facts)

    return solver

## Test Run

In [None]:
solver = define_abaf_graph(5)
solver = append_path_rules_and_knowledge(solver, facts=facts[:3], n_nodes=5)

In [None]:

len(solver.enumerate_extensions('ST'))