# Example: Inferring causal structure using a pretrained AVICI model

Amortized variational inference for causal discovery (AVICI) allows inferring causal structure from data based on a **simulator** of the domain of interest. 
By training a neural network to infer structure from the simulated data, it can acquire realistic inductive biases from prior knowledge that is hard to cast as score functions or conditional independence tests.


In this example, we **download a pretrained model checkpoint and perform predictions** for a simulated dataset. Inferring causal structure with a pretrained AVICI model takes only a few seconds since it amounts to one forward pass through the neural network.

Setup for Google Colab (Skip this if running locally).


In [None]:
%pip install --quiet avici

**Simulate some data:**

The function `simulate_data` accepts the following most important arguments:
- **d** (int) -- number of variables in the system
- **n** (int) -- number of observational data points to be sampled
- **n_interv** (int) -- number of interventional data points to be sampled
- **domain** (str) -- specifier of domain to be simulated.
  Current options: `lin-gauss`, `lin-gauss-heterosked`, `lin-laplace-cauchy`, `rff-gauss`, `rff-gauss-heterosked`, `rff-laplace-cauchy`, `gene-ecoli` (see `avici.config.examples`)

The function returns a 3-tuple of
- **g** -- adjacency matrix of shape `[d, d]` of causal graph
- **x** -- data matrix of shape `[n + n_interv, d]` containing `n + n_interv` observations of the `d` variables
- **interv** -- binary matrix of shape `[n + n_interv, d]` encoding which nodes were intervened upon (`None` if `n_interv=0`)

In [None]:
import avici
from avici import simulate_data

# simulate data
g, x, interv = simulate_data(d=50, n=200, n_interv=20, domain="rff-gauss")

**Download and initialize a pretrained model:**

We currently provide the following models checkpoints,
which can be specified by the `download` argument:

- `scm-v0` (**default**): linear and nonlinear SCM data, broad graph and noise distributions
- `neurips-linear`: SCM data with linear causal mechanisms
- `neurips-rff`: SCM data with nonlinear causal mechanisms drawn
from GPs with squared-exponential kernel
(defined via random Fourier features)
- `neurips-grn`: Synthetic scRNA-seq gene expression data using the SERGIO
[simulator](https://github.com/PayamDiba/SERGIO) by
[Dibaeinia and Sinha, (2020)](https://www.cell.com/cell-systems/pdf/S2405-4712(20)30287-8.pdf)


In [None]:
# load pretrained model
model = avici.load_pretrained(download="scm-v0")

**Predict the causal structure:**

Calling `model` as obtained from `avici.load_pretrained` predicts the `[d, d]` matrix of probabilities for each possible edge in the causal graph and accepts the following arguments:

- **x** (ndarray) -- real-valued data matrix of shape `[n, d]`
- **interv** (ndarray, optional) --  binary matrix of the same shape as **x** with **interv[i,j] = 1** iff node **j** was intervened upon in observation **i**. (Default is `None`)
- **return_probs** (bool, optional) -- whether to return probability estimates for each edge. `False` simply clips the predictions to 0 and 1 using a decision threshold of 0.5. (Default is `True` as the computational cost is the same.)
- **devices** (optional) - String definining the backend to use for computation (e.g., "cpu", "gpu"), or list of explicit JAX devices. Defaults to default JAX devices and backend.
- **shard_if_possible** (optional): whether to shard the computation across the observations axis (`n`) of the input when multiple devices are available. This may improve the memory footprint on device. Defaults to `True`.


In [None]:
%%time
# g_prob: [d, d] predicted edge probabilities of the causal graph
g_prob = model(x=x, interv=interv)

In [None]:
from avici.metrics import shd, classification_metrics, threshold_metrics

# visualize predictions and compute metrics
avici.visualize(g_prob, true=g, size=0.75)

print(f"SHD:   {shd(g, (g_prob > 0.5).astype(int))}")
print(f"F1:    {classification_metrics(g, (g_prob > 0.5).astype(int))['f1']:.4f}")
print(f"AUROC: {threshold_metrics(g, g_prob)['auroc']:.4f}")