In [126]:
# standard imports
import numpy as np
import pandas as pd
import pydot

# causal-learn imports
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.search.ConstraintBased.PC import pc
from causallearn.graph.Edge import Edge
from causallearn.utils.GraphUtils import GraphUtils

# pysat imports
from pysat.formula import CNF
from pysat.solvers import Glucose3

In [127]:
# Generate data where X -> Y <- Z
X = np.random.uniform(size=10000)
eps = np.random.normal(size=10000)
delta = np.random.uniform(size=10000)
Y = -7*X + 0.5*delta
Z = 2*X + Y + eps

# Create DataFrame with named variables
data = pd.DataFrame({'X': X, 'Y': Y, 'Z': Z})

In [128]:
# Store original column names
variable_names = list(data.columns)

g = pc(data.to_numpy())

# Get the dot representation to be plotted
dot = GraphUtils.to_pydot(g.G)

# Relabel the nodes with original variable names
# for i, node in enumerate(dot.get_nodes()):
#     if node.get_name().isdigit() and int(node.get_name()) < len(variable_names):
#         # Replace the node label with the original variable name
#         node.set_label(f'"{variable_names[int(node.get_name())]}"')

dot.write_png('output/pc_graph.png')

# Run FCI
g, edges = fci(data.to_numpy())

# Get the dot representation to be plotted
dot = GraphUtils.to_pydot(g)

# Relabel the nodes with original variable names
for i, node in enumerate(dot.get_nodes()):
    if node.get_name().isdigit() and int(node.get_name()) < len(variable_names):
        # Replace the node label with the original variable name
        node.set_label(f'"{variable_names[int(node.get_name())]}"')


dot.write_png('output/fci_graph.png')

# rename the nodes with the original variable names since FCI do whatever he likes
for i, node in enumerate(g.nodes):
    node.name = variable_names[i]

Depth=1, working on node 2: 100%|██████████| 3/3 [00:00<00:00, 126.37it/s]
Depth=0, working on node 2: 100%|██████████| 3/3 [00:00<00:00, 169.34it/s]


In [129]:
variable_names

['X', 'Y', 'Z']

In [130]:
# translate the edges to a more readable format
def get_endpoint_type(endpoint: int, isFirst: bool):
    if endpoint == -1:
        return "-"
    elif endpoint == 1:
        return "<" if isFirst else ">"
    elif endpoint == 2:
        return "o"

def get_edge(edge: Edge):
    start = edge.numerical_endpoint_1
    end = edge.numerical_endpoint_2

    return f"{get_endpoint_type(start, True)}-{get_endpoint_type(end, False)}"


sat_clauses = []
formatted_edges = []
for edge in edges:
    formatted_edges.append((edge.node1.name, edge.node2.name, get_edge(edge)))

print(formatted_edges)

[('X', 'Y', 'o-o'), ('X', 'Z', 'o-o'), ('Y', 'Z', 'o-o')]


In [131]:
def get_unique_nodes(edges):
    nodes = set()
    for edge in edges:
        nodes.add(edge[0])
        nodes.add(edge[1])
    return nodes

nodes = get_unique_nodes(formatted_edges)

In [132]:
nodes

{'X', 'Y', 'Z'}

In [133]:
# create a variable mapping for the nodes, with all possible edge types
var_mapping = {}
def create_variable_mapping(nodes):
    for n1 in nodes:
        for n2 in nodes:
            for edge_type in ["direct", "latent", "transitive"]:
                var_mapping[(n1, n2, edge_type)] = len(var_mapping) + 1
    return var_mapping

var_mapping = create_variable_mapping(nodes)

In [134]:
var_mapping

{('Y', 'Y', 'direct'): 1,
 ('Y', 'Y', 'latent'): 2,
 ('Y', 'Y', 'transitive'): 3,
 ('Y', 'Z', 'direct'): 4,
 ('Y', 'Z', 'latent'): 5,
 ('Y', 'Z', 'transitive'): 6,
 ('Y', 'X', 'direct'): 7,
 ('Y', 'X', 'latent'): 8,
 ('Y', 'X', 'transitive'): 9,
 ('Z', 'Y', 'direct'): 10,
 ('Z', 'Y', 'latent'): 11,
 ('Z', 'Y', 'transitive'): 12,
 ('Z', 'Z', 'direct'): 13,
 ('Z', 'Z', 'latent'): 14,
 ('Z', 'Z', 'transitive'): 15,
 ('Z', 'X', 'direct'): 16,
 ('Z', 'X', 'latent'): 17,
 ('Z', 'X', 'transitive'): 18,
 ('X', 'Y', 'direct'): 19,
 ('X', 'Y', 'latent'): 20,
 ('X', 'Y', 'transitive'): 21,
 ('X', 'Z', 'direct'): 22,
 ('X', 'Z', 'latent'): 23,
 ('X', 'Z', 'transitive'): 24,
 ('X', 'X', 'direct'): 25,
 ('X', 'X', 'latent'): 26,
 ('X', 'X', 'transitive'): 27}

In [135]:
# Initialize a counter for variable IDs
next_var_id = 1

def get_next_var_id():
    global next_var_id
    var_id = next_var_id
    next_var_id += 1
    return var_id


# create the CNF clauses for the edge constraints
def add_edge_constraints(edges, all_nodes):
    cnf = []
    for n1, n2, edge_type in edges:
            if edge_type == '-->': # A is a direct cause of B

                # Direct causation must be true
                cnf.append([var_mapping[(n1, n2, 'direct')]])
                cnf.append([-var_mapping[(n2, n1, 'direct')]])

                # No latent common cause
                cnf.append([-var_mapping[(n1, n2, 'latent')]])

            elif edge_type == 'o->': # B is not an ancestor of A

                cnf.append([-var_mapping[(n2, n1, 'direct')]])
                cnf.append([var_mapping((n1, n2, 'direct'))])
                cnf.append([var_mapping((n2, n1, 'latent'))])
                cnf.append([var_mapping((n1, n2, 'latent'))])

                # # For ancestral relationships, we need to prevent all paths from B to A
                # # This requires additional variables to represent transitive relationships
                # for intermediate in all_nodes:
                #     if intermediate != n1 and intermediate != n2:
                #         # If B→C and C→A, then B is an ancestor of A, which is prohibited
                #         cnf.append([
                #             -var_mapping[(n2, intermediate, 'direct')],
                #             -var_mapping[(intermediate, n1, 'direct')]
                #         ])
                #
                #         # For longer paths, we would need to recursively consider all possible paths
                #         # This is complicated in pure SAT, but can be handled more easily with auxiliary variables

            elif edge_type == 'o-o': # no set d-separate A and B

                # Either direct causation or latent common cause must exist
                cnf.append([
                    var_mapping[(n1, n2, 'direct')],
                    var_mapping[(n2, n1, 'direct')],
                    var_mapping[(n1, n2, 'latent')],
                    var_mapping[(n2, n1, 'latent')],
                ])

            elif edge_type == '<->': # There is a latent common cause of A and B

                # Must have latent common cause
                cnf.append([var_mapping[(n1, n2, 'latent')]])
                cnf.append([var_mapping[(n2, n1, 'latent')]])

                # No direct causation in either direction
                cnf.append([-var_mapping[(n2, n1, 'direct')]])
                cnf.append([-var_mapping[(n1, n2, 'direct')]])

    return cnf

def add_transitive_closure_constraints(all_nodes, var_mapping):
    cnf = []

    # Create mapping for transitive relationships
    for i, node_i in enumerate(all_nodes):
        for j, node_j in enumerate(all_nodes):
            if i != j:
                # Define: transitive(i,j) iff i is an ancestor of j through any path
                var_mapping[(node_i, node_j, 'transitive')] = get_next_var_id()

                # Direct edge implies transitive relationship
                cnf.append([-var_mapping[(node_i, node_j, 'direct')],
                           var_mapping[(node_i, node_j, 'transitive')]])

                # Build transitive relationships
                for k, node_k in enumerate(all_nodes):
                    if i != k and j != k:
                        # If i→k and k→j transitively, then i→j transitively
                        cnf.append([
                            -var_mapping[(node_i, node_k, 'transitive')],
                            -var_mapping[(node_k, node_j, 'transitive')],
                            var_mapping[(node_i, node_j, 'transitive')]
                        ])

    return cnf

def add_no_ancestor_constraints(edges, var_mapping):
    cnf = []

    for n1, n2, edge_type in edges:
        if edge_type == 'o->':  # B is not an ancestor of A
            # Use the transitive relationship variable to enforce no ancestry
            cnf.append([-var_mapping[(n2, n1, 'transitive')]])

    return cnf

In [136]:
cnf = add_edge_constraints(formatted_edges, nodes)

# Add transitive closure constraints
# cnf.extend(add_transitive_closure_constraints(nodes, var_mapping))

# Add the no-ancestor constraints
# cnf.extend(add_no_ancestor_constraints(formatted_edges, var_mapping))

In [137]:
cnf

[[19, 7, 20, 8], [22, 16, 23, 17], [4, 10, 5, 11]]

In [138]:
# iterate through the clauses and count the number of variables
variable_set = set()
for clause in cnf:
    for var in clause:
        variable_set.add(abs(var))

In [139]:
new_var = list(range(1, len(variable_set) + 1))
new_var

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

In [140]:
# create a mapping from old variable to new variable
cnf_variable_mapping = {}
for i, var in enumerate(variable_set):
    cnf_variable_mapping[var] = new_var[i]

In [141]:
cnf_variable_mapping

{4: 1,
 5: 2,
 7: 3,
 8: 4,
 10: 5,
 11: 6,
 16: 7,
 17: 8,
 19: 9,
 20: 10,
 22: 11,
 23: 12}

In [142]:
new_cnf = []
for clause in cnf:
    new_clause = []
    for var in clause:
        new_var = cnf_variable_mapping[abs(var)]
        new_clause.append(new_var if var > 0 else -new_var)
    new_cnf.append(new_clause)

In [143]:
new_cnf

[[9, 3, 10, 4], [11, 7, 12, 8], [1, 5, 2, 6]]

In [144]:
# create the formula as CNF
formula = CNF(from_clauses=new_cnf)

In [145]:
formula

CNF(from_string='p cnf 12 3\n9 3 10 4 0\n11 7 12 8 0\n1 5 2 6 0')

In [146]:
solver = Glucose3()
solver.append_formula(formula)

In [147]:
is_sat = solver.solve()

In [148]:
is_sat

True

In [149]:
model = solver.get_model()

In [150]:
model

[1, -2, 3, -4, -5, -6, 7, -8, -9, -10, -11, -12]

In [151]:
# map back with cnf_variable_mapping
# reverse mapping
reverse_cnf_variable_mapping = {v: k for k, v in cnf_variable_mapping.items()}
reverse_cnf_variable_mapping

{1: 4,
 2: 5,
 3: 7,
 4: 8,
 5: 10,
 6: 11,
 7: 16,
 8: 17,
 9: 19,
 10: 20,
 11: 22,
 12: 23}

In [152]:
temp = []
for item in model:
    temp.append(reverse_cnf_variable_mapping[abs(item)] if item > 0 else -reverse_cnf_variable_mapping[abs(item)])
model = temp

In [153]:
model

[4, -5, 7, -8, -10, -11, 16, -17, -19, -20, -22, -23]

In [154]:
# Create reverse mapping for interpretation
reverse_mapping = {v: k for k, v in var_mapping.items()}

In [155]:
reverse_mapping

{1: ('Y', 'Y', 'direct'),
 2: ('Y', 'Y', 'latent'),
 3: ('Y', 'Y', 'transitive'),
 4: ('Y', 'Z', 'direct'),
 5: ('Y', 'Z', 'latent'),
 6: ('Y', 'Z', 'transitive'),
 7: ('Y', 'X', 'direct'),
 8: ('Y', 'X', 'latent'),
 9: ('Y', 'X', 'transitive'),
 10: ('Z', 'Y', 'direct'),
 11: ('Z', 'Y', 'latent'),
 12: ('Z', 'Y', 'transitive'),
 13: ('Z', 'Z', 'direct'),
 14: ('Z', 'Z', 'latent'),
 15: ('Z', 'Z', 'transitive'),
 16: ('Z', 'X', 'direct'),
 17: ('Z', 'X', 'latent'),
 18: ('Z', 'X', 'transitive'),
 19: ('X', 'Y', 'direct'),
 20: ('X', 'Y', 'latent'),
 21: ('X', 'Y', 'transitive'),
 22: ('X', 'Z', 'direct'),
 23: ('X', 'Z', 'latent'),
 24: ('X', 'Z', 'transitive'),
 25: ('X', 'X', 'direct'),
 26: ('X', 'X', 'latent'),
 27: ('X', 'X', 'transitive')}

In [156]:
causal_relationship = []

for item in model:
    absolute_value = abs(item)
    if absolute_value in reverse_mapping:
        node1, node2, edge = reverse_mapping[absolute_value]
        causal_relationship.append({
            "node1": node1,
            "node2": node2,
            "edge": edge,
            "exists": True if item > 0 else False
        })

In [157]:
causal_relationship

[{'node1': 'Y', 'node2': 'Z', 'edge': 'direct', 'exists': True},
 {'node1': 'Y', 'node2': 'Z', 'edge': 'latent', 'exists': False},
 {'node1': 'Y', 'node2': 'X', 'edge': 'direct', 'exists': True},
 {'node1': 'Y', 'node2': 'X', 'edge': 'latent', 'exists': False},
 {'node1': 'Z', 'node2': 'Y', 'edge': 'direct', 'exists': False},
 {'node1': 'Z', 'node2': 'Y', 'edge': 'latent', 'exists': False},
 {'node1': 'Z', 'node2': 'X', 'edge': 'direct', 'exists': True},
 {'node1': 'Z', 'node2': 'X', 'edge': 'latent', 'exists': False},
 {'node1': 'X', 'node2': 'Y', 'edge': 'direct', 'exists': False},
 {'node1': 'X', 'node2': 'Y', 'edge': 'latent', 'exists': False},
 {'node1': 'X', 'node2': 'Z', 'edge': 'direct', 'exists': False},
 {'node1': 'X', 'node2': 'Z', 'edge': 'latent', 'exists': False}]

In [158]:
solver.delete()

In [159]:
direct_causes = [rel for rel in causal_relationship if rel["edge"] == "direct" and rel["exists"]]
latent_causes = [rel for rel in causal_relationship if rel["edge"] == "latent" and rel["exists"]]

In [160]:
for rel in direct_causes:
    print(f"{rel['node1']} -> {rel['node2']}")

Y -> Z
Y -> X
Z -> X


In [161]:
for rel in latent_causes:
    print(f"{rel['node1']} <- {rel['node2']}")

In [162]:
# create a set of variables which will be the nodes
nodes = set()
for rel in causal_relationship:
    nodes.add(rel["node1"])
    nodes.add(rel["node2"])
graph = pydot.Dot("my_graph", graph_type="digraph")
for node in nodes:
    graph.add_node(pydot.Node(node))

for rel in causal_relationship:
    if rel["edge"] == "direct" and rel["exists"]:
        graph.add_edge(pydot.Edge(rel["node1"], rel["node2"], arrowhead=""))
    elif rel["edge"] == "latent" and rel["exists"]:
        graph.add_edge(pydot.Edge(rel["node2"], rel["node1"], arrowhead="normal"))

In [163]:
graph.write_png("output/output.png")