In [40]:
# standard imports
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# causal-learn imports
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.graph.Edge import Edge
from causallearn.utils.GraphUtils import GraphUtils

# pysat imports
from pysat.formula import CNF
from pysat.solvers import Glucose3

from main import direct_causes, latent_causes

In [41]:
# Generate data where X -> Y <- Z
X = np.random.uniform(size=1000)
eps = np.random.normal(size=1000)
delta = np.random.uniform(size=1000)
Z = -7*X + 0.5*delta
Y = 2*X + Z + eps

# Create DataFrame with named variables
data = pd.DataFrame({'X': X, 'Y': Y, 'Z': Z})

In [42]:
# Store original column names
variable_names = list(data.columns)

# Run FCI
g, edges = fci(data.to_numpy(), alpha=0.2)

# Get the dot representation to be plotted
dot = GraphUtils.to_pydot(g)

# Relabel the nodes with original variable names
for i, node in enumerate(dot.get_nodes()):
    if node.get_name().isdigit() and int(node.get_name()) < len(variable_names):
        # Replace the node label with the original variable name
        node.set_label(f'"{variable_names[int(node.get_name())]}"')


dot.write_png('labeled_graph.png')

# rename the nodes with the original variable names since FCI do whatever he likes
for i, node in enumerate(g.nodes):
    node.name = variable_names[i]

Depth=0, working on node 2: 100%|██████████| 3/3 [00:00<?, ?it/s]


In [43]:
# translate the edges to a more readable format
def get_endpoint_type(endpoint: int, isFirst: bool):
    if endpoint == -1:
        return "-"
    elif endpoint == 1:
        return "<" if isFirst else ">"
    elif endpoint == 2:
        return "o"

def get_edge(edge: Edge):
    start = edge.numerical_endpoint_1
    end = edge.numerical_endpoint_2

    return f"{get_endpoint_type(start, True)}-{get_endpoint_type(end, False)}"


sat_clauses = []
formatted_edges = []
for edge in edges:
    formatted_edges.append((edge.node1.name, edge.node2.name, get_edge(edge)))

print(formatted_edges)

[('X', 'Z', 'o-o'), ('Y', 'Z', 'o-o')]


In [44]:
def get_unique_nodes(edges):
    nodes = set()
    for edge in edges:
        nodes.add(edge[0])
        nodes.add(edge[1])
    return nodes

nodes = get_unique_nodes(formatted_edges)

In [45]:
nodes

{'X', 'Y', 'Z'}

In [46]:
# create a variable mapping for the nodes, with all possible edge types
var_mapping = {}
def create_variable_mapping(nodes):
    for n1 in nodes:
        for n2 in nodes:
            for edge_type in ["direct", "latent"]:
                var_mapping[(n1, n2, edge_type)] = len(var_mapping) + 1
    return var_mapping

var_mapping = create_variable_mapping(nodes)

In [47]:
var_mapping

{('Z', 'Z', 'direct'): 1,
 ('Z', 'Z', 'latent'): 2,
 ('Z', 'X', 'direct'): 3,
 ('Z', 'X', 'latent'): 4,
 ('Z', 'Y', 'direct'): 5,
 ('Z', 'Y', 'latent'): 6,
 ('X', 'Z', 'direct'): 7,
 ('X', 'Z', 'latent'): 8,
 ('X', 'X', 'direct'): 9,
 ('X', 'X', 'latent'): 10,
 ('X', 'Y', 'direct'): 11,
 ('X', 'Y', 'latent'): 12,
 ('Y', 'Z', 'direct'): 13,
 ('Y', 'Z', 'latent'): 14,
 ('Y', 'X', 'direct'): 15,
 ('Y', 'X', 'latent'): 16,
 ('Y', 'Y', 'direct'): 17,
 ('Y', 'Y', 'latent'): 18}

In [48]:
# create the CNF clauses for the edge constraints
def add_edge_constraints(edges):
    cnf = []
    for n1, n2, edge_type in edges:
            if edge_type == '-->':

                # Direct causation must be true
                cnf.append([var_mapping[(n1, n2, 'direct')]])

                # No latent common cause
                cnf.append([-var_mapping[(n1, n2, 'latent')]])

            elif edge_type == 'o->':

                # n2 cannot be ancestor of n1
                cnf.append([-var_mapping[(n2, n1, 'direct')]])

            elif edge_type == 'o-o':

                # Either direct causation or latent common cause must exist
                cnf.append([
                    var_mapping[(n1, n2, 'direct')],
                    var_mapping[(n1, n2, 'latent')]
                ])

            elif edge_type == '<->':

                # Must have latent common cause
                cnf.append([var_mapping[(n1, n2, 'latent')]])

                # No direct causation in either direction
                cnf.append([-var_mapping[(n2, n1, 'direct')]])
                cnf.append([-var_mapping[(n1, n2, 'direct')]])

    return cnf

cnf = add_edge_constraints(formatted_edges)

In [49]:
cnf

[[7, 8], [13, 14]]

In [69]:
# create the formula as CNF
formula = CNF(from_clauses=cnf)

In [70]:
formula

CNF(from_string='p cnf 14 2\n7 8 0\n13 14 0')

In [71]:
solver = Glucose3()
solver.append_formula(formula)

In [72]:
is_sat = solver.solve()

In [73]:
is_sat

True

In [74]:
model = solver.get_model()

In [75]:
model

[-1, -2, -3, -4, -5, -6, 7, -8, -9, -10, -11, -12, 13, -14]

In [57]:
# Create reverse mapping for interpretation
reverse_mapping = {v: k for k, v in var_mapping.items()}

In [58]:
reverse_mapping

{1: ('Z', 'Z', 'direct'),
 2: ('Z', 'Z', 'latent'),
 3: ('Z', 'X', 'direct'),
 4: ('Z', 'X', 'latent'),
 5: ('Z', 'Y', 'direct'),
 6: ('Z', 'Y', 'latent'),
 7: ('X', 'Z', 'direct'),
 8: ('X', 'Z', 'latent'),
 9: ('X', 'X', 'direct'),
 10: ('X', 'X', 'latent'),
 11: ('X', 'Y', 'direct'),
 12: ('X', 'Y', 'latent'),
 13: ('Y', 'Z', 'direct'),
 14: ('Y', 'Z', 'latent'),
 15: ('Y', 'X', 'direct'),
 16: ('Y', 'X', 'latent'),
 17: ('Y', 'Y', 'direct'),
 18: ('Y', 'Y', 'latent')}

In [59]:
causal_relationship = []

for item in model:
    absolute_value = abs(item)
    if absolute_value in reverse_mapping:
        node1, node2, edge = reverse_mapping[absolute_value]
        causal_relationship.append({
            "node1": node1,
            "node2": node2,
            "edge": edge,
            "exists": True if item > 0 else False
        })

In [60]:
causal_relationship

[{'node1': 'Z', 'node2': 'Z', 'edge': 'direct', 'exists': False},
 {'node1': 'Z', 'node2': 'Z', 'edge': 'latent', 'exists': False},
 {'node1': 'Z', 'node2': 'X', 'edge': 'direct', 'exists': False},
 {'node1': 'Z', 'node2': 'X', 'edge': 'latent', 'exists': False},
 {'node1': 'Z', 'node2': 'Y', 'edge': 'direct', 'exists': False},
 {'node1': 'Z', 'node2': 'Y', 'edge': 'latent', 'exists': False},
 {'node1': 'X', 'node2': 'Z', 'edge': 'direct', 'exists': True},
 {'node1': 'X', 'node2': 'Z', 'edge': 'latent', 'exists': False},
 {'node1': 'X', 'node2': 'X', 'edge': 'direct', 'exists': False},
 {'node1': 'X', 'node2': 'X', 'edge': 'latent', 'exists': False},
 {'node1': 'X', 'node2': 'Y', 'edge': 'direct', 'exists': False},
 {'node1': 'X', 'node2': 'Y', 'edge': 'latent', 'exists': False},
 {'node1': 'Y', 'node2': 'Z', 'edge': 'direct', 'exists': True},
 {'node1': 'Y', 'node2': 'Z', 'edge': 'latent', 'exists': False}]

In [61]:
solver.delete()

In [62]:
direct_causes = [rel for rel in causal_relationship if rel["edge"] == "direct" and rel["exists"]]
latent_causes = [rel for rel in causal_relationship if rel["edge"] == "latent" and rel["exists"]]

In [65]:
for rel in direct_causes:
    print(f"{rel['node1']} -> {rel['node2']}")

X -> Z
Y -> Z


In [67]:
for rel in latent_causes:
    print(f"{rel['node1']} <- {rel['node2']}")