# Introduction

This notebook introduces the code framework for reproducing the results of
> Malek, Aglietti, Chiappa. "Additive Causal Bandits with Unknown Graph." ICML, 2023.


Copyright 2023 DeepMind Technologies Limited. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
#@title Install the package
!pip install git+https://github.com/deepmind/additive_cbug

In [None]:
#@title Imports
import numpy as np
import matplotlib.pyplot as plt

from cbug import base
from cbug import discrete_scm
from cbug import discrete_scm_utils as d_utils
from cbug import scm
from cbug import stoc_fn_utils as sf_utils
from cbug import run

## First Example: Chain Graph

Our implementation is based on the SCM class with allows you to specify the SCM
with stochastic functions, which are functions that are potentially random. We will set $X_0$ to be discrete and i.i.d. uniform, $X_1$ to be $X_0$ plus discrete noise, and $Y$, the target variable, equal to $X_1$ with Gaussian noise.

In [None]:
stoc_fns = {}

# Give the support size of each X. Y must be univariate, real-value3.
support_sizes = {"X0": 5, "X1": 8}
# Because X0 has no parents, in needs to take n_samples as a parent.
y_mean = lambda X1: X1
y_stoc_fn = lambda X1: X1 + np.random.normal(size=X1.shape)
stoc_fns = {
    "X0": lambda n_samples: np.random.choice(np.arange(support_sizes["X0"]), n_samples),
    "X1": lambda X0: X0 + np.random.choice([0, 1, 2], size=X0.shape),
    "Y": y_stoc_fn,
}

# Alternatively, we can define Y_stoc_fn is terms of the Y_mean stoc fn:
stoc_fns["Y"] = scm.StocFnRecipe(
    stoc_fn=sf_utils.add_gaussian_noise_to_stoc_fn(y_mean, cov=1),
    stoc_fn_inputs=["X1"],
)

model = discrete_scm.DiscreteSCM(stoc_fns=stoc_fns,
                                 support_sizes=support_sizes,
                                 outcome_variable="Y",
                                 best_action=11,
                                 best_action_value=y_mean(11),
                                 outcome_expected_value_fn=y_mean,
)
results = base.run_modl(
    delta=.5,
    epsilon=.1,
    model=model,
    cov=1,
    outcome_bound=3 + 12
)
print(results)


In [None]:
# @title Running the algorithms from the paper on our example

# Successive Elimination
se_results = base.run_se(
    delta=.5,
    epsilon=.1,
    model=model,
    cov=1,
    outcome_bound=3 + 12
)
print("Results for Successive Elimination")
print(se_results)

# And the oracle algorithm that is given the parents of Y first.
oracle_results = base.run_modl(
    delta=.5,
    epsilon=.1,
    model=model,
    cov=1,
    outcome_bound=3 + 12,
    opt_scope=["X1"],
)
print("Results for MODL when the true parents are provided")
print(oracle_results)

## Introduction to the SCMs used in the paper

The d_utils module provides methods for randomly generating the SCMs used in the paper; a description can be found there and also below.

Using $N_V$ to denote `num_varables` and $N_{Pa}$ to denote `num_parents`, the generating process is as follows. A Erdos-Renyi graph with degree `degree` is sampled over variables $X_1,\ldots, X_{N_V}$; $N_{Pa}$ variables are then chosen uniformly as the parents of $Y$, and w.l.o.g. call them $X_1,\ldots, X_{N_{Pa}}$. We make every $X_j$ that is not an ancestor of $Y$ a child of $Y$ with probability `prob_outcome_child`.

Each variable $X_i$ has support $1, \ldots, M_i$, where $M_i$ is uniformly sampled between `support_size_min` and `support_size_max`. If $X_i$ has no parents, it is a categorical with distribution sampled from a Dirichlet with parameter `dir_alpha`. Otherwise, the conditional distribution $p(X_i| X_{i_1},\ldots, X_{i_m})$, where $i_1,\ldots, i_m$ are the parents of $X_i$, is sampled independently for each value of $X_{i_1},\ldots, X_{i_m}$ by generating $p_1,\ldots, p_{M_i}\stackrel{i.i.d.}{\sim}$Beta(`alpha`,`beta`) then assigning the conditional of $\{X_i = j\}$ to be proportional to $p_j$. If $Y$ is a parent of $X_i$, then the same construction is used on a discritized version of $Y$.

The stochastic function for the response, $Y$, is generated in the following way. We sample $\beta_1,\ldots, \beta_{N_{Pa}}$ where each term is i.i.d. Uniform(-`mean_bound`, `mean_bound`); these parameters specify the linear term, and `cov` specifies the random additive Gaussian noise. The `interactions` list specifies the number and size of the interaction terms in the non-linear component of $Y$'s stochastic function. An interaction term of size $m$ is equal to $\prod_{i=1}^m X_{j_i}$, where $j_i$ is an index chosen uniformly at random from $Y$'s parents. In total, using $Z$ as a Normal with covariance `cov` and $\{m_1,\ldots, m_k\}=$ `interactions`,
$$
  Y|x_1,\ldots, x_{P} = \sum_{i=1}^{N_{Pa}} \beta_i[x_i] + \gamma \sum_{j=1}^k \prod_{i=1}^{m_j} X_{j_i} + Z,
$$
where $\gamma=$ `interaction_magnitude`.

In [None]:
np.random.seed(0)  # Ensures Y has two parent and one child
model = d_utils.sample_discrete_additive_scm(
    num_variables=3,  # Number of variables, excluding Y
    degree=2,  # Average degree of the Erdos-Renyi graph
    num_parents=2,  # The number of parents of Y: parents are chosen i.i.d. from all variables
    prob_outcome_child=.5,  # Each variable topologically after Y is made a child with this probability
    cov=1,  # Y is a linear Gaussian with this covariance
    mean_bound=1,  # Y has linear coefficients sampled in [-mean_bound, mean_bound].
    # Other variables are discrete with number of values in [support_size_min, support_size_max]
    support_size_min=2,
    support_size_max=3,
    interactions=[2, 2],  # Specifies the non-linear interactions terms for Y.
    interaction_magnitude=1e-3,  # The scale of the non-linear terms.
    alpha=2,  # Hyper parameter for generating the X conditional distributions.
    beta=5,  # Hyper parameter for generating the X conditional distributions.
    dir_alpha=.5,  # Hyper parameter for the X conditional distributions
)
# Draw the model's DAG
model.draw()

## We can also run all algorithms for a fixed set of SCMs generating parameters using the run model.

In [None]:
default_scm_params = {
    "num_variables": 20,
    "num_parents": 4,
    "degree": 3.0,
    "mean_bound": 1.0,
    "cov": 1.0,
    "interaction_magnitude": 0.0,
    "interactions": None,
    "alpha": 2.0,
    "beta": 5.0,
    "support_size_min": 3,
    "support_size_max": 6,
}


results = run.single_experiment(
    scm_params=default_scm_params,  # default SCM parameters.
    epsilon=.5,  # error tolerance
    delta=.1,  # probability of error tolerance
    include_se=False,  # whether to include the successive elimination baseline
)
print(results)

# Generating the Plots from the paper

Using the utilities defined above, we can generate the plots from the paper in an automatic fashion. We first define a few run.sweeping functions.

In [None]:
def plot_samples_and_gap(results, x_values, x_label, axs, include_se=False):
  ax1, ax2 = axs

  modl_samples = [results[i]["MODL_mean_samples"] for i in range(len(x_values))]
  oracle_samples = [results[i]["oracle_mean_samples"] for i in range(len(x_values))]
  parents_first_samples = [results[i]["parents_first_mean_samples"] for i in range(len(x_values))]

  modl_gaps = [results[i]["MODL_mean_gap"] for i in range(len(x_values))]
  oracle_gaps = [results[i]["oracle_mean_gap"] for i in range(len(x_values))]
  parents_first_gaps = [results[i]["parents_first_mean_gap"] for i in range(len(x_values))]

  ax1.plot(x_values, modl_samples, label="MODL")
  ax1.plot(x_values, parents_first_samples, label="two stage")
  ax1.plot(x_values, oracle_samples, label="parent oracle")
  ax1.set(xlabel=x_label, ylabel="average sample complexity")

  ax2.plot(x_values, modl_gaps, label="MODL")
  ax2.plot(x_values, parents_first_gaps, label="two stage")
  ax2.plot(x_values, oracle_gaps, label="parent oracle")
  ax2.set(xlabel=x_label, ylabel="average gap")

  if include_se:
    samples = [results[i]["se_mean_samples"] for i in range(len(x_values))]
    gaps = [results[i]["se_mean_gap"] for i in range(len(x_values))]
    ax1.plot(x_values, samples, label="successive elimination")
    ax2.plot(x_values, gaps, label="successive elimination")


In [None]:
#@title Plot  sample complexity vs. number of parents.
scm_params = {
    "num_variables": 10,
    "num_parents": 1,
    "degree": 3.0,
    "mean_bound": 5.0,
    "cov": 1.0,
    "interaction_magnitude": 0.0,
    "interactions": None,
    "alpha": 2.0,
    "beta": 5.0,
    "support_size_min": 3,
    "support_size_max": 6,
}

num_scms = 20  # Set to 20 to match the settings in the paper.
num_seeds = 5  # Set to 5 to match the settings in the paper.
# Generate results for num_parents run.sweep.
fig, axs = plt.subplots(2, 4, figsize=(20, 8))

num_parents = np.arange(1, 11)
var_10_num_par_known = run.sweep(scm_params, "num_parents", num_parents,
                                 known_num_parents=True,
                                 num_scms=num_scms,
                                 num_seeds=num_seeds)
var_10_num_par_unknown=run.sweep(scm_params, "num_parents", num_parents,
                                 known_num_parents= False,
                                 num_scms=num_scms,
                                 num_seeds=num_seeds)
plot_samples_and_gap(var_10_num_par_known, num_parents, "num_parents", axs[:, 0])
axs[0][0].set(title="10 Variables, Num Parents Known")
plot_samples_and_gap(var_10_num_par_unknown, num_parents, "num_parents", axs[:, 1])
axs[0][1].set(title="10 Variables, Num Parents Unknown")

scm_params["num_variables"] = 30
num_parents = np.arange(1, 31)
var_30_num_par_known = run.sweep(scm_params, "num_parents", num_parents,
                                 known_num_parents=True,
                                 num_scms=num_scms,
                                 num_seeds=num_seeds)
var_30_num_par_unknown = run.sweep(scm_params, "num_parents", num_parents,
                                   known_num_parents=False,
                                   num_scms=num_scms,
                                   num_seeds=num_seeds)
plot_samples_and_gap(var_30_num_par_known, num_parents, "num_parents", axs[:, 2])
axs[0][2].set(title="30 Variables, Num Parents Known")
plot_samples_and_gap(var_30_num_par_unknown, num_parents, "num_parents", axs[:, 3])
axs[0][3].set(title="30 Variables, Num Parents Unknown")


plt.rcParams.update({"font.size": 14})
plt.legend(loc="upper right")

In [None]:
#@title Compare with successive elimination on sample complexity vs. number of parents.

scm_params = {
    "num_variables": 4,
    "num_parents": 1,
    "degree": 3.0,
    "mean_bound": 5.0,
    "cov": 1.0,
    "interaction_magnitude": 0.0,
    "interactions": None,
    "alpha": 2.0,
    "beta": 5.0,
    "support_size_min": 3,
    "support_size_max": 6,
}
num_scms = 20  # Set to 20 to match the settings in the paper.
num_seeds = 5  # Set to 5 to match the settings in the paper.

# Generate results for num_parents run.sweep.
num_parents = np.arange(1, 4)
var_4_num_par_unknown = run.sweep(scm_params, "num_parents", num_parents,
                                  include_se=True,
                                  known_num_parents=False,
                                  num_scms=num_scms,
                                  num_seeds=num_seeds)
scm_params["num_variables"] = 6
var_6_num_par_unknown = run.sweep(scm_params, "num_parents", num_parents,
                                  include_se=True,
                                  known_num_parents=False,
                                  num_scms=num_scms,
                                  num_seeds=num_seeds)

fig, axs = plt.subplots(2, 2, figsize=(14, 8))
plot_samples_and_gap(var_4_num_par_unknown,
                     num_parents,
                     "num_parents",
                     axs[:, 0],
                     include_se=True)
axs[0][0].set(title="4 Variables, Num Parents Known")
plot_samples_and_gap(var_6_num_par_unknown,
                     num_parents,
                     "num_parents",
                     axs[:, 1],
                     include_se=True)
axs[0][1].set(title="6 Variables, Num Parents Unknown")

plt.rcParams.update({"font.size": 10})
plt.legend(loc="upper right")

In [None]:
#@title Plot sample complexity vs. support size
sizes = np.arange(6, 24)
scm_params = {
    "num_variables": 30,
    "num_parents": 10,
    "degree": 3.0,
    "mean_bound": 1.0,
    "cov": 1.0,
    "interaction_magnitude": 0.0,
    "interactions": None,
    "alpha": 2.0,
    "beta": 5.0,
    "support_size_min": 3,
}
num_scms = 20  # Set to 20 to match the settings in the paper.
num_seeds = 5  # Set to 5 to match the settings in the paper.

fig, axs = plt.subplots(2, 4, figsize=(12, 4))

results_10_par_num_par_known = run.sweep(scm_params,
                                         "support_size_max",
                                         sizes,
                                         known_num_parents=True,
                                         num_scms=num_scms,
                                         num_seeds=num_seeds)
plot_samples_and_gap(results_10_par_num_par_known,
                     sizes,
                     "support_size_max",
                     axs[:, 0])
axs[0][0].set(title="10 Parents, Num Parents Know")
results_10_par_num_par_unknown = run.sweep(scm_params,
                                           "support_size_max",
                                           sizes,
                                           known_num_parents=False,
                                           num_scms=num_scms,
                                           num_seeds=num_seeds)
plot_samples_and_gap(results_10_par_num_par_unknown,
                     sizes,
                     "support_size_max",
                     axs[:, 1])
axs[0][1].set(title="10 Parents, Num Parents Unknown")

scm_params["num_parents"] = 30
results_30_par_num_par_known = run.sweep(scm_params,
                                         "support_size_max",
                                         sizes,
                                         known_num_parents=True,
                                         num_scms=num_scms,
                                         num_seeds=num_seeds)
plot_samples_and_gap(results_30_par_num_par_known,
                     sizes,
                     "support_size_max",
                     axs[:, 2])
axs[0][2].set(title="30 Parents, Num Parents Known")
results_30_par_num_par_unknown = run.sweep(scm_params,
                                           "support_size_max",
                                           sizes,
                                           known_num_parents=False,
                                           num_scms=num_scms,
                                           num_seeds=num_seeds)
plot_samples_and_gap(results_30_par_num_par_unknown,
                     sizes,
                     "support_size_max",
                     axs[:, 3])
axs[0][3].set(title="30 Parents, Num Parents Unknown")


plt.rcParams.update({"font.size": 14})
plt.legend(loc="upper right")

In [None]:
#@title Performance vs. model mispecification

interaction_magnitude_sweep = np.linspace(.001, .1, num=10)
scm_params = {
    "num_variables": 30,
    "num_parents": 10,
    "degree": 3.0,
    "mean_bound": 1.0,
    "cov": 1.0,
    "interactions": [2, 3, 3],
    "alpha": 2.0,
    "beta": 5.0,
    "support_size_min": 3,
    "support_size_max": 6,
}
num_scms = 20  # Set to 20 to match the settings in the paper.
num_seeds = 5  # Set to 5 to match the settings in the paper.

fig, axs = plt.subplots(2, 2, figsize=(12, 4))

results_10_par_num_par_known = run.sweep(scm_params, "interaction_magnitude",
                                         interaction_magnitude_sweep,
                                         known_num_parents=True,
                                         num_scms=num_scms,
                                         num_seeds=num_seeds)
plot_samples_and_gap(results_10_par_num_par_known, interaction_magnitude_sweep,
                     "interaction_magnitude",
                     axs[:, 0])
axs[0][0].set(title="10 Parents, Num Parents Known",
              xlabel="model mispecification")
results_10_par_num_par_unknown = run.sweep(scm_params, "interaction_magnitude",
                                           interaction_magnitude_sweep,
                                           known_num_parents=False,
                                           num_scms=num_scms,
                                           num_seeds=num_seeds)
plot_samples_and_gap(results_10_par_num_par_unknown,
                     interaction_magnitude_sweep,
                     "interaction_magnitude",
                     axs[:, 1])
axs[0][1].set(title="10 Parents, Num Parents Unknown")

plt.rcParams.update({"font.size": 14})
plt.legend(loc="upper right")

In [None]:
#@title Performance vs. average graph degree
degree_values = np.linspace(1, 8, num=2)
scm_params = {
    "num_variables": 30,
    "num_parents": 10,
    "mean_bound": 1.0,
    "cov": 1.0,
    "interactions": [2, 3, 3],
    "alpha": 2.0,
    "beta": 5.0,
    "support_size_min": 3,
    "support_size_max": 6,
    "interaction_magnitude": 0.0,
}
num_scms = 20  # Set to 20 to match the settings in the paper.
num_seeds = 5  # Set to 5 to match the settings in the paper.

fig, axs = plt.subplots(2, 2, figsize=(12, 4))

par_known = run.sweep(scm_params, "degree", degree_values,
                      known_num_parents=True,
                      num_scms=num_scms,
                      num_seeds=num_seeds)
plot_samples_and_gap(par_known, degree_values, "degree", axs[:, 0])
axs[0][0].set(title="Num Parents Known", xlabel="average graph degree")

par_unknown = run.sweep(scm_params, "degree", degree_values,
                        known_num_parents=False,
                        num_scms=num_scms,
                        num_seeds=num_seeds)
plot_samples_and_gap(par_unknown, degree_values, "degree", axs[:, 1])
axs[1][0].set(title="Num Parents Unknown", xlabel="average graph degree")


plt.rcParams.update({"font.size": 14})
plt.legend(loc="upper right")