In [1]:
import causaldag as cd
from causaldag.inference.structural import gsp
from causaldag.utils.ci_tests import gauss_ci_test
from causaldag.utils.ci_tests import hsic_test
import numpy as np
from pprint import pprint
import random

Create the graph $0 \rightarrow 1 \leftarrow 2$

In [2]:
dag = cd.DAG(arcs={(0, 1), (2, 1)})

Turn the graph into GaussDAG, which will allow us to sample from it. Use random edge weights to avoid faithfulness violation. By default, edge weights are sampled uniformly from $\pm[.25, 1]$.

In [3]:
gdag = cd.rand.rand_weights(dag)

Take $n$ samples

In [4]:
nsamples = 500
np.random.seed(1729)
random.seed(1729)
samples = gdag.sample(nsamples)

Form the sufficient statistics dictionary for the CI test.

*The gauss_ci test requires a correlation matrix and the number of samples*

In [5]:
corr = np.corrcoef(samples, rowvar=False)
suffstat = dict(C=corr, n=nsamples)

Run GSP

In [6]:
nnodes = 3  # this could be inferred from the sufficient statistics in the future
np.random.seed(1729)
random.seed(1729)
est_dag, summaries = gsp(suffstat, nnodes, gauss_ci_test, alpha=.05, depth=4, nruns=30)

Print the result. The convention for displaying a DAG as a string follows pcalg/bnlearn in R: [i|j,k,l] means that j,k, and l are parents of i.

In [7]:
print(est_dag)

[2][0][1|0,2]


GSP returns the smallest DAG found over the course of multiple runs of the algorithm. `summaries` is a list containing details about each run. 

Each summary run's summary lists the DAGs in the order they were visited, their sparsity, and the search depth of the depth-first search procedure.

In [8]:
pprint(summaries[0])  # in this run, the starting DAG had no covered edges

[{'dag': [2][0][1|0,2], 'depth': 0, 'num_arcs': 2}]


In [9]:
pprint(summaries[3])  # this run is less trivial

[{'dag': [1][2|1][0|1,2], 'depth': 0, 'num_arcs': 3},
 {'dag': [1][0|1][2|0,1], 'depth': 1, 'num_arcs': 3},
 {'dag': [0][1|0][2|0,1], 'depth': 2, 'num_arcs': 3},
 {'dag': [2][0][1|0,2], 'depth': 0, 'num_arcs': 2}]


Use the non-parametric HSIC test as the CI test. The sufficient statistic for this test is simply the data itself (note: in the future, this should be a dictionary for the sake of consistency).

In [10]:
np.random.seed(1729)
random.seed(1729)
est_dag, summaries = gsp(samples, nnodes, hsic_test)

  x, resids, rank, s = lstsq(a, b, cond=cond, check_finite=False)


In [11]:
print(est_dag)

[2][0][1|0,2]
