# Example usage of DiBS-GP

This notebook illustrates Bayesian causal inference with DiBS-GP, i.e., using DiBS for approximate graph posterior inference
and a GP mechanism model.

In [None]:
# imports
%reload_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

from src.abci_dibs_gp import ABCIDiBSGP as ABCI
from src.config import ABCIDiBSGPConfig
from src.environments.generic_environments import *


First, we generate a ground truth environment/SCM.


In [None]:
# specify the number of nodes and (optionally) a query of interventional variables
num_nodes = 5
env_cfg = EnvironmentConfig()
env_cfg.num_observational_train_samples = 30
env_cfg.num_observational_test_samples = 20
env_cfg.normalise_data = True

env = BarabasiAlbert(num_nodes, env_cfg)

# plot true graph
nx.draw(env.graph, nx.circular_layout(env.graph), labels=dict(zip(env.graph.nodes, env.graph.nodes)))

Here, we create an ABCI instance with the desired experimental design policy.

In [None]:
cfg = ABCIDiBSGPConfig()
cfg.policy = 'static-obs-dataset'
cfg.num_particles = 5
cfg.num_mc_graphs = 100
cfg.num_svgd_steps = 50
abci = ABCI(env, cfg)

We can now run a number of ABCI loops.

In [None]:

abci.run()

# plot loss over experiments
ax = plt.figure().gca()
plt.plot(abci.stats['svgd_loss'])
plt.xlabel('Number of Steps')
plt.ylabel('Loss')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

Print structure learning stats.

In [None]:
# optionally: recompute the structure learning stats
# abci.stats.clear()
# abci.compute_stats()

print()
print(f"ESHD {abci.stats['eshd']} vs. ESHD CPDAG {abci.stats['eshd_cpdag']}")
print(f"True Num E {env.graph.number_of_edges()} vs. E-NUM Edges{abci.stats['enum_edges']}")
print(f"A-AID {abci.stats['aaid']}   vs. A-AID cpdag {abci.stats['aaid_cpdag']}")
print(f"P-AID {abci.stats['paid']}   vs. P-AID cpdag {abci.stats['paid_cpdag']}")
print(f"OSET-AID {abci.stats['oset_aid']} vs. OSET-AID cpdag {abci.stats['oset_aid_cpdag']} ")
print()
print(f"F1 {abci.stats['ef1']}     vs. F1 cpdag {abci.stats['ef1_cpdag']}")
print(f"TPR {abci.stats['etpr']} vs. TPR cpdag {abci.stats['etpr_cpdag']}")
print(f"TNR {abci.stats['etnr']} vs. TNR cpdag {abci.stats['etnr_cpdag']}")
print(f"FNR {abci.stats['efnr']} vs. FNR cpdag {abci.stats['efnr_cpdag']}")
print(f"FPR {abci.stats['efpr']} vs. FPR cpdag {abci.stats['efpr_cpdag']}")
print(f"AUROC {abci.stats['auroc']}  vs. AUROC CPDAG {abci.stats['auroc_cpdag']}")
print(f"AUPRC {abci.stats['auprc']}  vs. AUPRC CPDAG {abci.stats['auprc_cpdag']}")
print()
