# Deep End-to-end Causal Inference: Demo Notebook

This notebook provides a showcase of the features provided by our open source code for Deep End-to-end Causal Inference (DECI).

 - We begin with a simple two node example, showing how DECI can orient an edge correctly when non-Gaussian noise is present, and how DECI can then be used for treatment effect estimation
 - We show how different graph constraints can be incorporated into DECI
 - We showcase DECI on a larger graph example
 
### Dataset availability
To use the notebook, the CSuite datasets need to be available. Ensure that you have run the CSuite data generation script in `causica/data_generation/csuite/simulate.py` before attempting to load datasets.

For Microsoft internal users, the datasets will be automatically downloaded from storage.

In [None]:
import os
# Use this to set the notebook's working directory to the top-level directory, where ./data is located
os.chdir("../../..")

In [None]:
from open_source.causica.experiment.steps.step_func import load_data
from open_source.causica.models.deci.deci import DECI
import seaborn as sns
import pandas as pd
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

def plot_true_graph(dataset):
    true_graph = nx.convert_matrix.from_numpy_matrix(dataset.get_adjacency_data_matrix(), create_using=nx.DiGraph)
    nx.draw_networkx(true_graph, arrows=True, with_labels=True)

In [None]:
%matplotlib inline

In [None]:
dataset_config = {'dataset_format': 'causal_csv', 'use_predefined_dataset': True, 'test_fraction': 0.1, 
                  'val_fraction': 0.1, 'random_seed': 0, 'negative_sample': False}
model_config = {'tau_gumbel': 0.25, 'lambda_dag': 100.0, 'lambda_sparse': 5.0, 'spline_bins': 8, 
                'var_dist_A_mode': 'enco', 'mode_adjacency': 'learn', 
                'norm_layers': True, 'res_connection': True, 'base_distribution_type': 'spline'}
# To speed up training you can try:
#  increasing learning_rate
#  increasing batch_size (reduces noise when using higher learning rate)
#  decreasing max_steps_auglag (go as low as you can and still get a DAG)
#  decreasing max_auglag_inner_epochs
training_params = {'learning_rate': 0.05, 'batch_size': 256, 'stardardize_data_mean': False, 
                   'stardardize_data_std': False, 'rho': 1.0, 'safety_rho': 10000000000000.0, 
                   'alpha': 0.0, 'safety_alpha': 10000000000000.0, 'tol_dag': 1e-04, 'progress_rate': 0.65, 
                   'max_steps_auglag': 5, 'max_auglag_inner_epochs': 2000, 'max_p_train_dropout': 0.6, 
                   'reconstruction_loss_factor': 1.0, 'anneal_entropy': 'noanneal'}

## Simplest example of end-to-end causal inference

In [None]:
try:
    from evaluation_pipeline.aml_run_context import setup_run_context_in_aml
    run_context = setup_run_context_in_aml()
except ImportError:
    from causica.experiment.run_context import RunContext
    run_context = RunContext()

To load the dataset, ensure that you have run the CSuite data generation script in `causica/data_generation/csuite/simulate.py`, ensure that the CSuite datasets have been created under `./data`, and ensure that the notebook's working directory has been set correctly.

In [None]:
dataset = load_data("csuite_linexp", "./data", 0, dataset_config, model_config, False, run_context.download_dataset)

In [None]:
train_data = pd.DataFrame(dataset._train_data, columns=["A", "B"])

In [None]:
train_data.head()

In [None]:
%matplotlib inline
sns.scatterplot(x=train_data["A"], y=train_data["B"])

Initially, it is unclear what the causal relationship between A and B is.

In [None]:
df = pd.DataFrame({'from': ['A'], 'to': ['B']})
G = nx.from_pandas_edgelist(df, 'from', 'to')
nx.draw_networkx(G, arrows=False, with_labels=True)

In [None]:
model = DECI("mymodel", dataset.variables, "mysavedir", device, **model_config) #change cuda to cpu if GPU is not available

In [None]:
model.run_train(dataset, run_context.metrics_logger, training_params)

## Causal discovery results

In [None]:
graph = model.networkx_graph()
print(graph.edges)

In [None]:
nx.draw_networkx(graph, arrows=True, with_labels=True)

We can compare this with the true graph:

In [None]:
plot_true_graph(dataset)

## Causal inference results

DECI has also fitted an SCM that captures the functional relationship and error distribution of this dataset.

We can estimate ATE and compare it to the ATE estimate from ground truth interventional data. Here we will compute E[B|do(A=1)] - E[B|do(A=-1)].

In [None]:
import numpy as np

intervention_idxs = np.array([0])
outcome_idx = 1

### Model-based ATE estimate
do_1 = model.sample(5000, intervention_idxs=intervention_idxs, intervention_values=np.array([1.])).cpu().numpy()
do_minus_1 = model.sample(5000, intervention_idxs=intervention_idxs, intervention_values=np.array([-1.])).cpu().numpy()
ate_estimate = do_1[:, outcome_idx].mean() - do_minus_1[:, outcome_idx].mean()
print("Estimated ATE:", ate_estimate)

In [None]:
### Interventional test data ATE
ate_true = dataset._intervention_data[0].test_data[:, outcome_idx].mean() - dataset._intervention_data[0].reference_data[:, outcome_idx].mean()
print("Interventional ATE:", ate_true)

In [None]:
print("Theoretical ATE is 1.")

In short, we can start from data, do causal discovery and causal inference, yielding treatment effect estimates that actions can be based upon.

## Graph constraints
First, train on a new dataset with no constraints. *Note*: this is a very difficult dataset in which all variables are strongly correlated with one another.

To load the dataset, first ensure that it has been generated under `./data`

In [None]:
simpson_data = load_data("csuite_nonlin_simpson", "./data", 0, dataset_config, model_config, False, run_context.download_dataset)

In [None]:
print(f"New dataset with {simpson_data.variables.num_groups} nodes.")

In [None]:
print("The true graph is:")
plot_true_graph(simpson_data)

In [None]:
simpson_df = pd.DataFrame(simpson_data._train_data, columns=simpson_data.variables.group_names)
simpson_df.head()

In [None]:
simpson_model = DECI("mymodel", simpson_data.variables, "mysavedir", device, **model_config)

In [None]:
# You may need more auglag steps / higher rho to make sure you do not get a non-DAG
training_params['max_auglag_inner_epochs'] = 2000
training_params['max_steps_auglag'] = 10

simpson_model.run_train(simpson_data, run_context.metrics_logger, training_params)

In [None]:
print(simpson_model.networkx_graph().edges)

If we are not happy with this DAG, we could add some constraints.

Constraints are encoded using an adjacency matrix where:
 - 0 indicates that there is no directed edge i → j,
 - 1 indicates that there has to be a directed edge i → j,
 - nan indicates that the directed edge i → j is learnable.
 
The following function converts from `tabu_` format into this matrix format, for use with DECI.

In [None]:
def make_constraint_matrix(variables, tabu_child_nodes=None,  tabu_parent_nodes=None, tabu_edges=None):
    """
    Makes a DECI constraint matrix from GCastle constraint format.

    Arguments:
        tabu_child_nodes: Optional[List[str]]
            nodes that cannot be children of any other nodes (root nodes)
        tabu_parent_nodes: Optional[List[str]]
            edges that cannot be the parent of any other node (leaf nodes)
        tabu_edge: Optional[List[Tuple[str, str]]]
            edges that cannot exist
    """

    constraint = np.full((variables.num_groups, variables.num_groups), np.nan)
    name_to_idx = {name: i for (i, name) in enumerate(variables.group_names)}
    if tabu_child_nodes is not None:
        for node in tabu_child_nodes:
            idx = name_to_idx[node]
            constraint[:, idx] = 0.0
    if tabu_parent_nodes is not None:
        for node in tabu_parent_nodes:
            idx = name_to_idx[node]
            constraint[idx, :] = 0.0
    if tabu_edges is not None:
        for source, sink in tabu_edges:
            source_idx, sink_idx = name_to_idx[source], name_to_idx[sink]
            constraint[source_idx, sink_idx] = 0.0
    return constraint.astype(np.float32)

### Adding constraint that a node is not a child
Let's suppose that 'Column 0' is not a child of anything (it's a root node).

In [None]:
training_params['max_auglag_inner_epochs'] = 1000
training_params['max_steps_auglag'] = 5

In [None]:
constraint = make_constraint_matrix(simpson_data.variables, tabu_child_nodes=['Column 0'])

In [None]:
simpson_model = DECI("mymodel", simpson_data.variables, "mysavedir", device, **model_config)
simpson_model.set_graph_constraint(constraint)

In [None]:
simpson_model.run_train(simpson_data, run_context.metrics_logger, training_params)

In [None]:
print(simpson_model.networkx_graph().edges)

### Adding constraint that a node is not a parent
Suppose we also want to specify that 'Column 3' is not a parent of anything (it's a leaf node).

In [None]:
constraint = make_constraint_matrix(
    simpson_data.variables, tabu_child_nodes=['Column 0'], tabu_parent_nodes=['Column 3']
)
simpson_model = DECI("mymodel", simpson_data.variables, "mysavedir", device, **model_config)
simpson_model.set_graph_constraint(constraint)
simpson_model.run_train(simpson_data, run_context.metrics_logger, training_params)

In [None]:
print(simpson_model.networkx_graph().edges)

### Adding constraint that an edge doesn't exist
Suppose we also want to specify that there is no edge Column 1 to Column 3.

In [None]:
constraint = make_constraint_matrix(
    simpson_data.variables, tabu_child_nodes=['Column 0'], tabu_parent_nodes=['Column 3'], 
    tabu_edges=[('Column 1', 'Column 3')]
)
simpson_model = DECI("mymodel", simpson_data.variables, "mysavedir", device, **model_config)
simpson_model.set_graph_constraint(constraint)
simpson_model.run_train(simpson_data, run_context.metrics_logger, training_params)

In [None]:
print(simpson_model.networkx_graph().edges)

### Adding a positive constraint
It's also possible with DECI to force an edge to exist. For example, suppose we decide to enforce the existence of the egde from Column 1 to Column 3.

In [None]:
constraint[0, 2] = 1.0
simpson_model = DECI("mymodel", simpson_data.variables, "mysavedir", device, **model_config)
simpson_model.set_graph_constraint(constraint)
simpson_model.run_train(simpson_data, run_context.metrics_logger, training_params)

In [None]:
print(simpson_model.networkx_graph().edges)

In [None]:
print("The correct graph is ", [(0, 1), (0, 2), (1, 2), (2, 3)])

## A larger graph example

In [None]:
large_data = load_data("csuite_large_backdoor", "./data", 0, dataset_config, model_config, False, run_context.download_dataset)

In [None]:
[train_row, train_col] = np.shape(large_data._train_data)

In [None]:
large_train_data = pd.DataFrame(large_data._train_data, columns=[f"X{i}" for i in range(9)])

In [None]:
large_train_data.head()

In [None]:
if train_col < 15:
    sns.pairplot(large_train_data)

In [None]:
large_model = DECI("mymodel", large_data.variables, "mysavedir", device, **model_config)

In [None]:
training_params['max_steps_auglag'] = 15
training_params['max_auglag_inner_epochs'] = 3000
large_model.run_train(large_data, run_context.metrics_logger, training_params)

In [None]:
large_graph = large_model.networkx_graph()

In [None]:
nx.draw_networkx(large_graph, arrows=True, with_labels=True)

In [None]:
print("The true graph is:")
plot_true_graph(large_data)

In [None]:
import numpy as np

### Model-based ATE estimate
do_1 = large_model.sample(5000, intervention_idxs=np.array([7]), intervention_values=np.array([2.])).cpu().numpy()
do_minus_1 = large_model.sample(5000, intervention_idxs=np.array([7]), intervention_values=np.array([0.])).cpu().numpy()
ate_estimate = do_1[:, 8].mean() - do_minus_1[:, 8].mean()
print("Estimated ATE:", ate_estimate)

In [None]:
### Interventional test data ATE
ate_true = large_data._intervention_data[0].test_data[:, 8].mean() - large_data._intervention_data[0].reference_data[:, 8].mean()
print("Interventional ATE:", ate_true)

## Imputation results

DECI also learns an imputation network that can be used to fill in missing data.

In [None]:
def make_missing(data):
    missing_data = data.copy()
    mask = np.full(missing_data.shape, fill_value=True, dtype=bool)
    n_rows, n_cols = data.shape
    for row in range(n_rows):
        i = np.random.choice(list(range(n_cols)))
        missing_data[row, i] = 0.
        mask[row, i] = False
    return missing_data, mask

In [None]:
data_with_missingness, mask = make_missing(dataset._train_data)

In [None]:
imputed = model.impute(data_with_missingness, mask)

In [None]:
ax = sns.scatterplot(x=dataset._train_data[~mask], y=imputed[~mask])
ax.set(xlabel="True value", ylabel="Imputed value")

## Analysing the DECI model

DECI gives us a simulator of the observational distribution.

In [None]:
simulation = pd.DataFrame(model.sample(5000).cpu().numpy(), columns=["A", "B"])

In [None]:
sns.scatterplot(train_data["A"], train_data["B"])

In [None]:
sns.scatterplot(simulation["A"], simulation["B"])

The DECI model also allows us to simulate from interventional distributions.

In [None]:
simulation_intervention = pd.DataFrame(
    model.sample(5000, intervention_idxs=np.array([0]), intervention_values=np.array([4.])).cpu().numpy(), 
    columns=["A", "B"]
)

In [None]:
simulation_intervention.min()

Intervening on A causes a change in the values of B.

In [None]:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
sns.kdeplot(simulation_intervention["B"], ax=ax1)
sns.kdeplot(train_data["B"].astype(np.float32), ax=ax2, color='r')

Intervening on B does not cause a change for A.

In [None]:
simulation_intervention2 = pd.DataFrame(
    model.sample(5000, intervention_idxs=np.array([1]), intervention_values=np.array([1.])).cpu().numpy(), 
    columns=["A", "B"]
)

In [None]:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
sns.kdeplot(simulation_intervention2["A"], ax=ax1)
sns.kdeplot(train_data["A"].astype(np.float32), ax=ax2, color='r')