In [1]:
import causaldag as cd
from causaldag import unknown_target_igsp
from causaldag.utils.ci_tests import gauss_ci_test, MemoizedCI_Tester, gauss_ci_suffstat
from causaldag.utils.invariance_tests import MemoizedInvarianceTester, gauss_invariance_test, gauss_invariance_suffstat
import numpy as np
from pprint import pprint
import random
from tqdm import tqdm
from scipy.stats import describe

In [2]:
np.random.seed(375296)
random.seed(827642)

In [3]:
def setup_dag(gdag, num_known, num_unknown, num_settings):
    # RANDOMLY PICK TARGETS
    known_targets_list = [
        set(random.sample(nodes, num_known)) 
        for _ in range(num_settings)
    ]
    unknown_targets_list = [
        set(random.sample(nodes-known_targets, num_unknown))
        for known_targets in known_targets_list
    ]
    setting_list = [
        dict(known_interventions=targets) 
        for targets in known_targets_list
    ]
    
    # TURN TARGETS INTO INTERVENTIONS
    ivs = [
        {target: cd.ShiftIntervention(1) for target in known_targets|unknown_targets}
        for known_targets, unknown_targets in zip(known_targets_list, unknown_targets_list)
    ]
    
    # GET SAMPLES
    obs_samples = gdag.sample(nsamples)
    iv_samples_list = [gdag.sample_interventional(iv, nsamples) for iv in ivs]
    
    # CREATE SUFFICIENT STATISTICS
    suffstat = gauss_ci_suffstat(obs_samples)
    inv_suffstat = gauss_invariance_suffstat(obs_samples, iv_samples_list)
    
    # CREATE CI TESTERS
    ci_tester =  MemoizedCI_Tester(gauss_ci_test, suffstat, alpha=1e-5)
    invariance_tester = MemoizedInvarianceTester(gauss_invariance_test, inv_suffstat, alpha=1e-5)
    
    return setting_list, ci_tester, invariance_tester, known_targets_list, unknown_targets_list

In [4]:
nnodes = 100
nodes = set(range(nnodes))
exp_nbrs = 1.5
ngraphs = 50
nsamples = 1000
num_settings = 5
num_known = 1
num_unknown = 3

dags = cd.rand.directed_erdos(nnodes, exp_nbrs/(nnodes-1), ngraphs)
cpdags = [dag.cpdag() for dag in dags]
gdags = [cd.rand.rand_weights(d) for d in dags]
alg_info_list = [setup_dag(g, num_known, num_unknown, num_settings) for g in gdags]
_, _, _, dags2known_target_list, dags2unknown_target_list = zip(*alg_info_list)

In [5]:
est_dags = []
est_targets_list = []
for setting_list, ci_tester, inv_tester, _, _ in tqdm(alg_info_list):
    est_dag, est_targets = unknown_target_igsp(setting_list, nodes, ci_tester, inv_tester)
    est_dags.append(est_dag)
    est_targets_list.append(est_targets)

100%|██████████| 50/50 [19:39<00:00, 23.59s/it]


In [6]:
est_cpdags = [est_dag.cpdag() for est_dag in est_dags]
shds = [est_cpdag.shd(cpdag) for est_cpdag, cpdag in zip(est_cpdags, cpdags)]

In [7]:
describe(shds)

DescribeResult(nobs=50, minmax=(5, 35), mean=18.28, variance=91.9608163265306, skewness=0.38337613969156975, kurtosis=-1.207439115817538)