# Introduction

The solution is adapted from the implementation of a DECI model in the causica library, which can be found [here](https://github.com/microsoft/causica/blob/b3c79a01f30f44ed36c582ffe2b4522058d82a73/causica/models/deci/deci.py), with an associated paper of the original model [here](https://arxiv.org/abs/2202.02195).

The solution is GNN-based, DECI being a generative model that employs an additive noise structural equation model to capture the functional relationships among variables and exogenous noise, while also learning a variational distribution. It is designed to perform causal inference without background information about the causal graph.

The relationships are learnt through flexible neural networks, while the noise can be modeled either as a Gaussian or a spine-flow model. It is considered a generative method, since it essentially evolves from exogenous noise to observations.

Both a mean-field approximate posterior distribution, and the functional relationships, are learned by optimising an evidence lower bound (ELBO).

The implementation itself is based on the popular torch library, and on the causica library, following these steps: it creates prior distributions over directed acyclic graphs (DAGs), it also creates variational posterior distributions over the adjacency matrices, a GNN, for representing functional relationships, and a noise distribution for each node, these last three components being the ones to be optimized. These components are gathered in a structural equation model (SEM), which will be optimized with different learning rates for each module. An Augmented Lagrangian scheduler is employed, that will optimize towards a DAG. The result of the training process will be the adjacency matrix.

Since it is already available, we also try to employ the discretization method, used in the bnlearn solution, to verify whether it helps in any way the learning process.

# 0. Preliminaries

This section includes the required imports, and it also defines the hyperparameters. The notebook can be run locally after following the README instructions in the root directory.

In [1]:
from causica.distributions import ContinuousNoiseDist

from data.csuite.csuite_datasets import *
from data.sachs.sachs_datasets import unaltered_dataset
from evaluation.metrics import eval_all
from models.deci.causica_deci import causica_deci
from utils.solution_utils import hartemink_discretization

In [2]:
train_config = {
    "batch_size": 128,
    "epochs": 1000,
    "init_alpha": 0.0,
    "init_rho": 1.0,
    "gumbel_temp": 0.25,
    "prior_sparsity_lambda": 5.0,
    "embedding_size": 32,
    "out_dim_g": 32,
    "num_layers_g": 2,
    "num_layers_zeta": 2,
    "noise_dist": ContinuousNoiseDist.SPLINE
}

# 1. Observational Set (Sachs)

In [3]:
# Without discretization
print("~~~~~ WITHOUT DISCRETIZATION ~~~~~")
df, gt_graph = unaltered_dataset(get_data=True, return_index_name_correlation=False, return_adjacency_graph=True)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

# With discretization
print("\n\n\n~~~~~ WITH DISCRETIZATION ~~~~~")
df = hartemink_discretization(df)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

~~~~~ WITHOUT DISCRETIZATION ~~~~~
epoch:0 loss:3.0279e+05 nll:3.0279e+05 dagness:5.34188 num_edges:28 alpha:0 rho:1 step:0|1 num_lr_updates:0
epoch:100 loss:11934 nll:11934 dagness:38.23705 num_edges:53 alpha:0 rho:1 step:0|701 num_lr_updates:0
epoch:200 loss:2990.1 nll:2989.8 dagness:28.87832 num_edges:50 alpha:0 rho:1 step:0|1401 num_lr_updates:0
epoch:300 loss:1755 nll:1754.7 dagness:22.02489 num_edges:48 alpha:0 rho:1 step:0|2101 num_lr_updates:0
epoch:400 loss:1189.5 nll:1189.2 dagness:30.51480 num_edges:48 alpha:0 rho:1 step:0|2801 num_lr_updates:0
Updating alpha to: 25.203094482421875
epoch:500 loss:497.13 nll:496.89 dagness:1.17651 num_edges:36 alpha:25.203 rho:1 step:1|501 num_lr_updates:0
epoch:600 loss:492.5 nll:492.27 dagness:1.17651 num_edges:34 alpha:25.203 rho:1 step:1|1201 num_lr_updates:0
epoch:700 loss:258.21 nll:257.99 dagness:1.17651 num_edges:33 alpha:25.203 rho:1 step:1|1901 num_lr_updates:0
epoch:800 loss:197.57 nll:197.35 dagness:1.17651 num_edges:32 alpha:25.2

# 2. Synthetic Set (CSuite)

In [4]:
# ~~~~~~~~ LINGAUSS ~~~~~~~~
csuite_dataset = lingauss

# Without discretization
print("~~~~~ WITHOUT DISCRETIZATION ~~~~~")
df, gt_graph = csuite_dataset(2000, True, True)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

# With discretization
print("\n\n\n~~~~~ WITH DISCRETIZATION ~~~~~")
df = hartemink_discretization(df)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

~~~~~ WITHOUT DISCRETIZATION ~~~~~
epoch:0 loss:3.0199 nll:3.0179 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:0|1 num_lr_updates:0
Updating alpha to: 0.0
epoch:100 loss:2.771 nll:2.7686 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:1|335 num_lr_updates:0
Updating alpha to: 0.0
epoch:200 loss:2.486 nll:2.4835 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:2|611 num_lr_updates:1
Updating alpha to: 0.0
Updating alpha to: 0.0
epoch:300 loss:2.5766 nll:2.5741 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:4|131 num_lr_updates:0
Updating alpha to: 0.0
epoch:400 loss:2.7435 nll:2.741 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|438 num_lr_updates:0
epoch:500 loss:2.6145 nll:2.612 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|2038 num_lr_updates:1
epoch:600 loss:2.6867 nll:2.6842 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|3638 num_lr_updates:1
epoch:700 loss:2.727 nll:2.7246 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|5238 num_lr_updates:1
epoch:800 loss:2.7463 nll:2.7438

In [5]:
# ~~~~~~~~ LINEXP ~~~~~~~~
csuite_dataset = linexp

# Without discretization
print("~~~~~ WITHOUT DISCRETIZATION ~~~~~")
df, gt_graph = csuite_dataset(2000, True, True)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

# With discretization
print("\n\n\n~~~~~ WITH DISCRETIZATION ~~~~~")
df = hartemink_discretization(df)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

~~~~~ WITHOUT DISCRETIZATION ~~~~~
epoch:0 loss:2.9267 nll:2.9247 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:0|1 num_lr_updates:0
epoch:100 loss:2.1235 nll:2.121 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:0|1601 num_lr_updates:0
Updating alpha to: 0.0
epoch:200 loss:2.0187 nll:2.0162 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:1|531 num_lr_updates:1
Updating alpha to: 0.0
epoch:300 loss:1.9144 nll:1.9119 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:2|691 num_lr_updates:0
Updating alpha to: 0.0
epoch:400 loss:2.0911 nll:2.0886 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:3|550 num_lr_updates:1
Updating alpha to: 0.0
Updating alpha to: 0.0
epoch:500 loss:1.9914 nll:1.9889 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|23 num_lr_updates:0
epoch:600 loss:1.7264 nll:1.7239 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|1623 num_lr_updates:3
epoch:700 loss:1.9537 nll:1.9512 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|3223 num_lr_updates:3
epoch:800 loss:1.9224 nll:1.9

In [6]:
# ~~~~~~~~ NONLINGAUSS ~~~~~~~~
csuite_dataset = nonlingauss

# Without discretization
print("~~~~~ WITHOUT DISCRETIZATION ~~~~~")
df, gt_graph = csuite_dataset(2000, True, True)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

# With discretization
print("\n\n\n~~~~~ WITH DISCRETIZATION ~~~~~")
df = hartemink_discretization(df)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

~~~~~ WITHOUT DISCRETIZATION ~~~~~
epoch:0 loss:3.6213 nll:3.6193 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:0|1 num_lr_updates:0
epoch:100 loss:2.2258 nll:2.2233 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:0|1601 num_lr_updates:1
Updating alpha to: 0.0
epoch:200 loss:2.3897 nll:2.3873 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:1|1221 num_lr_updates:2
Updating alpha to: 0.0
Updating alpha to: 0.0
epoch:300 loss:2.1486 nll:2.1461 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:3|703 num_lr_updates:1
Updating alpha to: 0.0
epoch:400 loss:2.3465 nll:2.344 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:4|1155 num_lr_updates:1
Updating alpha to: 0.0
epoch:500 loss:2.1007 nll:2.0982 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|1398 num_lr_updates:2
epoch:600 loss:2.0261 nll:2.0236 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|2998 num_lr_updates:3
epoch:700 loss:2.1283 nll:2.1258 dagness:0.00000 num_edges:1 alpha:0 rho:1 step:5|4598 num_lr_updates:3
epoch:800 loss:2.3355 nll

In [7]:
# ~~~~~~~~ NONLIN SIMPSON ~~~~~~~~
csuite_dataset = nonlin_simpson

# Without discretization
print("~~~~~ WITHOUT DISCRETIZATION ~~~~~")
df, gt_graph = csuite_dataset(2000, True, True)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

# With discretization
print("\n\n\n~~~~~ WITH DISCRETIZATION ~~~~~")
df = hartemink_discretization(df)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

~~~~~ WITHOUT DISCRETIZATION ~~~~~
epoch:0 loss:6.588 nll:6.5914 dagness:0.00000 num_edges:0 alpha:0 rho:1 step:0|1 num_lr_updates:0
epoch:100 loss:1.1577 nll:1.1428 dagness:1.18629 num_edges:6 alpha:0 rho:1 step:0|1601 num_lr_updates:1
Updating alpha to: 1.1862878799438477
epoch:200 loss:0.96578 nll:0.95008 dagness:1.18629 num_edges:6 alpha:1.1863 rho:1 step:1|1205 num_lr_updates:1
Updating rho, dag penalty prev:  1.1862878799
epoch:300 loss:0.82451 nll:0.80882 dagness:1.18629 num_edges:6 alpha:1.1863 rho:10 step:2|587 num_lr_updates:0
Updating rho, dag penalty prev:  1.1862878799
epoch:400 loss:0.78335 nll:0.76763 dagness:1.18629 num_edges:6 alpha:1.1863 rho:100 step:3|69 num_lr_updates:0
Updating rho, dag penalty prev:  1.1862878799
epoch:500 loss:0.83881 nll:0.82294 dagness:1.18629 num_edges:6 alpha:1.1863 rho:1000 step:4|406 num_lr_updates:0
Updating rho, dag penalty prev:  1.1862878799
epoch:600 loss:0.98426 nll:0.9668 dagness:1.18629 num_edges:6 alpha:1.1863 rho:10000 step:5|436

In [8]:
# ~~~~~~~~ SYMPROD SIMPSON ~~~~~~~~
csuite_dataset = symprod_simpson

# Without discretization
print("~~~~~ WITHOUT DISCRETIZATION ~~~~~")
df, gt_graph = csuite_dataset(2000, True, True)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

# With discretization
print("\n\n\n~~~~~ WITH DISCRETIZATION ~~~~~")
df = hartemink_discretization(df)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

~~~~~ WITHOUT DISCRETIZATION ~~~~~
epoch:0 loss:7.8643 nll:7.8626 dagness:0.00000 num_edges:2 alpha:0 rho:1 step:0|1 num_lr_updates:0
epoch:100 loss:1.7906 nll:1.7756 dagness:1.18629 num_edges:6 alpha:0 rho:1 step:0|1601 num_lr_updates:1
Updating alpha to: 1.1862878799438477
epoch:200 loss:2.235 nll:2.2193 dagness:1.18629 num_edges:6 alpha:1.1863 rho:1 step:1|1206 num_lr_updates:1
Updating rho, dag penalty prev:  1.1862878799
epoch:300 loss:1.7948 nll:1.7791 dagness:1.18629 num_edges:6 alpha:1.1863 rho:10 step:2|822 num_lr_updates:1
Updating rho, dag penalty prev:  1.1862878799
epoch:400 loss:1.3995 nll:1.3838 dagness:1.18629 num_edges:6 alpha:1.1863 rho:100 step:3|1011 num_lr_updates:1
Updating rho, dag penalty prev:  1.1862878799
epoch:500 loss:1.5478 nll:1.5319 dagness:1.18629 num_edges:6 alpha:1.1863 rho:1000 step:4|1092 num_lr_updates:1
Updating rho, dag penalty prev:  1.1862878799
epoch:600 loss:1.9512 nll:1.9338 dagness:1.18629 num_edges:6 alpha:1.1863 rho:10000 step:5|651 num_l

In [9]:
# ~~~~~~~~ LARGE BACKDOOR ~~~~~~~~
csuite_dataset = large_backdoor

# Without discretization
print("~~~~~ WITHOUT DISCRETIZATION ~~~~~")
df, gt_graph = csuite_dataset(2000, True, True)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

# With discretization
print("\n\n\n~~~~~ WITH DISCRETIZATION ~~~~~")
df = hartemink_discretization(df)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

~~~~~ WITHOUT DISCRETIZATION ~~~~~
epoch:0 loss:15.098 nll:15.069 dagness:6.18400 num_edges:20 alpha:0 rho:1 step:0|1 num_lr_updates:0
epoch:100 loss:6.4576 nll:6.4092 dagness:6.45935 num_edges:21 alpha:0 rho:1 step:0|1601 num_lr_updates:1
Updating alpha to: 7.314506530761719
epoch:200 loss:6.6152 nll:6.5513 dagness:4.73215 num_edges:19 alpha:7.3145 rho:1 step:1|646 num_lr_updates:0
Updating rho, dag penalty prev:  7.3145065308
epoch:300 loss:6.2776 nll:6.2102 dagness:4.92532 num_edges:20 alpha:7.3145 rho:10 step:2|404 num_lr_updates:0
epoch:400 loss:6.2949 nll:6.2303 dagness:4.73215 num_edges:19 alpha:7.3145 rho:10 step:2|2004 num_lr_updates:1
Updating alpha to: 54.63604736328125
epoch:500 loss:6.6028 nll:6.4261 dagness:4.73215 num_edges:19 alpha:54.636 rho:10 step:3|1414 num_lr_updates:2
Updating rho, dag penalty prev:  4.7321538925
epoch:600 loss:6.436 nll:6.2765 dagness:4.18519 num_edges:18 alpha:54.636 rho:100 step:4|1131 num_lr_updates:1
Updating rho, dag penalty prev:  4.7321538

In [10]:
# ~~~~~~~~ WEAK ARROWS ~~~~~~~~
csuite_dataset = weak_arrows

# Without discretization
print("~~~~~ WITHOUT DISCRETIZATION ~~~~~")
df, gt_graph = csuite_dataset(2000, True, True)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

# With discretization
print("\n\n\n~~~~~ WITH DISCRETIZATION ~~~~~")
df = hartemink_discretization(df)
pred_graph = causica_deci(df, train_config)
print(eval_all(torch.tensor(pred_graph), gt_graph))

~~~~~ WITHOUT DISCRETIZATION ~~~~~
epoch:0 loss:15.5 nll:15.468 dagness:5.24694 num_edges:21 alpha:0 rho:1 step:0|1 num_lr_updates:0
epoch:100 loss:7.9156 nll:7.8507 dagness:12.01619 num_edges:27 alpha:0 rho:1 step:0|1601 num_lr_updates:1
Updating alpha to: 11.288244247436523
epoch:200 loss:7.6172 nll:7.5251 dagness:6.63588 num_edges:22 alpha:11.288 rho:1 step:1|880 num_lr_updates:0
Updating alpha to: 17.92412567138672
epoch:300 loss:7.4576 nll:7.3433 dagness:6.63588 num_edges:22 alpha:17.924 rho:1 step:2|592 num_lr_updates:0
Updating rho, dag penalty prev:  6.6358814240
epoch:400 loss:7.1753 nll:7.0609 dagness:6.63588 num_edges:22 alpha:17.924 rho:10 step:3|714 num_lr_updates:1
Updating rho, dag penalty prev:  6.6358814240
epoch:500 loss:7.0205 nll:6.9056 dagness:6.63588 num_edges:22 alpha:17.924 rho:100 step:4|652 num_lr_updates:1
Updating rho, dag penalty prev:  6.6358814240
epoch:600 loss:6.8398 nll:6.7199 dagness:6.63588 num_edges:22 alpha:17.924 rho:1000 step:5|963 num_lr_updates

# 3. Conclusions

Very good adjacency scores all around, and good results for the simpler graphs for the orientation too. However, the model starts to struggle on the orientation prediction as the complexity of the graph that needs to be predicted increases.

As expected, almost overall significantly worse results for the discretized data, due to the loss in information, and to the fact that the network used is not designed for categorical data.