In [1]:
import causaldag as cd

from causaldag import igsp
from causaldag import partial_correlation_test, MemoizedCI_Tester, partial_correlation_suffstat
from causaldag import MemoizedInvarianceTester, gauss_invariance_test, gauss_invariance_suffstat

import numpy as np
from pprint import pprint
import random

## UT-IGSP on random graph and intervention targets

In [2]:
np.random.seed(37645296)

Generate a random DAG.

In [3]:
nnodes = 10
nodes = set(range(nnodes))
exp_nbrs = 2
d = cd.rand.directed_erdos(nnodes, exp_nbrs/(nnodes-1))

Randomly assign edge weights.

In [4]:
g = cd.rand.rand_weights(d)

Choose random intervention targets.

In [5]:
num_targets = 2
num_settings = 2
targets_list = [random.sample(nodes, num_targets) for _ in range(num_settings)]
print(targets_list)

[[5, 8], [4, 5]]


since Python 3.9 and will be removed in a subsequent version.
  targets_list = [random.sample(nodes, num_targets) for _ in range(num_settings)]


Generate observational data.

In [6]:
nsamples_obs = 1000
obs_samples = g.sample(nsamples_obs)

Generate interventional data.

In [7]:
iv_mean = 1
iv_var = .1
nsamples_iv = 1000
ivs = [{target: cd.GaussIntervention(iv_mean, iv_var) for target in targets} for targets in targets_list]
iv_samples_list = [g.sample_interventional(iv, nsamples_iv) for iv in ivs]

Form sufficient statistics.

In [8]:
obs_suffstat = partial_correlation_suffstat(obs_samples)
invariance_suffstat = gauss_invariance_suffstat(obs_samples, iv_samples_list)

Create CI and invariance tester objects.

In [9]:
alpha = 1e-3
alpha_inv = 1e-3
ci_tester = MemoizedCI_Tester(partial_correlation_test, obs_suffstat, alpha=alpha)
invariance_tester = MemoizedInvarianceTester(gauss_invariance_test, invariance_suffstat, alpha=alpha_inv)

Run IGSP.

In [10]:
setting_list = [dict(interventions=targets) for targets in targets_list]
est_dag = igsp(setting_list, nodes, ci_tester, invariance_tester)

Check performance.

In [11]:
true_icpdag = d.interventional_cpdag(targets_list, cpdag=d.cpdag())
est_icpdag = est_dag.interventional_cpdag(targets_list, cpdag=est_dag.cpdag())
print(true_icpdag.shd(est_icpdag))
print(true_icpdag.shd_skeleton(est_icpdag))

0
0


  warn(s)
