# Example usage of ARCO-GP on the Sachs dataset.

This notebook illustrates the example usage of ARCO-GP on the Sachs dataset.

In [None]:

# imports
%reload_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import networkx as nx
import torch
from matplotlib.ticker import MaxNLocator

from src.abci_arco_gp import ABCIArCOGP as ABCI
from src.config import ABCIArCOGPConfig
from src.environments.experiment import gather_data
from src.environments.sachs import Sachs
from src.mechanism_models.mechanisms import get_mechanism_key

Load the sachs dataset.

In [None]:
# init environment
env = Sachs(data_file='../data/sachs.csv', normalise=True) # this needs the data provided by Sachs et al. 

# plot true graph
nx.draw(env.graph, nx.circular_layout(env.graph), labels=dict(zip(env.graph.nodes, env.graph.nodes)))


Next, we create an instance of ARCO-GP with the desired configuration.

In [None]:

cfg = ABCIArCOGPConfig()
cfg.policy = 'static-obs-dataset'
cfg.max_ps_size = 2
cfg.num_arco_steps = 100
abci = ABCI(env, cfg)


We can now run ARCO-GP.

In [None]:

abci.run()

# plot loss over experiments
ax = plt.figure().gca()
plt.plot(abci.stats['arco_loss'], label='arco_loss')
plt.xlabel('Number of Steps')
plt.ylabel('Loss')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend()

Here we print the structure stats of the learned ARCO-GP model.

In [None]:

# could compute the edge probabilities like so
# abci.cfg.num_mc_graphs = 10
# mc_cos, _ = abci.sample_mc_cos(num_cos=500)
# edge_probs = abci.compute_posterior_edge_probs(mc_cos)
# print(edge_probs)

# could re-compute the stats like so
# abci.stats.clear()
# abci.compute_stats()

print()
print(f"ESHD {abci.stats['eshd']} vs. ESHD CPDAG {abci.stats['eshd_cpdag']}")
print(f"True Num E {env.graph.number_of_edges()} vs. E-NUM Edges{abci.stats['enum_edges']}")
print(f"A-AID {abci.stats['aaid']}   vs. A-AID cpdag {abci.stats['aaid_cpdag']}")
print(f"P-AID {abci.stats['paid']}   vs. P-AID cpdag {abci.stats['paid_cpdag']}")
print(f"OSET-AID {abci.stats['oset_aid']} vs. OSET-AID cpdag {abci.stats['oset_aid_cpdag']} ")
print(f"ORDER-AID {abci.stats['order_aid']} ")
print()
print(f"F1 {abci.stats['ef1']}     vs. F1 cpdag {abci.stats['ef1_cpdag']}")
print(f"TPR {abci.stats['etpr']} vs. TPR cpdag {abci.stats['etpr_cpdag']}")
print(f"TNR {abci.stats['etnr']} vs. TNR cpdag {abci.stats['etnr_cpdag']}")
print(f"FNR {abci.stats['efnr']} vs. FNR cpdag {abci.stats['efnr_cpdag']}")
print(f"FPR {abci.stats['efpr']} vs. FPR cpdag {abci.stats['efpr_cpdag']}")
print(f"AUROC {abci.stats['auroc']}  vs. AUROC CPDAG {abci.stats['auroc_cpdag']}")
print(f"AUPRC {abci.stats['auprc']}  vs. AUPRC CPDAG {abci.stats['auprc_cpdag']}")
print()


Finally, we can have a look at the learned mechanisms.

In [None]:
# setup
parent = 'Plcg'  # 'PKC'
target = 'PIP3'  # 'PKA'
num_points = 500

xdata, ydata = gather_data(abci.experiments, target, parents=[parent])

# compute predicted mechanism values
xrange = torch.linspace(xdata.min(), xdata.max(), num_points).unsqueeze(1)
ypred = abci.mechanism_model.apply_mechanism(xrange.expand(-1, env.num_nodes),
                                             get_mechanism_key(target, [parent])).detach()

plt.figure()
plt.plot(xdata, ydata, 'rx', label='Experimental Data')
plt.plot(xrange, ypred, label=f'{parent}->{target} prediction')
plt.xlabel(parent)
plt.ylabel(target)
plt.legend()
plt.tight_layout()

