In [187]:
# standard imports
import numpy as np
import pandas as pd
import pydot

# causal-learn imports
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.graph.Edge import Edge
from causallearn.utils.GraphUtils import GraphUtils

# pysat imports
from pysat.formula import CNF
from pysat.solvers import Glucose3

In [152]:
# Generate data where X -> Y <- Z
X = np.random.uniform(size=1000)
eps = np.random.normal(size=1000)
delta = np.random.uniform(size=1000)
Y = -7*X + 0.5*delta
Z = 2*X + Y + eps

# Create DataFrame with named variables
data = pd.DataFrame({'X': X, 'Y': Y, 'Z': Z})

In [153]:
# Store original column names
variable_names = list(data.columns)

# Run FCI
g, edges = fci(data.to_numpy(), alpha=0.05)

# Get the dot representation to be plotted
dot = GraphUtils.to_pydot(g)

# Relabel the nodes with original variable names
for i, node in enumerate(dot.get_nodes()):
    if node.get_name().isdigit() and int(node.get_name()) < len(variable_names):
        # Replace the node label with the original variable name
        node.set_label(f'"{variable_names[int(node.get_name())]}"')


dot.write_png('labeled_graph.png')

# rename the nodes with the original variable names since FCI do whatever he likes
for i, node in enumerate(g.nodes):
    node.name = variable_names[i]

Depth=0, working on node 2: 100%|██████████| 3/3 [00:00<00:00, 608.63it/s]


In [154]:
# translate the edges to a more readable format
def get_endpoint_type(endpoint: int, isFirst: bool):
    if endpoint == -1:
        return "-"
    elif endpoint == 1:
        return "<" if isFirst else ">"
    elif endpoint == 2:
        return "o"

def get_edge(edge: Edge):
    start = edge.numerical_endpoint_1
    end = edge.numerical_endpoint_2

    return f"{get_endpoint_type(start, True)}-{get_endpoint_type(end, False)}"


sat_clauses = []
formatted_edges = []
for edge in edges:
    formatted_edges.append((edge.node1.name, edge.node2.name, get_edge(edge)))

print(formatted_edges)

[('X', 'Y', 'o-o'), ('Y', 'Z', 'o-o')]


In [155]:
def get_unique_nodes(edges):
    nodes = set()
    for edge in edges:
        nodes.add(edge[0])
        nodes.add(edge[1])
    return nodes

nodes = get_unique_nodes(formatted_edges)

In [156]:
nodes

{'X', 'Y', 'Z'}

In [157]:
# create a variable mapping for the nodes, with all possible edge types
var_mapping = {}
def create_variable_mapping(nodes):
    for n1 in nodes:
        for n2 in nodes:
            for edge_type in ["direct", "latent"]:
                var_mapping[(n1, n2, edge_type)] = len(var_mapping) + 1
    return var_mapping

var_mapping = create_variable_mapping(nodes)

In [158]:
var_mapping

{('Z', 'Z', 'direct'): 1,
 ('Z', 'Z', 'latent'): 2,
 ('Z', 'Y', 'direct'): 3,
 ('Z', 'Y', 'latent'): 4,
 ('Z', 'X', 'direct'): 5,
 ('Z', 'X', 'latent'): 6,
 ('Y', 'Z', 'direct'): 7,
 ('Y', 'Z', 'latent'): 8,
 ('Y', 'Y', 'direct'): 9,
 ('Y', 'Y', 'latent'): 10,
 ('Y', 'X', 'direct'): 11,
 ('Y', 'X', 'latent'): 12,
 ('X', 'Z', 'direct'): 13,
 ('X', 'Z', 'latent'): 14,
 ('X', 'Y', 'direct'): 15,
 ('X', 'Y', 'latent'): 16,
 ('X', 'X', 'direct'): 17,
 ('X', 'X', 'latent'): 18}

In [159]:
# create the CNF clauses for the edge constraints
def add_edge_constraints(edges):
    cnf = []
    for n1, n2, edge_type in edges:
            if edge_type == '-->': # A is a direct cause of B

                # Direct causation must be true
                cnf.append([var_mapping[(n1, n2, 'direct')]])

                # No latent common cause
                cnf.append([-var_mapping[(n1, n2, 'latent')]])

            elif edge_type == 'o->': # B is not an ancestor of A

                # n2 cannot be ancestor of n1
                cnf.append([-var_mapping[(n2, n1, 'direct')]])

            elif edge_type == 'o-o': # no set d-separate A and B

                # Either direct causation or latent common cause must exist
                cnf.append([
                    var_mapping[(n1, n2, 'direct')],
                    var_mapping[(n2, n1, 'direct')],
                    var_mapping[(n1, n2, 'latent')]
                ])

            elif edge_type == '<->': # There is a latent common cause of A and B

                # Must have latent common cause
                cnf.append([var_mapping[(n1, n2, 'latent')]])

                # No direct causation in either direction
                cnf.append([-var_mapping[(n2, n1, 'direct')]])
                cnf.append([-var_mapping[(n1, n2, 'direct')]])

    return cnf

cnf = add_edge_constraints(formatted_edges)

In [160]:
cnf

[[15, 11, 16], [7, 3, 8]]

In [161]:
# iterate through the clauses and count the number of variables
variable_set = set()
for clause in cnf:
    for var in clause:
        variable_set.add(abs(var))

In [162]:
new_var = list(range(1, len(variable_set) + 1))
new_var

[1, 2, 3, 4, 5, 6]

In [163]:
# create a mapping from old variable to new variable
cnf_variable_mapping = {}
for i, var in enumerate(variable_set):
    cnf_variable_mapping[var] = new_var[i]

In [164]:
cnf_variable_mapping

{3: 1, 7: 2, 8: 3, 11: 4, 15: 5, 16: 6}

In [165]:
new_cnf = []
for clause in cnf:
    new_clause = []
    for var in clause:
        new_var = cnf_variable_mapping[abs(var)]
        new_clause.append(new_var if var > 0 else -new_var)
    new_cnf.append(new_clause)

In [166]:
new_cnf

[[5, 4, 6], [2, 1, 3]]

In [167]:
# create the formula as CNF
formula = CNF(from_clauses=new_cnf)

In [168]:
formula

CNF(from_string='p cnf 6 2\n5 4 6 0\n2 1 3 0')

In [169]:
solver = Glucose3()
solver.append_formula(formula)

In [170]:
is_sat = solver.solve()

In [171]:
is_sat

True

In [172]:
model = solver.get_model()

In [173]:
model

[1, -2, -3, 4, -5, -6]

In [174]:
# map back with cnf_variable_mapping
# reverse mapping
reverse_cnf_variable_mapping = {v: k for k, v in cnf_variable_mapping.items()}
reverse_cnf_variable_mapping

{1: 3, 2: 7, 3: 8, 4: 11, 5: 15, 6: 16}

In [175]:
temp = []
for item in model:
    temp.append(reverse_cnf_variable_mapping[abs(item)] if item > 0 else -reverse_cnf_variable_mapping[abs(item)])
model = temp

In [176]:
model

[3, -7, -8, 11, -15, -16]

In [177]:
# Create reverse mapping for interpretation
reverse_mapping = {v: k for k, v in var_mapping.items()}

In [178]:
reverse_mapping

{1: ('Z', 'Z', 'direct'),
 2: ('Z', 'Z', 'latent'),
 3: ('Z', 'Y', 'direct'),
 4: ('Z', 'Y', 'latent'),
 5: ('Z', 'X', 'direct'),
 6: ('Z', 'X', 'latent'),
 7: ('Y', 'Z', 'direct'),
 8: ('Y', 'Z', 'latent'),
 9: ('Y', 'Y', 'direct'),
 10: ('Y', 'Y', 'latent'),
 11: ('Y', 'X', 'direct'),
 12: ('Y', 'X', 'latent'),
 13: ('X', 'Z', 'direct'),
 14: ('X', 'Z', 'latent'),
 15: ('X', 'Y', 'direct'),
 16: ('X', 'Y', 'latent'),
 17: ('X', 'X', 'direct'),
 18: ('X', 'X', 'latent')}

In [145]:
causal_relationship = []

for item in model:
    absolute_value = abs(item)
    if absolute_value in reverse_mapping:
        node1, node2, edge = reverse_mapping[absolute_value]
        causal_relationship.append({
            "node1": node1,
            "node2": node2,
            "edge": edge,
            "exists": True if item > 0 else False
        })

In [146]:
causal_relationship

[{'node1': 'Z', 'node2': 'Y', 'edge': 'direct', 'exists': True},
 {'node1': 'Y', 'node2': 'Z', 'edge': 'direct', 'exists': False},
 {'node1': 'Y', 'node2': 'Z', 'edge': 'latent', 'exists': False},
 {'node1': 'Y', 'node2': 'X', 'edge': 'direct', 'exists': True},
 {'node1': 'X', 'node2': 'Y', 'edge': 'direct', 'exists': False},
 {'node1': 'X', 'node2': 'Y', 'edge': 'latent', 'exists': False}]

In [147]:
solver.delete()

In [148]:
direct_causes = [rel for rel in causal_relationship if rel["edge"] == "direct" and rel["exists"]]
latent_causes = [rel for rel in causal_relationship if rel["edge"] == "latent" and rel["exists"]]

In [149]:
for rel in direct_causes:
    print(f"{rel['node1']} -> {rel['node2']}")

Z -> Y
Y -> X


In [150]:
for rel in latent_causes:
    print(f"{rel['node1']} <- {rel['node2']}")

In [192]:
# create a set of variables which will be the nodes
nodes = set()
for rel in causal_relationship:
    nodes.add(rel["node1"])
    nodes.add(rel["node2"])
graph = pydot.Dot("my_graph", graph_type="digraph")
for node in nodes:
    graph.add_node(pydot.Node(node))

for rel in causal_relationship:
    if rel["edge"] == "direct" and rel["exists"]:
        graph.add_edge(pydot.Edge(rel["node1"], rel["node2"], arrowhead=""))
    elif rel["edge"] == "latent" and rel["exists"]:
        graph.add_edge(pydot.Edge(rel["node2"], rel["node1"], arrowhead="normal"))

In [193]:
graph.write_png("output.png")