# Stage 2: Model training

In this tutorial, we will walk through how to train the CASCADE model using the
preprocessed data from [stage 1](preprocessing.ipynb).

In [1]:
import networkx as nx
import pandas as pd
import scanpy as sc

from cascade.graph import acyclify, demultiplex, filter_edges, multiplex
from cascade.model import CASCADE

## Read preprocessed data

In [2]:
adata = sc.read_h5ad("adata.h5ad")

In [3]:
scaffold = nx.read_gml("scaffold.gml.gz")

In [4]:
latent_emb = pd.read_csv("latent_emb.csv.gz", index_col=0)

## Build the CASCADE model

The first step is to build a CASCADE model:

In [5]:
cascade = CASCADE(
    vars=adata.var_names,
    n_covariates=adata.obsm["covariate"].shape[1],
    scaffold_graph=scaffold,
    latent_data=latent_emb,
    log_dir="log_dir",
)

This creates a CASCADE model under the default setting. For advanced options,
visit the documentation of [CASCADE](api/cascade.model.CASCADE.rst) to find out
more about tunable hyperparameters, modules and their usage.

## Run causal discovery

> (Estimated time: 30 min – 1 hour, depending on computation device)

To run causal discovery using the CASCADE model, use the `discover` method:

In [6]:
cascade.discover(adata)
cascade.save("discover.pt")

[32m15:00:33.493[0m | [1mINFO    [0m | [33m1484074[0m:[36mutils[0m:[36mautodevice[0m - [1mUsing GPU [6] as computation device.[0m
[32m15:00:38.761[0m | [1mINFO    [0m | [33m1484074[0m:[36mnn[0m:[36mset_empirical[0m - [1mUsing theta coefficient = 4.150[0m
[32m15:00:38.763[0m | [1mINFO    [0m | [33m1484074[0m:[36mnn[0m:[36mset_empirical[0m - [1mUsing theta intercept = 0.819[0m



  | Name         | Type      | Params | Mode 
---------------------------------------------------
0 | scaffold     | Edgewise  | 129 K  | train
1 | sparse       | L1        | 0      | train
2 | acyc         | SpecNorm  | 0      | train
3 | kernel       | RBF       | 0      | train
4 | latent       | EmbLatent | 6.3 K  | train
5 | lik          | NegBin    | 0      | train
6 | func         | Func      | 18.9 M | train
  | other params | n/a       | 8.5 K  | n/a  
---------------------------------------------------
19.1 M    Trainable params
0         Non-trainable params
19.1 M    Total params
76.366    Total estimated model params size (MB)
16        Modules in train mode
0         Modules in eval mode


Training: |                                                                                                   …

[32m15:45:48.805[0m | [1mINFO    [0m | [33m1484074[0m:[36mmodel[0m:[36m_extrapolate_interv[0m - [1mExtrapolating scale and bias of 959 non-intervened variables from 105 intervened variables.[0m


This runs CASCADE causal discovery under the default setting. For advanced
options, visit the documentation of
[discover](api/cascade.model.CASCADE.discover.rst) for more details.

The same can also be achieved using the
[command line interface](cli.rst#causal-discovery),
with the following command:

```sh
cascade discover -d adata.h5ad -m discover.pt \
    --scaffold-graph scaffold.gml.gz \
    --latent-data latent_emb.csv.gz [other options]
```

> You may use `tensorboard --logdir .` to monitor the training process.

## Remove remaining cycles

Due to numerical limitations, some cycles may still remain in the resulting model.
We further use graph utility functions to ensure directed acyclic graphs, which
is required for downstream inferences.

In [7]:
graph = cascade.export_causal_graph()
graph = multiplex(*[acyclify(filter_edges(g, cutoff=0.5)) for g in demultiplex(graph)])
nx.write_gml(graph, "discover.gml.gz")

  0%|          | 0/32264 [00:00<?, ?it/s]

  0%|          | 0/5329 [00:00<?, ?it/s]

  0%|          | 0/5928 [00:00<?, ?it/s]

  0%|          | 0/6096 [00:00<?, ?it/s]

  0%|          | 0/5712 [00:00<?, ?it/s]

  0%|          | 0/12294 [00:00<?, ?it/s]

The same can also be achieved using the
[command line interface](cli.rst#graph-acyclification),
with the following command:

```sh
cascade acyclify -m discover.pt -g discover.gml.gz [other options]
```

## Model tuning

> (Estimated time: 15 min – 30 min, depending on computation device)

Next, we reimport the acyclified graph back into the model:

In [8]:
cascade.import_causal_graph(graph)

Now we can fine tune the structural equations in the model using the `tune`
method to adapt for removed edges during the acyclification step. It is also
recommended to enable the counterfactual tuning mode, where the tuning process
is specifically optimized for counterfactual prediction.

In [9]:
cascade.tune(adata, tune_ctfact=True)
cascade.save("tune.pt")

[32m15:46:28.404[0m | [1mINFO    [0m | [33m1484074[0m:[36mmodel[0m:[36mtune[0m - [1mPruning model...[0m


[32m15:46:29.371[0m | [1mINFO    [0m | [33m1484074[0m:[36mcore[0m:[36mfit_stage[0m - [1mNumber of topological generations: [68, 88, 70, 101][0m



  | Name         | Type      | Params | Mode 
---------------------------------------------------
0 | scaffold     | Edgewise  | 49.2 K | eval 
1 | sparse       | L1        | 0      | eval 
2 | acyc         | SpecNorm  | 0      | eval 
3 | kernel       | RBF       | 0      | eval 
4 | latent       | EmbLatent | 6.3 K  | train
5 | lik          | NegBin    | 0      | eval 
6 | func         | Func      | 7.9 M  | train
  | other params | n/a       | 8.5 K  | n/a  
---------------------------------------------------
7.9 M     Trainable params
49.2 K    Non-trainable params
8.0 M     Total params
31.921    Total estimated model params size (MB)
11        Modules in train mode
5         Modules in eval mode


Training: |                                                                                                   …

[32m16:11:39.865[0m | [1mINFO    [0m | [33m1484074[0m:[36mmodel[0m:[36m_extrapolate_interv[0m - [1mExtrapolating scale and bias of 959 non-intervened variables from 105 intervened variables.[0m


For advanced options, visit the documentation of
[tune](api/cascade.model.CASCADE.tune.rst) for more details.

The same can also be achieved using the
[command line interface](cli.rst#model-tuning),
using the following command:

```sh
cascade tune -d adata.h5ad -g discover.gml.gz -m discover.pt -o tune.pt \
    --tune-ctfact [other options]
```

Now this tuned model is ready for counterfactual prediction in
[stage 3](counterfactual.ipynb) and intervention design in
[stage 4](design.ipynb).