# ArCO Model Example Notebook

This notebook illustrates the usage of the ArCO model - an auto-regressive generative model for causal orders.

In [None]:
# imports
%reload_ext autoreload
%autoreload 2

import random
import matplotlib.pyplot as plt
import torch

from src.config import ArCOConfig
from src.graph_models.arco import ArCO
from src.utils.causal_orders import CausalOrder, generate_all_permutations
from src.utils.plotting import init_plot_style

init_plot_style()

Sanity check: total probability of all causal orders for a given number of nodes must sum to 1

In [None]:
num_nodes = 3
node_labels = [f'X{i}' for i in range(num_nodes)]

# init causal order model
cfg = ArCOConfig()
cfg.map_mode = 'simple' # available are 'simple' and 'mlp'
model = ArCO(node_labels, cfg)
cos = generate_all_permutations(node_labels)
print(f'Number of causal orders with {num_nodes} nodes is {len(cos)}')

probs = model.log_prob(cos).exp()
print(f'\nTotal prob: {probs.sum()}')

for cidx, co in enumerate(cos):
    print(f'p(<{co}>) = {probs[cidx].item()}')

### We will now train and ArCO model to fit a given set of causal orders.

First, we initialise the ArCO model.

In [None]:
num_nodes = 5
node_labels = [f'X{i}' for i in range(num_nodes)]
cfg = ArCOConfig()
cfg.map_mode = 'simple'
model = ArCO(node_labels, cfg)

# sample cos from model
cos, _ = model.sample(5)
print(f'Sampled causal orders from prior model:')
probs = model.log_prob(cos).exp()
for idx, co in enumerate(cos):
    print(f'p(<{co}>) = {probs[idx].item()}')

Option 1: Generate target orders as a pair of chain and reverse chain. The simple model will fail to fit these two orders with equal probability.

In [None]:

# 1) chain
co1 = CausalOrder([{f'X{i}'} for i in range(num_nodes)])
# 2) reverse chain
co2 = CausalOrder([{f'X{i}'} for i in range(num_nodes - 1, -1, -1)])

cos = [co1, co2]

Option 2: Generate random permutations as target orders.

In [None]:
def permutation_key(layers):
    return '|'.join([list(layer)[0] for layer in layers])


num_permutations = 1
layers = [{f'X{i}'} for i in range(num_nodes)]

cos = []
unique_cos = set()
pidx = 0
while pidx < num_permutations:
    random.shuffle(layers)
    key = permutation_key(layers)
    if key not in unique_cos:
        pidx += 1
        unique_cos.update([key])
        cos.append(CausalOrder(layers.copy()))



We can now train CO model to uniform distribution over a set of target orders.

In [None]:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
losses = []
for _ in range(50):
    optimizer.zero_grad()
    ll = model.log_prob(cos).sum()
    loss = -ll - model.log_param_prior()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

print(f'Loss: {losses[-1]}')
probs = model.log_prob(cos).exp()
print(f'\nTotal prob: {probs.sum()}')
for idx, co in enumerate(cos):
    print(f'p(<{co}>) = {probs[idx]}')

sampled_cos, _ = model.sample(10)
probs = model.log_prob(sampled_cos).exp()
print(f'Samples of the trained model:')
for idx, co in enumerate(sampled_cos):
    print(f'p(<{co}>) = {probs[idx]}')

plt.figure()
plt.plot(losses)

For 'simple' ArCO only: visualize model parameters (= 1D embedding) of the simple map.

In [None]:
theta = model.logit_map.theta.detach()

x = torch.linspace(-model.logit_map.prior.scale - 5., model.logit_map.prior.scale + 5., 1000)
_, ax = plt.subplots(figsize=(16, 9))
for i in range(num_nodes):
    ax.text(theta[i], 0., f'{i}')
ax.plot(model.logit_map.theta.detach(), torch.zeros_like(model.logit_map.theta), 'rx')
ax.plot(x, model.logit_map.prior.log_prob(x).exp(), label=r'$p(\theta)$')
ax.legend(loc='upper center')
_ = plt.xlim([-model.logit_map.prior.scale - 5., model.logit_map.prior.scale + 5.])
