# Example usage of ABCI-Categorical-GP

This notebook illustrates active Bayesian causal inference with a categorical distribution over graphs, i.e., all graphs are exhaustively enumerated, and a GP mechanism model. This setup scales up to systems with around four variables.

In [None]:
import networkx as nx
# imports
%reload_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

from src.abci_categorical_gp import ABCICategoricalGP as ABCI
from src.config import ABCICategoricalGPConfig
from src.environments.generic_environments import *
from src.mechanism_models.gp_model import get_graph_key, gather_data, get_mechanism_key
from src.utils.graphs import get_parents


First, we generate a ground truth environment/SCM.

In [None]:
# generate the ground truth environment
num_nodes = 4
env_cfg = EnvironmentConfig()
env_cfg.normalise_data = True
env = BiDiag(num_nodes, env_cfg)

# plot true graph
nx.draw(env.graph, nx.planar_layout(env.graph), labels=dict(zip(env.graph.nodes, env.graph.nodes)))


Next, we can examine the ground truth mechanisms.

In [None]:
# plotting a univariate mechanism
parent = 'X0'
target = 'X1'  # target node
num_points = 500
xrange = torch.linspace(-3., 3., num_points).unsqueeze(1)
ytrue = env.apply_mechanism(xrange, get_mechanism_key(target, [parent])).detach()

plt.figure()
plt.plot(xrange, ytrue)
plt.xlabel(parent)
plt.ylabel(target)
plt.tight_layout()

In [None]:
# plotting a bivariate mechanism
parents = ['X0', 'X1']
node = 'X2'
num_points = 100

# create meshgrid and compute true mechanism values
range_min = range_max = 3.
xrange = torch.linspace(-range_min, range_max, num_points)
yrange = torch.linspace(-range_min, range_max, num_points)
xgrid, ygrid = torch.meshgrid(xrange, yrange, indexing='ij')
inputs = torch.stack((xgrid, ygrid), dim=2).view(-1, 2)
ztrue = env.apply_mechanism(inputs, get_mechanism_key(node, parents)).detach().view(num_points, num_points).numpy()

zmin = ztrue.min().item()
zmax = ztrue.max().item()
levels = torch.linspace(zmin, zmax, 30).cpu().numpy()
print(f'Function values for {node} in range [{zmin, zmax}].')

fig, ax = plt.subplots(1, 1)
cp1 = ax.contourf(xgrid.cpu(), ygrid.cpu(), ztrue, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax,
                  antialiased=False)

ax.set_xlabel(parents[0])
ax.set_ylabel(parents[1])
ax.set_xlim([-range_min, range_max])
ax.set_ylim([-range_min, range_max])
_ = fig.colorbar(cp1)


In [None]:
# plotting a multi-variate mechanism along a fixed dimension
target = 'X2'
parents = get_parents(target, env.graph)
varnode = parents[0]
num_points = 100
xmin = -1.
xmax = 1.

fixed_samples = (torch.rand(1, len(parents)) * (xmax - xmin) + xmin).expand(num_points, -1)
xrange = torch.linspace(xmin, xmax, num_points).unsqueeze(1)
varidx = parents.index(varnode)
inputs = torch.cat((fixed_samples[:, :varidx], xrange, fixed_samples[:, varidx + 1:]), dim=-1)
ytrue = env.apply_mechanism(inputs, get_mechanism_key(target, parents)).detach()

plt.figure()
plt.plot(xrange, ytrue)
plt.xlabel(varnode)
plt.ylabel(target)
plt.tight_layout()

Here, we create an ABCI instance with the desired experimental design policy.

In [None]:
cfg = ABCICategoricalGPConfig()
cfg.policy = 'random'
cfg.num_experiments = 5
cfg.batch_size = 3
cfg.num_initial_obs_samples = 5
abci = ABCI(env, cfg)

We can now run a number of ABCI loops.

In [None]:
abci.run()

Print the structure learning stats.

In [None]:
# optionally: recompute the structure learning stats
# abci.stats.clear()
# abci.compute_stats()

edge_probs = abci.graph_posterior.edge_probs()
print('Posterior edge probabilities:')
print(f'{edge_probs}')
print()
print(f"ESHD {abci.stats['eshd']} vs. ESHD CPDAG {abci.stats['eshd_cpdag']}")
print(f"True Num E {env.graph.number_of_edges()} vs. E-NUM Edges{abci.stats['enum_edges']}")
print(f"A-AID {abci.stats['aaid']}        ")
print(f"P-AID {abci.stats['paid']}        ")
print(f"OSET-AID {abci.stats['oset_aid']} ")
print()
print(f"F1 {abci.stats['ef1']}     vs. F1 cpdag {abci.stats['ef1_cpdag']}")
print(f"TPR {abci.stats['etpr']} vs. TPR cpdag {abci.stats['etpr_cpdag']}")
print(f"TNR {abci.stats['etnr']} vs. TNR cpdag {abci.stats['etnr_cpdag']}")
print(f"FNR {abci.stats['efnr']} vs. FNR cpdag {abci.stats['efnr_cpdag']}")
print(f"FPR {abci.stats['efpr']} vs. FPR cpdag {abci.stats['efpr_cpdag']}")
print(f"AUROC {abci.stats['auroc']}  vs. AUROC CPDAG {abci.stats['auroc_cpdag']}")
print(f"AUPRC {abci.stats['auprc']}  vs. AUPRC CPDAG {abci.stats['auprc_cpdag']}")
print()


Plot some of the recorded metrics over the experiments.

In [None]:
print(f'Number of observational batches: {len([e for e in abci.experiments if e.interventions == {}])}')
for node in env.node_labels:
    print(
        f'Number of interventional batches on {node}: {len([e for e in abci.experiments if node in e.interventions])}')

# plot expected SHD over experiments
ax = plt.figure().gca()
plt.plot(abci.stats['eshd'])
plt.xlabel('Number of Experiments')
plt.ylabel('Expected SHD')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

# plot ancestor aid over experiments
ax = plt.figure().gca()
plt.plot(abci.stats['aaid'])
plt.xlabel('Number of Experiments')
plt.ylabel('A-AID')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.tight_layout()

# plot true graph NLL over experiments
ax = plt.figure().gca()
plt.plot(-torch.tensor(abci.stats['graph_ll']))
plt.xlabel('Number of Experiments')
plt.ylabel('True Graph NLL')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

# plot graph posterior entropy over experiments
ax = plt.figure().gca()
plt.plot(abci.stats['graph_entropy'], label='entropy estimate')
plt.xlabel('Number of Experiments')
plt.ylabel('Entropy of Graph Posterior')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend()

# plot graph posterior
graphs = abci.graph_posterior.sort_by_prob()[0:10]
probs = [abci.graph_posterior.log_prob(g).exp().detach() for g in graphs]
graph_keys = [get_graph_key(g) for g in graphs]

plt.figure()
plt.xticks(rotation=90)
plt.bar(graph_keys, probs)
plt.ylabel(r'Graph Posterior, $p(G|D)$')





Finally, we can have a look at the learned vs. true mechanisms. Here, we compare univariate mechanisms.

In [None]:
# setup
parent = 'X0'
target = 'X1'  # target node
num_points = 500

# compute true mechanism values
xdata, ydata = gather_data(abci.experiments, target, parents=[parent])
xrange = torch.linspace(xdata.min(), xdata.max(), num_points).unsqueeze(1)
ytrue = env.apply_mechanism(xrange, get_mechanism_key(target, [parent])).detach()

# compute predicted mechanism values
mech = abci.mechanism_model.get_mechanism(target, parents=[parent])
mech.set_data(xdata, ydata)
ypred = mech(xrange).detach()

plt.figure()
plt.plot(xdata, ydata, 'rx', label='Experimental Data')
plt.plot(xrange, ytrue, label=f'{parent}->{target} true')
plt.plot(xrange, ypred, label=f'{parent}->{target} prediction')
plt.xlabel(parent)
plt.ylabel(target)
plt.legend()
plt.tight_layout()

Here, we compare bi-variate mechanisms.

In [None]:
# plotting a bivariate mechanism
parents = ['X0', 'X1']
target = 'X2'
num_points = 100

# create meshgrid and compute true mechanism values
range_min = range_max = 3.
xrange = torch.linspace(-range_min, range_max, num_points)
yrange = torch.linspace(-range_min, range_max, num_points)
xgrid, ygrid = torch.meshgrid(xrange, yrange, indexing='ij')
inputs = torch.stack((xgrid, ygrid), dim=2).view(-1, 2)
ztrue = env.apply_mechanism(inputs, get_mechanism_key(target, parents)).detach().view(num_points, num_points).numpy()

# compute predicted mechanism values
mech = abci.mechanism_model.get_mechanism(target, parents=parents)
zpred = mech(inputs).detach().view(num_points, num_points).numpy()

zmin = min(ztrue.min().item(), zpred.min().item())
zmax = max(ztrue.max().item(), zpred.max().item())
levels = torch.linspace(zmin, zmax, 100).cpu().numpy()
print(f'Function values for {target} in range [{zmin, zmax}].')

# plot mechanisms
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)
cp1 = axes[0].contourf(xgrid.cpu(), ygrid.cpu(), ztrue, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax,
                       antialiased=False)
cp2 = axes[1].contourf(xgrid.cpu(), ygrid.cpu(), zpred, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax,
                       antialiased=False)

sample_inputs, sample_targets = gather_data(abci.experiments, target, parents=parents)

axes[0].plot(sample_inputs[:, 0].cpu(), sample_inputs[:, 1].cpu(), 'kx')
axes[1].plot(sample_inputs[:, 0].cpu(), sample_inputs[:, 1].cpu(), 'kx')
axes[0].set_xlabel(parents[0])
axes[0].set_ylabel(parents[1])
ax.set_xlim([-range_min, range_max])
ax.set_ylim([-range_min, range_max])
axes[1].set_xlabel(parents[0])
_ = fig.colorbar(cp2)