In [97]:
# basic imports
import pandas as pd
import numpy as np
import pydot

# causallearn imports
from causallearn.search.ConstraintBased.PC import pc

# pysat imports
from pysat.formula import CNF
from pysat.solvers import Glucose3

Generate some synthetic data

In [57]:
# Generate data where X -> Y <- Z
X = np.random.uniform(size=10000)
eps = np.random.normal(size=10000)
delta = np.random.uniform(size=10000)
Y = -7*X + 0.5*delta
Z = 2*X + Y + eps

# Create DataFrame with named variables
data = pd.DataFrame({'X': X, 'Y': Y, 'Z': Z})

Run PC

In [58]:
# Store original column names
variable_names = list(data.columns)

g = pc(data.to_numpy())

Depth=1, working on node 2: 100%|██████████| 3/3 [00:00<00:00, 530.34it/s] 


In [59]:
for i, node in enumerate(g.G.nodes):
    node.name = variable_names[i]

In [60]:
g.G.graph

array([[ 0, -1, -1],
       [-1,  0, -1],
       [-1, -1,  0]])

In [61]:
node_mapping = {node.name: index for index, node in enumerate(g.G.nodes)}
node_mapping


{'X': 0, 'Y': 1, 'Z': 2}

In [62]:
reversed_node_mapping = {index: node.name for index, node in enumerate(g.G.nodes)}
reversed_node_mapping

{0: 'X', 1: 'Y', 2: 'Z'}

Extract edges, since some of them are not order dependent

In [63]:
edges = []
indices = np.where(g.G.graph != 0)
processed_pairs = set()

for i, j in zip(indices[0], indices[1]):
    
    node_pair = frozenset([i.item(), j.item()])
    
    if node_pair in processed_pairs:
        continue
        
    if g.G.graph[i,j] == 1 and g.G.graph[j,i] == -1:
        edges.append({
            'from': reversed_node_mapping[i.item()],
            'to': reversed_node_mapping[j.item()],
            'type': "->"
        })
    
    elif g.G.graph[i,j] == -1 and g.G.graph[j,i] == -1:
        edges.append({
            'from': reversed_node_mapping[i.item()],
            'to': reversed_node_mapping[j.item()],
            'type': "--"
        })
        processed_pairs.add(node_pair) 
    
    elif g.G.graph[i,j] == 1 and g.G.graph[j,i] == 1:
        edges.append({
            'from': reversed_node_mapping[i.item()],
            'to': reversed_node_mapping[j.item()],
            'type': "<->"
        })
        processed_pairs.add(node_pair)

In [64]:
edges

[{'from': 'X', 'to': 'Y', 'type': '--'},
 {'from': 'X', 'to': 'Z', 'type': '--'},
 {'from': 'Y', 'to': 'Z', 'type': '--'}]

In [65]:
causal_dict = {}
for node1 in node_mapping:
    for node2 in node_mapping:
        for edge in ['direct']:
            causal_dict[(node1, node2, edge)] = len(causal_dict) + 1

In [66]:
causal_dict

{('X', 'X', 'direct'): 1,
 ('X', 'Y', 'direct'): 2,
 ('X', 'Z', 'direct'): 3,
 ('Y', 'X', 'direct'): 4,
 ('Y', 'Y', 'direct'): 5,
 ('Y', 'Z', 'direct'): 6,
 ('Z', 'X', 'direct'): 7,
 ('Z', 'Y', 'direct'): 8,
 ('Z', 'Z', 'direct'): 9}

Now we need to create logical Clauses from the edges. Logical clauses should be in CNF formula, which is a set of clauses is in or, but the clauses togeter are in and.

In [67]:
SATClauses = []
index = 0

for item in edges:
    
    if item['type'] == '->':
        # there MUST be a direct edge from node1 to node2 and NO direct edge from node2 to node1
        SATClauses.append([causal_dict[(item['from'], item['to'], 'direct')]])
        SATClauses.append([-causal_dict[(item['to'], item['from'], 'direct')]])
    elif item['type'] == '--':
        # there MUST be a direct edge from node1 to node2 OR a direct edge from node2 to node1
        SATClauses.append([causal_dict[(item['from'], item['to'], 'direct')], causal_dict[(item['to'], item['from'], 'direct')]])
    elif item['type'] == '<->':
        # there MUSTN'T be a direct edge from node1 to node2 AND a direct edge from node2 to node1
        SATClauses.append([-causal_dict[(item['from'], item['to'], 'direct')]])
        SATClauses.append([-causal_dict[(item['to'], item['from'], 'direct')]])

In [68]:
SATClauses

[[2, 4], [3, 7], [6, 8]]

In [74]:
# iterate through the clauses and count the number of variables
variable_set = set()
for clause in SATClauses:
    for var in clause:
        variable_set.add(abs(var))

In [75]:
new_var = list(range(1, len(variable_set) + 1))
new_var

[1, 2, 3, 4, 5, 6]

In [76]:
# create a mapping from old variable to new variable
cnf_variable_mapping = {}
for i, var in enumerate(variable_set):
    cnf_variable_mapping[var] = new_var[i]
    
cnf_variable_mapping

{2: 1, 3: 2, 4: 3, 6: 4, 7: 5, 8: 6}

In [79]:
new_cnf = []
for clause in SATClauses:
    new_clause = []
    for var in clause:
        new_var = cnf_variable_mapping[abs(var)]
        new_clause.append(new_var if var > 0 else -new_var)
    new_cnf.append(new_clause)

new_cnf

[[1, 3], [2, 5], [4, 6]]

In [80]:
formula = CNF(from_clauses=new_cnf)
formula

CNF(from_string='p cnf 6 3\n1 3 0\n2 5 0\n4 6 0')

In [81]:
solver = Glucose3()
solver.append_formula(formula)

In [82]:
isSat = solver.solve()
isSat

True

In [83]:
model = solver.get_model()
model

[1, 2, -3, 4, -5, -6]

In [84]:
# map back with cnf_variable_mapping
# reverse mapping
reverse_cnf_variable_mapping = {v: k for k, v in cnf_variable_mapping.items()}
reverse_cnf_variable_mapping

{1: 2, 2: 3, 3: 4, 4: 6, 5: 7, 6: 8}

In [85]:
temp = []
for item in model:
    temp.append(reverse_cnf_variable_mapping[abs(item)] if item > 0 else -reverse_cnf_variable_mapping[abs(item)])
model = temp
model

[2, 3, -4, 6, -7, -8]

In [87]:
# Create reverse mapping for interpretation
reversed_causal_dict = {v: k for k, v in causal_dict.items()}

In [91]:
causal_relationship = []

for item in model:
    absolute_value = abs(item)
    if absolute_value in reversed_causal_dict:
        node1, node2, edge = reversed_causal_dict[absolute_value]
        causal_relationship.append({
            "node1": node1,
            "node2": node2,
            "edge": edge,
            "exists": True if item > 0 else False
        })
causal_relationship

[{'node1': 'X', 'node2': 'Y', 'edge': 'direct', 'exists': True},
 {'node1': 'X', 'node2': 'Z', 'edge': 'direct', 'exists': True},
 {'node1': 'Y', 'node2': 'X', 'edge': 'direct', 'exists': False},
 {'node1': 'Y', 'node2': 'Z', 'edge': 'direct', 'exists': True},
 {'node1': 'Z', 'node2': 'X', 'edge': 'direct', 'exists': False},
 {'node1': 'Z', 'node2': 'Y', 'edge': 'direct', 'exists': False}]

In [92]:
solver.delete()

In [94]:
direct_causes = [rel for rel in causal_relationship if rel["edge"] == "direct" and rel["exists"]]

In [95]:
for rel in direct_causes:
    print(f"{rel['node1']} -> {rel['node2']}")

X -> Y
X -> Z
Y -> Z


In [98]:
# create a set of variables which will be the nodes
nodes = set()
for rel in causal_relationship:
    nodes.add(rel["node1"])
    nodes.add(rel["node2"])
graph = pydot.Dot("my_graph", graph_type="digraph")
for node in nodes:
    graph.add_node(pydot.Node(node))

for rel in causal_relationship:
    if rel["edge"] == "direct" and rel["exists"]:
        graph.add_edge(pydot.Edge(rel["node1"], rel["node2"], arrowhead=""))
    elif rel["edge"] == "latent" and rel["exists"]:
        graph.add_edge(pydot.Edge(rel["node2"], rel["node1"], arrowhead="normal"))

In [99]:
graph.write_png("output/output.png")