# Example usage of ABCI-DiBS-GP

This notebook illustrates the example usage of ABCI using DiBS for approximate graph posterior inference
and a GP mechanism model.

In [None]:
# imports
%reload_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import torch.distributions as dist
from matplotlib.ticker import MaxNLocator

from src.abci_dibs_gp import ABCIDiBSGP as ABCI
from src.environments.generic_environments import *
from src.models.gp_model import *

First, we generate a ground truth environment/SCM.


In [None]:
# specify the number of nodes and (optionally) a query of interventional variables
num_nodes = 5
interventional_queries = None
# interventional_queries = [InterventionalDistributionsQuery(['X2'], {'X1': dist.Uniform(2., 5.)})]

# generate the ground truth environment
env = BarabasiAlbert(num_nodes,
            num_test_queries=50,
            interventional_queries=interventional_queries)

# plot true graph
nx.draw(env.graph, nx.circular_layout(env.graph), labels=dict(zip(env.graph.nodes, env.graph.nodes)))

Here, we create an ABCI instance with the desired experimental design policy.

In [None]:
policy = 'graph-info-gain'
abci = ABCI(env, policy, num_particles=5, num_mc_graphs=40, num_workers=1, dibs_plus=True, linear=False)

We can now run a number of ABCI loops.

In [None]:
num_experiments = 2
batch_size = 3

abci.run(num_experiments, batch_size, num_initial_obs_samples=3)

Here, we plot the training stats and results.

In [None]:
print(f'Number of observational batches: {len([e for e in abci.experiments if e.interventions == {}])}')
for node in env.node_labels:
    print(
        f'Number of interventional batches on {node}: {len([e for e in abci.experiments if node in e.interventions])}')

# plot expected SHD over experiments
ax = plt.figure().gca()
plt.plot(abci.eshd_list)
plt.xlabel('Number of Experiments')
plt.ylabel('Expected SHD')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

# plot auroc over experiments
ax = plt.figure().gca()
plt.plot(abci.auroc_list)
plt.xlabel('Number of Experiments')
plt.ylabel('AUROC')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

# plot auprc over experiments
ax = plt.figure().gca()
plt.plot(abci.auprc_list)
plt.xlabel('Number of Experiments')
plt.ylabel('AUPRC')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

# plot Query KLD over experiments
ax = plt.figure().gca()
plt.plot(abci.query_kld_list)
plt.xlabel('Number of Experiments')
plt.ylabel('Query KLD')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.tight_layout()

Finally, we can have a look at the learned vs. true mechanisms.

In [None]:
# plot X_i -> X_j true vs. predicted
i = 0
j = 1
xdata, ydata = gather_data(abci.experiments, f'X{j}', parents=[f'X{i}'])
xrange = torch.linspace(-7., 7., 100).unsqueeze(-1)
ytrue = env.mechanisms[f'X{j}'](xrange).detach()
mech = abci.mechanism_model.get_mechanism(f'X{j}', parents=[f'X{i}'])
mech.set_data(xdata, ydata)
ypred = mech(xrange).detach()

plt.figure()
plt.plot(xdata, ydata, 'rx', label='Experimental Data')
plt.plot(xrange, ytrue, label='X->Y true')
plt.plot(xrange, ypred, label='X->Y prediction')
plt.xlabel(f'X{i}')
plt.ylabel(f'X{j}')
plt.legend()
plt.tight_layout()

In [None]:
# plot bivariate mechanisms
node = 'X2'
num_points = 100
xrange = torch.linspace(-7., 7., num_points)
yrange = torch.linspace(-7., 7., num_points)
xgrid, ygrid = torch.meshgrid(xrange, yrange)
inputs = torch.stack((xgrid, ygrid), dim=2).view(-1, 2)
ztrue = env.mechanisms[node](inputs).detach().view(num_points, num_points).numpy()

parents = ['X0', 'X1']
mech = abci.mechanism_model.get_mechanism(node, parents=parents)
sample_inputs, sample_targets = gather_data(abci.experiments, node, parents=parents)
mech.set_data(sample_inputs, sample_targets)
zpred = mech(inputs)
zpred = zpred.detach().view(num_points, num_points).numpy()

zmin = ztrue.min().item()
zmax = ztrue.max().item()
print(f'Function values for {node} in range [{zmin, zmax}].')

levels = torch.linspace(zmin, zmax, 30).numpy()
fig, axes = plt.subplots(1, 2)
cp1 = axes[0].contourf(xgrid, ygrid, ztrue, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax,
                       antialiased=False)
cp2 = axes[1].contourf(xgrid, ygrid, zpred, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax,
                       antialiased=False)

axes[0].plot(sample_inputs[:, 0], sample_inputs[:, 1], 'kx')
axes[0].set_xlabel(parents[0])
axes[1].set_xlabel(parents[0])
axes[0].set_ylabel(parents[1])
_ = fig.colorbar(cp2)