# Causal Methods

In this notebook, we can check if our data features are causally connected or not by comparing predictive SHAP values with true causal effects. If these two are equivalent then interventions can be made based purely on SHAP values.
If they are not equivalent, then the predictive model is not learning true causal effects, only correlations.

Adapted from: https://shap.readthedocs.io/en/latest/example_notebooks/overviews/Be%20careful%20when%20interpreting%20predictive%20models%20in%20search%20of%20causal%C2%A0insights.html

Dependencies:

networkx, castle, xgboost

### Imports

In [1]:
import numpy as np
import pandas as pd
import scipy.stats
import sklearn
import xgboost

import os
os.environ['CASTLE_BACKEND'] = 'pytorch'
from collections import OrderedDict
import networkx as nx
from castle.algorithms import PC, GES, ICALiNGAM, GOLEM
import matplotlib.pyplot as plt

2023-06-22 08:15:37,622 - /opt/miniconda3/envs/env_gal/lib/python3.9/site-packages/castle/algorithms/__init__.py[line:36] - INFO: You are using ``pytorch`` as the backend.


### Simulate a fake dataset with a binary outcome

In [2]:
class FixableDataFrame(pd.DataFrame):
    """ Helper class for manipulating generative models.
    """
    def __init__(self, *args, fixed={}, **kwargs):
        self.__dict__["__fixed_var_dictionary"] = fixed
        super(FixableDataFrame, self).__init__(*args, **kwargs)
    def __setitem__(self, key, value):
        out = super(FixableDataFrame, self).__setitem__(key, value)
        if isinstance(key, str) and key in self.__dict__["__fixed_var_dictionary"]:
            out = super(FixableDataFrame, self).__setitem__(key, self.__dict__["__fixed_var_dictionary"][key])
        return out

# generate the data
def generator(n, fixed={}, seed=0):
    """ The generative model for our subscriber retention example.
    """
    if seed is not None:
        np.random.seed(seed)
    X = FixableDataFrame(fixed=fixed)

    # days to go before flight departure
    dtg = np.random.uniform(0, 365, size=(n,)).round()
    X["dtg"] = (1/np.random.uniform(1, 2, size=(n,)).round()) * dtg

    # Destination type, assume there are 3 categorical values: Beach, City, Domestic
    X["Destination Type"] = np.random.uniform(0, 2, size=(n,)).round()

    # The price of the flight depends on destination type
    X["Flight Price"] = (X["Destination Type"]+1)/2 *  np.random.uniform(45, 250, size=(n,))

    # Range of prices shown to the customer depends on destination type
    X["Range of Prices"] = (X["Destination Type"]+1) +  np.random.uniform(0, 50, size=(n,)).round()

    # number of alternative flights shown depends on destination type
    X["Alternative Flights"] = (X["Destination Type"]+1)  +  np.random.uniform(0, 20, size=(n,)).round()

    # did the user renew?
    X["Did convert"] = scipy.special.expit((
          0.5 / X["dtg"] \
        + 0.5 / X["Flight Price"] \
        + 0.15 * X["Destination Type"] \
        + 0.2 / X["Range of Prices"] \
        + 0.05 * X["Alternative Flights"]
        + 0.1 * np.random.normal(0, 1, size=(n,))
    ))

    # in real life we would make a random draw to get either 0 or 1 for if the
    # customer did or did not renew. but here we leave the label as the probability
    # so that we can get less noise in our plots. Uncomment this line to get
    # noiser causal effect lines but the same basic results
    X["Did convert"] = scipy.stats.bernoulli.rvs(X["Did convert"])

    return X

def user_retention_dataset():
    """ The observed data for model training.
    """
    n = 10000
    X_full = generator(n)
    y = X_full["Did convert"]
    X = X_full.drop(["Did convert"], axis=1)
    return X, y

def fit_xgboost(X, y):
    """ Train an XGBoost model with early stopping.
    """
    X_train,X_test,y_train,y_test = sklearn.model_selection.train_test_split(X, y)
    dtrain = xgboost.DMatrix(X_train, label=y_train)
    dtest = xgboost.DMatrix(X_test, label=y_test)
    model = xgboost.train(
        { "eta": 0.001, "subsample": 0.5, "max_depth": 2, "objective": "reg:logistic"}, dtrain, num_boost_round=200000,
        evals=((dtest, "test"),), early_stopping_rounds=20, verbose_eval=False
    )
    return model

### Fit Prediction Model

In [3]:
X, y = user_retention_dataset()
model = fit_xgboost(X, y)

### SHAP values from prediction model

In [None]:
import shap

explainer = shap.Explainer(model)
shap_values = explainer(X)

clust = shap.utils.hclust(X, y, linkage="single")
shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=1)

In [None]:
shap.plots.scatter(shap_values, ylabel="SHAP value\n(higher means more likely to convert)")

### Causal Approach

We can see the SHAP values think that the flight price has some causal impact on conversion rate but the true causal impact is flat

In [None]:
def marginal_effects(generative_model, num_samples=1000, columns=None, max_points=100, logit=True, seed=0):
    """ Helper function to compute the true marginal causal effects.
    """
    X = generative_model(num_samples)
    if columns is None:
        columns = X.columns
    ys = [[] for _ in columns]
    xs = [X[c].values for c in columns]
    xs = np.sort(xs, axis=1)
    xs = [xs[i] for i in range(len(xs))]
    for i,c in enumerate(columns):
        print(i, c)
        xs[i] = np.unique([np.nanpercentile(xs[i], v, method='nearest') for v in np.linspace(0, 100, max_points)])
        for x in xs[i]:
            Xnew = generative_model(num_samples, fixed={c: x}, seed=seed)
            val = Xnew["Did convert"].mean()
            if logit:
                val = scipy.special.logit(val)

            if (val == np.inf) or (val == -1 * np.inf) or (val == np.nan):
                val = 0
            ys[i].append(val)
        ys[i] = np.array(ys[i])
    ys = [ys[i] - ys[i].mean() for i in range(len(ys))]
    return list(zip(xs, ys))

shap.plots.scatter(shap_values, ylabel="SHAP value\n(higher means more likely to convert)", overlay={
    "True causal effects": marginal_effects(generator, 10000, X.columns)
})

### Causal Discovery
This section automatically generates causal DAGs based on observed data. You can pass in your features and target and these algorithms will try a number of different DAGs and score their relevance. This uses a number of different methods as the best one depends on the data so it is worth trying them all and seeing which seems more reasonable.

Use with caution. Try first with simulated data where you know the outcome. Make this simulated data as similar to what you expect from your true data as possible.

In [None]:
pc_dataset = np.vstack([X['Destination Type'], X['Alternative Flights'], X['Flight Price'], X['dtg'], y.values]).T

In [None]:
pc_dataset.shape

This gives us the adjacency matrix of our features. i.e if row 0 column 1 is 1 then feature 0 is causally impacting feature 1. We can then use networkx to draw the result of this.

In [None]:
pc = PC()
pc.learn(pc_dataset)

# Print out the learned matrix
print(pc.causal_matrix)

In [None]:
# Get learned graph
learned_graph = nx.DiGraph(pc.causal_matrix)

# Relabel the nodes
MAPPING = {k: v for k, v in zip(range(5), ['Destination Type', 'Alternative Flights', 'Flight Price', 'dtg', 'CvR'])}

learned_graph = nx.relabel_nodes(learned_graph, MAPPING, copy=True)
# Plot the graph
nx.draw_networkx(
    learned_graph,
    with_labels=True,
    node_size=1800,
    font_size=18,
    font_color='white'
)
plt.show()

### Now try more methods for completeness

We know that the result from GOLEM and LiNGAM are quite incorrect. However, PC and GES are very close to what the simulation intended. The bidirectional arrow is alarming but this does give us a good indication of the causal impacts of our features.

In [None]:
methods = OrderedDict({
    'PC': PC,
    'GES': GES,
    'LiNGAM': ICALiNGAM,
    'GOLEM': GOLEM
})

for method in methods:
    plt.figure(figsize=(6,4))
    plt.title(f'{method}')
    if method == 'GOLEM':
        model = methods[method](num_iter=2.5e4)
    else:
        model = methods[method]()

    # Fit the model
    model.learn(pc_dataset)

    # Get the DAG
    pred_dag = model.causal_matrix
    
    # Get learned graph
    learned_graph = nx.DiGraph(pred_dag)

    # Relabel the nodes
    MAPPING = {k: v for k, v in zip(range(5), ['Destination Type', 'Alternative Flights', 'Flight Price', 'dtg', 'CvR'])}

    learned_graph = nx.relabel_nodes(learned_graph, MAPPING, copy=True)
    # Plot the graph
    nx.draw_networkx(
        learned_graph,
        with_labels=True,
        node_size=1800,
        font_size=8,
        font_color='black'
    )
    plt.show()