## Tutorial: **Linear regression**

This tutorial demonstrates modeling and running inference on a simple univariate linear
regression model in Bean Machine. This should offer an accessible introduction to models
that use PyTorch tensors and Newtonian Monte Carlo inference in Bean Machine. It will
also teach you effective practices for prediction on new datasets with Bean Machine.

## Problem

In this classical linear regression problem, the goal is to estimate some unobserved
response variable from an observed covariate. We'll construct a Bayesian model for this
problem, which will yield not only point estimates but also measures of uncertainty in
our predictions.

We'll restrict this tutorial to the univariate case, to aid with clarity and
visualization.

## Prerequisites

We will be using the following packages within this tutorial.

* [arviz](https://arviz-devs.github.io/arviz/) and
  [bokeh](https://docs.bokeh.org/en/latest/docs/) for interactive visualizations; and
* [pandas](https://pandas.pydata.org/), [numpy](https://numpy.org/), and
  [scikit-learn](https://scikit-learn.org/) for data manipulation.

Let's code this in Bean Machine! Import the Bean Machine library and some fundamental
PyTorch classes.

In [1]:
# Install Bean Machine in Colab if using Colab.
import sys


if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
    !pip install beanmachine

In [2]:
import logging
import os
import warnings
from typing import List, Union

import arviz as az
import beanmachine.ppl as bm
import numpy as np
import pandas as pd
import sklearn.model_selection
import torch
import torch.distributions as dist
from beanmachine.ppl.inference.bmg_inference import BMGInference
from beanmachine.tutorials.utils import plots
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource
from bokeh.palettes import Colorblind3
from bokeh.plotting import gridplot, show
from IPython.display import Markdown
from torch import tensor

The next cell includes convenient configuration settings to improve the notebook
presentation as well as setting a manual seed for `torch` for reproducibility.

In [3]:
# Eliminate excess inference logging from Bean Machine, except for progress bars, and
# UserWarnings from Python.
logging.getLogger("beanmachine").setLevel(50)
warnings.filterwarnings("ignore")

# Plotting settings
az.rcParams["plot.backend"] = "bokeh"
az.rcParams["stats.hdi_prob"] = 0.89

# Manual seed
torch.manual_seed(12)

# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ

## Model

We're interested in predicting a response variable $y$ given an observed covariate $x$:

* $y = \beta_1 x + \beta_0 + \text{error}$

We can reframe this as:

* $y = \mathcal{N}(\beta_1 x + \beta_0, \epsilon)$

Here, $\beta_1$ is a coefficient for $x$, $\beta_0$ is a bias term, and $\epsilon$ is a
noise term. Specifically:

* $N \in \mathbb{Z}^+$ is the size of the training data.
* $x_i \in \mathbb{R}$ is the observed covariate.
* $\beta_1 \in \mathbb{R}$ is the coefficient for $x$. We'll use a prior of
  $\mathcal{N}(0,10)$.
* $\beta_0 \in \mathbb{R}$ is the bias term. We'll use a prior of $\mathcal{N}(0,10)$.
* $\epsilon \in \mathbb{R}^+$ is the error variance. We'll use a prior of
  $\text{Gamma}(1, 1)$.
* $y_i \stackrel{iid}{\sim} \mathcal{N}(\beta_1 x_i + \beta_0, \epsilon) \in \mathbb{R}$
  is the prediction.

We are interested in fitting posterior distributions for $\beta_1$, $\beta_0$, and
$\epsilon$ given a collection of training data $\{x, y\}_{i=1}^N$.

Let's visualize the Gamma distribution that we used as our prior for $\epsilon$:

In [4]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

concentration = 1
rate = 1
x = torch.arange(0, 5, 0.01)
y = dist.Gamma(concentration, rate).log_prob(x).exp()
cds = ColumnDataSource({"x": x.tolist(), "y": y.tolist()})

gamma_prior_plot = plots.line_plot(
    plot_sources=[cds],
    tooltips=[[("Density", "@y{0.000}"), ("ϵ", "@x{0.000}")]],
    figure_kwargs={
        "x_axis_label": "epsilon",
        "y_axis_label": "density",
        "title": f"Γ({concentration}, {rate}) prior",
    },
    plot_kwargs={"line_width": 2, "hover_line_color": "orange"},
)
show(gamma_prior_plot)

We can implement this model in Bean Machine by defining random variable objects with the
`@bm.random_variable` decorator. These functions behave differently than ordinary Python
functions.

<div
  style={
    {
      background: "#daeaf3",
      border_left: "3px solid #2980b9",
      display: "block",
      margin: "16px 0",
      padding: "12px",
    }
  }
>
  Semantics for <code>@bm.random_variable</code> functions:
  <ul>
    <li>
      They must return PyTorch <code>Distribution</code> objects.
    </li>
    <li>
      Though they return distributions, callees actually receive <i>samples</i> from the
      distribution. The machinery for obtaining samples from distributions is handled
      internally by Bean Machine.
    </li>
    <li>
      Inference runs the model through many iterations. During a particular inference
      iteration, a distinct random variable will correspond to exactly one sampled
      value: <b>calls to the same random variable function with the same arguments will
      receive the same sampled value within one inference iteration</b>. This makes it
      easy for multiple components of your model to refer to the same logical random
      variable.
    </li>
    <li>
      Consequently, to define distinct random variables that correspond to different
      sampled values during a particular inference iteration, an effective practice is
      to add a dummy "indexing" parameter to the function. Distinct random variables
      can be referred to with different values for this index.
    </li>
    <li>
      Please see the documentation for more information about this decorator.
    </li>
  </ul>
</div>

In [5]:
@bm.random_variable
def beta_1():
    return dist.Normal(0, 10)


@bm.random_variable
def beta_0():
    return dist.Normal(0, 10)


@bm.random_variable
def epsilon():
    return dist.Gamma(1, 1)


@bm.random_variable
def y(X):
    return dist.Normal(beta_0() + beta_1() * X, epsilon())

## Data

With the model defined, we need to collect some observed data in order to learn about
values of interest in our model.

In this case, we will observe a few samples of inputs and outputs. For demonstrative
purposes, we will use a synthetically generated dataset of observed values. In practice,
you would gather a collection of covariate and response variables, and then you could
construct a model to predict a new, unobserved response variable from a new, observed
covariate.

For our synthetic dataset, we will assume the following parameters to the relationship
between inputs and outputs.

In [6]:
true_beta_1 = 2.0
true_beta_0 = 5.0
true_epsilon = 1.0
N = 200

X = dist.Normal(0, 1).expand([N, 1]).sample()
Y = dist.Normal(true_beta_1 * X + true_beta_0, true_epsilon).sample()

We can visualize the data as follows:

In [7]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

synthetic_data_plot = plots.scatter_plot(
    plot_sources=ColumnDataSource(
        {"x": X.flatten().tolist(), "y": Y.flatten().tolist()}
    ),
    tooltips=[("y", "@y{0.000}"), ("x", "@x{0.000}")],
    figure_kwargs={"title": "Synthetic data", "x_axis_label": "x", "y_axis_label": "y"},
)
show(synthetic_data_plot)

Let's split the dataset into a training and test set, which we'll use later to evaluate
predictive performance.

In [8]:
X_train, X_test, Y_train, Y_test = sklearn.model_selection.train_test_split(X, Y)

Our inference algorithms expect observations in the form of a dictionary. This
dictionary should consist of `@bm.random_variable` invocations as keys, and tensor data
as values.

You can see this in the code snippet below, where we bind the observed values to a key
representing the random variable that was observed.

In [9]:
observations = {y(X_train): Y_train}

## Inference: Take 1

Inference is the process of combining _model_ with _data_ to obtain _insights_, in the
form of probability distributions over values of interest. Bean Machine offers a
powerful and general inference framework to enable fitting arbitrary models to data.

As a starting point for running inference, we will use the basic Metropolis-Hastings
inference algorithm. Ancestral Metropolis-Hastings is a simple inference algorithm,
which proposes child random variables conditional on values for the parent random
variables. The most ancestral random variables are simply sampled from the prior
distribution.

Running inference consists of a few arguments:

| Name           | Usage                                                                                                    |
| -------------- | -------------------------------------------------------------------------------------------------------- |
| `queries`      | List of `@bm.random_variable` targets to fit posterior distributions for.                                |
| `observations` | A dictionary of observations, as built above.                                                            |
| `num_samples`  | Number of Monte Carlo samples to approximate the posterior distributions for the variables in `queries`. |
| `num_chains`   | Number of separate inference runs to use. Multiple chains can help verify that inference ran correctly.  |

Let's run inference:

In [10]:
queries = [beta_1(), beta_0(), epsilon()]
num_samples = 2 if smoke_test else 2000
num_chains = 1 if smoke_test else 4

samples_mh = bm.SingleSiteAncestralMetropolisHastings().infer(
    queries=queries,
    observations=observations,
    num_samples=num_samples,
    num_chains=num_chains,
)

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

## Analysis: Take 1

`samples` now contains our inference results.

In [11]:
beta_0_marginal_mh = samples_mh[beta_0()].flatten(start_dim=0, end_dim=1).detach()
beta_1_marginal_mh = samples_mh[beta_1()].flatten(start_dim=0, end_dim=1).detach()
epsilon_marginal_mh = samples_mh[epsilon()].flatten(start_dim=0, end_dim=1).detach()

print(
    f"β0 marginal: {beta_0_marginal_mh}\n"
    f"β1 marginal: {beta_1_marginal_mh}\n"
    f" ϵ marginal: {epsilon_marginal_mh}"
)

β0 marginal: tensor([3.5978, 5.6188, 5.6188,  ..., 4.9685, 4.9685, 4.9685])
β1 marginal: tensor([0.5615, 0.5615, 0.5615,  ..., 1.9349, 1.9349, 1.9587])
 ϵ marginal: tensor([1.7924, 1.7924, 1.7924,  ..., 1.0664, 1.0664, 1.0664])


Next, let's visualize the inferred random variables.

In [12]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

beta_joint_plot = plots.plot_marginal(
    queries=[beta_0(), beta_1()],
    samples=samples_mh,
    true_values=[true_beta_0, true_beta_1],
    bandwidth=0.1,
    joint_plot_title="β marginal",
    figure_kwargs={"x_range": [2, 8], "y_range": [1, 3]},
)
show(beta_joint_plot)

We seem to have faithfully recovered $\beta_0$ but not $\beta_1$. It's possible that our
prior was too strong relative to the small amount of data.

In [13]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

epsilon_plot = plots.plot_marginal(
    queries=[epsilon()],
    samples=samples_mh,
    true_values=[true_epsilon],
    n_bins=100,
    bandwidth=0.025,
    figure_kwargs={"x_range": [0.8, 1.4]},
)
show(epsilon_plot)

We seem to have recovered a reasonably good understanding of the variance.

We can also compute log probability on the held-out test data. This isn't particularly
useful on its own, but is useful for comparing different approaches. Thus, here, we will
also plot a baseline for comparison: the log probability implied on the test dataset
using the ground truth parameters.

In [14]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

test_data_y = (
    dist.Normal(
        (
            X_test @ beta_1_marginal_mh[:num_samples].unsqueeze(0)
            + beta_0_marginal_mh[:num_samples].unsqueeze(0)
        ),
        epsilon_marginal_mh[:num_samples],
    )
    .log_prob(Y_test)
    .sum(dim=0)
)
test_data_x = list(range(len(test_data_y.tolist())))
test_data_cds = ColumnDataSource({"x": test_data_x, "y": test_data_y.tolist()})

ground_truth_y = (
    dist.Normal(X_test * true_beta_1 + true_beta_0, true_epsilon)
    .log_prob(Y_test)
    .sum(dim=0)
    .item()
)
ground_truth_cds = ColumnDataSource(
    {"x": test_data_x, "y": [ground_truth_y] * len(test_data_x)}
)

test_data_mh_plot = plots.line_plot(
    plot_sources=[ground_truth_cds, test_data_cds],
    labels=[f"Ground truth = {ground_truth_y:.2f}", "MH"],
    figure_kwargs={
        "y_axis_label": "Log probability",
        "title": "Log probability on test data",
    },
    plot_kwargs={"line_width": 3, "line_alpha": 0.7},
)
test_data_mh_plot.legend.location = "bottom_right"

show(test_data_mh_plot)

While the model doesn't seem to look healthy overall, it does at least seem to capture
the log probability using the ground truth parameters.

ArviZ provides helpful statistics about the results of inference, which we show below.

In [15]:
summary_mh_df = az.summary(samples_mh.to_xarray(), round_to=3)
Markdown(summary_mh_df.to_markdown())

|           |   mean |    sd |   hdi_5.5% |   hdi_94.5% |   mcse_mean |   mcse_sd |   ess_bulk |   ess_tail |   r_hat |
|:----------|-------:|------:|-----------:|------------:|------------:|----------:|-----------:|-----------:|--------:|
| epsilon() |  1.075 | 0.115 |      0.954 |       1.156 |       0.006 |     0.004 |    356.342 |    380.583 |   1.012 |
| beta_0()  |  5.034 | 0.172 |      4.888 |       5.203 |       0.013 |     0.011 |     78.559 |     52.15  |   1.275 |
| beta_1()  |  1.809 | 0.126 |      1.704 |       2.005 |       0.012 |     0.009 |     73.921 |     67.801 |   1.075 |

#### $\hat{R}$ diagnostic

$\hat{R}$ is a diagnostic tool that measures the between- and within-chain variances. It
is a test that indicates a lack of convergence by comparing the variance between
multiple chains to the variance within each chain. If the parameters are successfully
exploring the full space for each chain, then $\hat{R}\approx 1$, since the
between-chain and within-chain variance should be equal. $\hat{R}$ is calculated as

$$
\hat{R}=\frac{\hat{V}}{W}
$$

where $W$ is the within-chain variance and $\hat{V}$ is the posterior variance estimate
for the pooled rank-traces. The take-away here is that $\hat{R}$ converges towards 1
when each of the Markov chains approaches perfect adaptation to the true posterior
distribution. We do not recommend using inference results if $\hat{R}>1.01$. More
information about $\hat{R}$ can be found in the [Vehtari _et al_](#references) paper.

#### Effective sample size $ess$ diagnostic

MCMC samplers do not draw independent samples from the target distribution, which means
that our samples are correlated. In an ideal situation all samples would be independent,
but we do not have that luxury. We can, however, measure the number of _effectively
independent_ samples we draw, which is called the effective sample size. You can read
more about how this value is calculated in the [Vehtari _et al_](#references) paper,
briefly it is a measure that combines information from the $\hat{R}$ value with the
autocorrelation estimates within the chains. There are many ways to estimate effective
samples sizes, however, we will be using the method defined in the [Vehtari _et
al_](#references) paper.

The rule of thumb for `ess_bulk` is for this value to be greater than 100 per chain on
average. Since we ran four chains, we need `ess_bulk` to be greater than 400 for each
parameter. The `ess_tail` is an estimate for effectively independent samples considering
the more extreme values of the posterior. This is not the number of samples that landed
in the tails of the posterior. It is a measure of the number of effectively independent
samples if we sampled the tails of the posterior. The rule of thumb for this value is
also to be greater than 100 per chain on average.

In this case, both $\hat{R}$ and $ess$ leave something to be desired.

We can plot diagnostic information to assess model fit using ArviZ. Let's take a look:

In [16]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

samples_mh_diagnostic_plots = gridplot(plots.plot_diagnostics(samples_mh))
show(samples_mh_diagnostic_plots)

The diagnostics output shows two diagnostic plots for individual random variables: trace
plots and autocorrelation plots.

* **Rank plots** are a histogram of the samples over time. All samples across
  all chains are ranked and then we plot the average rank for each chain on
  regular intervals. If the chains are mixing well this histogram should look
  roughly uniform. If it looks highly irregular that suggests chains might be
  getting stuck and not adequately exploring the sample space.

* **Autocorrelation plots** measure how predictive the last several samples are
  of the current sample. Autocorrelation may vary between -1.0
  (deterministically anticorrelated) and 1.0 (deterministically correlated). We
  compute autocorrelation approximately, so it may sometimes exceed these
  bounds. In an ideal world, the current sample is chosen independently of the
  previous samples: an autocorrelation of zero. This is not possible in
  practice, due to stochastic noise and the mechanics of how inference works.

From the trace plots, we see each of the chains are relatively healthy: they don't get
stuck, and do not explore a chain-specific subset of the space. However, the trace plots
are fairly "blocky", indicating that many traces are waste due to too low of an
acceptance rate for new parameter values.

Let's see if we can do better by using another inference approach.

## Inference: Take 2

To improve upon our first attempt, let's use gradient information to help guide the
sampling process. Since this model is comprised entirely of differentiable random
variables, we'll make use of the Newtonian Monte Carlo (NMC) inference method. NMC is a
second-order method, which uses the Hessian to automatically scale the step size in each
dimension.

In [17]:
samples_nmc = bm.SingleSiteNewtonianMonteCarlo().infer(
    queries=queries,
    observations=observations,
    num_samples=num_samples,
    num_chains=num_chains,
)

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

## Analysis: Take 2

In [18]:
beta_0_marginal_nmc = samples_nmc[beta_0()].flatten(start_dim=0, end_dim=1).detach()
beta_1_marginal_nmc = samples_nmc[beta_1()].flatten(start_dim=0, end_dim=1).detach()
epsilon_marginal_nmc = samples_nmc[epsilon()].flatten(start_dim=0, end_dim=1).detach()

In [19]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

beta_nmc_plot = plots.plot_marginal(
    queries=[beta_0(), beta_1()],
    samples=samples_nmc,
    true_values=[true_beta_0, true_beta_1],
    joint_plot_title="β marginal",
)
show(beta_nmc_plot)

In [20]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

epsilon_nmc_plot = plots.plot_marginal(
    queries=[epsilon()],
    samples=samples_nmc,
    true_values=[true_epsilon],
    n_bins=100,
)
show(epsilon_nmc_plot)

The marginal distributions look much healthier than with the Metropolis-Hastings
approach. However, $\beta_1$ still attenuates to zero more than expected. This is
possibly good evidence that our prior is having a strong influence.

In [21]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

test_log_prob_nmc_y = (
    dist.Normal(
        (
            X_test @ beta_1_marginal_nmc[:num_samples].unsqueeze(0)
            + beta_0_marginal_nmc[:num_samples].unsqueeze(0)
        ),
        epsilon_marginal_nmc[:num_samples],
    )
    .log_prob(Y_test)
    .sum(dim=0)
)
test_log_prob_nmc_x = list(range(len(test_log_prob_nmc_y.tolist())))
test_log_prob_nmc_cds = ColumnDataSource(
    {"x": test_log_prob_nmc_x, "y": test_log_prob_nmc_y.tolist()}
)

ground_truth_y = (
    dist.Normal(X_test * true_beta_1 + true_beta_0, true_epsilon)
    .log_prob(Y_test)
    .sum(dim=0)
    .item()
)
ground_truth_cds = ColumnDataSource(
    {"x": test_data_x, "y": [ground_truth_y] * len(test_data_x)}
)

test_data_nmc_plot = plots.line_plot(
    plot_sources=[ground_truth_cds, test_data_cds, test_log_prob_nmc_cds],
    labels=[f"Ground truth = {ground_truth_y:.2f}", "MH", "NMC"],
    figure_kwargs={
        "y_axis_label": "Log probability",
        "title": "Log probability on test data",
    },
    plot_kwargs={"line_width": 3, "line_alpha": 0.7},
)
test_data_nmc_plot.legend.location = "bottom_right"

show(test_data_nmc_plot)

NMC seems to have a very healthy log probability, and it successfully captures the log
probability implied by the ground truth parameters on the test dataset.

Lastly, let's look at the diagnostics.

In [22]:
summary_nmc_df = az.summary(samples_nmc.to_inference_data())
Markdown(summary_nmc_df.to_markdown())

|           |   mean |    sd |   hdi_5.5% |   hdi_94.5% |   mcse_mean |   mcse_sd |   ess_bulk |   ess_tail |   r_hat |
|:----------|-------:|------:|-----------:|------------:|------------:|----------:|-----------:|-----------:|--------:|
| epsilon() |  1.095 | 0.532 |      0.378 |       1.862 |       0.264 |     0.202 |          5 |          4 |    2.77 |
| beta_0()  |  5.048 | 0.131 |      4.897 |       5.22  |       0.002 |     0.001 |       5976 |         91 |    1.16 |
| beta_1()  |  1.804 | 0.104 |      1.654 |       1.966 |       0.001 |     0.001 |       6029 |         96 |    1.17 |

In [23]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

nmc_diagnostic_plots = gridplot(plots.plot_diagnostics(samples_nmc))
show(nmc_diagnostic_plots)

All the diagnostics look healthy. Looks like we have a winner!

## Prediction

We've built and evaluated our model. Lastly, let's take a quick look at how to predict
with it.

In [24]:
def predict(x):
    if not isinstance(x, torch.Tensor):
        x = tensor(x).float()
    return pd.DataFrame(
        np.percentile(
            dist.Normal(
                x.view([-1, 1]) @ beta_1_marginal_nmc.unsqueeze(0)
                + beta_0_marginal_nmc.unsqueeze(0),
                epsilon_marginal_nmc.unsqueeze(0),
            )
            .sample([10])
            .transpose(0, 1)
            .flatten(1),
            [2.5, 50, 97.5],
            axis=1,
        ).T,
        index=x.view(-1).numpy(),
        columns=["2.5%", "50%", "97.5%"],
    )

Predict for a single value:

In [25]:
Markdown(predict(4).to_markdown())

|    |    2.5% |     50% |   97.5% |
|---:|--------:|--------:|--------:|
|  4 | 9.55242 | 12.2677 | 15.0144 |

Or for a range:

In [26]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

predicted_df = predict(torch.linspace(-10, 10, 100))
low_cds = ColumnDataSource(
    {"x": predicted_df.index.tolist(), "y": predicted_df["2.5%"].tolist()}
)
med_cds = ColumnDataSource(
    {"x": predicted_df.index.tolist(), "y": predicted_df["50%"].tolist()}
)
high_cds = ColumnDataSource(
    {"x": predicted_df.index.tolist(), "y": predicted_df["97.5%"].tolist()}
)
prediction_plot = plots.line_plot(
    plot_sources=[low_cds, med_cds, high_cds],
    labels=predicted_df.columns.tolist(),
    figure_kwargs={"x_axis_label": "x", "y_axis_label": "y", "title": "Predictions"},
    plot_kwargs={"line_width": 2, "line_alpha": 0.6},
)
prediction_plot.legend.location = "top_left"
prediction_plot.circle(
    x=X.flatten().tolist(),
    y=Y.flatten().tolist(),
    fill_color="brown",
    line_color="white",
    fill_alpha=0.6,
    size=7,
    legend_label="Simulated data",
)

show(prediction_plot)

## BMGInference

Bean Machine Graph (BMG) Inference is an experimental feature of the Bean Machine
framework that aims to deliver higher performance for specialized models. The model used
in this tutorial represents a static probabilistic graph model and happens to use only
features within the language subset supported by BMGInference. As a reference point, the
following code reports the time it takes for our basic implementation of NMC to compute
the posterior:

In [27]:
%%time
samples_nmc = bm.SingleSiteNewtonianMonteCarlo().infer(
    queries=queries,
    observations=observations,
    num_samples=num_samples,
    num_chains=num_chains,
)

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/2000 [00:00<?, ?it/s]

CPU times: user 52.9 s, sys: 123 ms, total: 53 s
Wall time: 53.1 s


To run our model using BMGInference, the only change needed is the following:

In [28]:
%%time
samples_bmg = BMGInference().infer(
    queries=queries,
    observations=observations,
    num_samples=num_samples,
    num_chains=num_chains,
)


0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
CPU times: user 1.79 s, sys: 9.95 ms, total: 1.8 s
Wall time: 1.31 s
***************************************************


Wall time numbers will naturally vary on different platforms, but with with these
parameters (model, observations, queries, sample size, and number of chains) speedup on
the author's machine is about 60x. Generally speaking, larger speedups are expected with
larger sample sizes. More information about `BMGInference` can be found on the website
in "Advanced" section of the documentation.

In [29]:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

bmg_diagnostic_plots = gridplot(plots.plot_diagnostics(samples_bmg))
show(bmg_diagnostic_plots)

<a id="references"></a>

## References

* Vehtari A, Gelman A, Simpson D, Carpenter B, Bürkner PC (2021)
  **Rank-Normalization, Folding, and Localization: An Improved $\hat{R}$ for
  Assessing Convergence of MCMC (with Discussion)**. Bayesian Analysis 16(2)
  667–718. [doi: 10.1214/20-BA1221](https://dx.doi.org/10.1214/20-BA1221).