# Visual diagnostics widget proposal
---

See the accompanying [Quip](https://quip.com/ujr1AxjEGoYk/Visual-diagnostics-proposal)
for a high level view of the following proposal. This notebook will outline in finer
details the possible paths we can take to implement the widget.

In the process of creating the widget, we will encounter blockers associated with object
types both internally in Bean Machine and externally with the tools used to create the
widget. We will describe in greater detail what the pain points are, and possible
solutions on how to handle them. We will also discuss ways to workaround issues and how
those workarounds could affect Bean Machine internally and its backwards compatibility.

## Problem statement

### User experience when visualizing model diagnostics and printing of the model output

The field is converging on using ArviZ for visual displays of diagnostics information
and several tools already use it in their official documentation. See
[pymc](https://docs.pymc.io/en/v3/pymc-examples/examples/case_studies/hierarchical_partial_pooling.html),
and
[pyro](https://num.pyro.ai/en/latest/tutorials/bayesian_hierarchical_linear_regression.html#4.1.-Inspecting-the-learned-parameters)
as examples. Other tools like [stan](https://mc-stan.org/) may not have official
documentation using ArviZ, but ArviZ is committed to ensuring tools like stan, emcee,
and pyjags, can use it with minimal efforts required by users, see the
[documentation](https://arviz-devs.github.io/arviz/api/data.html) for examples of
converters employed by ArviZ.

Bean Machine currently returns to a user a `MonteCarloSamples` object that is not
compatible with ArviZ and requires them to convert it into an object that is compatible
with the tool. TensorFlow requires similar steps for a user to use model outputs with
ArviZ, see [this
tutorial](https://jeffpollock9.github.io/bayesian-workflow-with-tfp-and-arviz/) as an
example. The conversion to a compatible object is not difficult, but it is a step a user
must be aware of. Ensuring the user understands that they need to convert a
`MonteCarloSamples` object to a compatible type for use in ArviZ must be documented and
communicated to the them. This extra step can be streamlined by Bean Machine such that
the user does not have to know _a priori_ that they need to convert their model output
into something ArviZ can use. Removing the burden of object conversion from the user
will enhance their experience using Bean Machine and its perceived compatibility with
other open source tools modelers use.

There are two distinct paths we will discuss in the notebook that enhances the user
experience of being able to use ArviZ diagnostics tools with the `MonteCarloSamples`
object returned by Bean Machine. Both methods transform the `MonteCarloSamples` object
transparently for the user into something that is consumable by ArviZ, and both methods
use the exact same code to create visual diagnostics. The only difference between the
two methods is their implementation. One method introduces accessor capabilities to the
`MonteCarloSamples` object while the other is a functional call with the input of the
function being the `MonteCarloSamples` object. The difference between how the two
implementations are executed by the user are detailed below.

**Accessor use**
```python
samples = bm.GlobalNoUTurnSampler.infer(...)
samples.accessor_name.plot()
```

**Functional use**
```python
samples = bm.GlobalNoUTurnSampler.infer(...)
bm.diagnostics.function_name(samples)
```

In the notebook we will address how to create a visual diagnostics widget using both
implementations. We will not discuss possibilities of returning an object to the user of
a Bean Machine infer method that is directly compatible with ArviZ. That discussion is
quite deep and touches on internal mechanism of the tool as well as internal dependency
compatibility issues. We will leave that discussion for another notebook.

How the model is represented to the user when they print it is also discussed. The
current user experience is behind those of ArviZ and xarray, which is easily solvable.

## Motivation for different implementations

**Accessor**

The accessor implementation was inspired from the documentation for extending
[xarray](https://xarray.pydata.org/en/stable/internals/extending-xarray.html). Internal
discussions around what the object type being used internally in Bean Machine led to
reading the documentation for xarray. The extensibility discussion around the use of an
accessor is why we chose to implement it.

**Functional**

This is a classical method for adding usability to a tool. Both the accessor
implementation and the functional one use the same code, and ultimately the same tests
so the choice of one over the other is purely dependent on the team wanting to introduce
accessor extensibility to Bean Machine of not.

## Prerequisites

We will be using the following tools in this notebook.

- [ArviZ](https://arviz-devs.github.io/arviz/)
- [Bokeh](https://docs.bokeh.org/en/latest/)

In [1]:
import warnings
from typing import (
    Any,
    Dict,
    List,
    Tuple,
    Union,
)

import arviz as az
import beanmachine.ppl as bm
import numpy as np
import torch
import torch.distributions as dist
from bokeh.io import output_notebook, show
from bokeh.layouts import column, row
from bokeh.models import (
    Circle,
    ColumnDataSource,
    GlyphRenderer,
    HoverTool,
    Line,
    Quad,
    Select,
)
from bokeh.plotting import figure, Figure
from bokeh.sampledata.penguins import data as penguin_df

The below commands allow for a better user experience in the notebook, as well as
installing Bean Machine in Google Colab if that is where this notebook is run.

In [2]:
# Colab support.
import os
import sys

if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
    !pip install beanmachine


# Plotting settings.
output_notebook(hide_banner=True)
az.rcParams["plot.backend"] = "bokeh"
az.rcParams["stats.hdi_prob"] = 0.89

# Manual seed for torch.
torch.manual_seed(1199);

## Data

We will use the [Palmer penguin dataset](https://github.com/allisonhorst/palmerpenguins)
located in Bokeh for an example model.

In [3]:
df = penguin_df.dropna().reset_index(drop=True).copy()
df.head()

Unnamed: 0,species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex
0,Adelie,Torgersen,39.1,18.7,181.0,3750.0,MALE
1,Adelie,Torgersen,39.5,17.4,186.0,3800.0,FEMALE
2,Adelie,Torgersen,40.3,18.0,195.0,3250.0,FEMALE
3,Adelie,Torgersen,36.7,19.3,193.0,3450.0,FEMALE
4,Adelie,Torgersen,39.3,20.6,190.0,3650.0,MALE


Below is a visual showing the penguin data. We will create a linear regression model
using Bean Machine below based on the data being displayed.

In [4]:
# Create the figure
p = figure(
    plot_width=800,
    plot_height=400,
    outline_line_color="black",
    x_axis_label="Body mass (g)",
    y_axis_label="Flipper length (mm)",
    title="Penguins",
)

# Bind data to the figure
colors = ["steelblue", "magenta", "brown"]
unique_species = sorted(df["species"].unique())
species_colors = [
    colors[unique_species.index(species)]
    for item_index, species in df["species"].iteritems()
]
cds = ColumnDataSource(
    {
        "x": df["body_mass_g"].astype(int).tolist(),
        "y": df["flipper_length_mm"].astype(int).tolist(),
        "color": species_colors,
        "species": df["species"].tolist(),
        "island": df["island"].tolist(),
    }
)
glyph = p.circle(
    x="x",
    y="y",
    source=cds,
    size=10,
    fill_color="color",
    line_color="white",
    fill_alpha=0.6,
    line_alpha=0.6,
    hover_fill_color="orange",
    hover_line_color="black",
    hover_fill_alpha=1,
    hover_line_alpha=1,
    legend_group="species",
)
tips = HoverTool(
    renderers=[glyph],
    tooltips=[
        ("Flipper", "@y{0,}mm"),
        ("Mass", "@x{0,}g"),
        ("Species", "@species"),
        ("Island", "@island"),
    ],
)
p.add_tools(tips)

# Style the figure
p.grid.grid_line_alpha = 0.2
p.grid.grid_line_color = "grey"
p.grid.grid_line_width = 0.3
p.legend.location = "top_left"
p.legend.title = "Species"

# Show the figure
show(p)

## Model

We will use the model defined below in Bean Machine.

$$
\begin{aligned}
  \alpha &\sim \text{Normal}(135,5)       \\
  \beta  &\sim \text{HalfNormal}(0.5)     \\
  \sigma &\sim \text{HalfNormal}(5)       \\
  \mu    &=    \alpha+\beta\cdot x        \\
  y      &\sim \text{Normal}(\mu,\sigma)
\end{aligned}
$$

In [5]:
@bm.random_variable
def alpha():
    return dist.Normal(135, 5)


@bm.random_variable
def beta():
    return dist.HalfNormal(0.5)


@bm.random_variable
def sigma():
    return dist.HalfNormal(5)


@bm.random_variable
def y():
    mu = alpha() + beta() * X
    return dist.Normal(mu, sigma())

In [6]:
X = torch.tensor(df["body_mass_g"].astype(float).tolist())
Y = torch.tensor(df["flipper_length_mm"].astype(float).tolist())
queries = [alpha(), beta(), sigma()]
observations = {y(): Y}

num_samples = 2000
num_chains = 4
num_adaptive_samples = num_samples // 2

posterior = bm.GlobalNoUTurnSampler().infer(
    queries=queries,
    observations=observations,
    num_samples=num_samples,
    num_chains=num_chains,
    num_adaptive_samples=num_adaptive_samples,
)

Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]

## Analysis

The first step after making a model involves looking at its diagnostics, specifically
the effective sample size and $\hat{R}$ values are crucial to understanding how the
model is fitting the data. We use ArviZ to generate those statistics and display them
below in a pandas dataframe.

### User required object conversion

**Note** we convert the `MonteCarloSamples` object from Bean Machine to a compatible
objects that ArviZ can consume.

In [7]:
print(f"Bean Machine model output type: {type(posterior)}")

Bean Machine model output type: <class 'beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples'>


#### xarray conversion

In [8]:
posterior_xr = posterior.to_xarray()
print(f"Conversion of Bean Machine model to: {type(posterior_xr)}\n")
summary_df = az.summary(posterior_xr, round_to=4)
display(summary_df)

Conversion of Bean Machine model to: <class 'xarray.core.dataset.Dataset'>



Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
beta(),0.0153,0.0004,0.0146,0.016,0.0,0.0,2313.3219,2606.025,1.0021
sigma(),6.8662,0.2677,6.4158,7.2734,0.0044,0.0031,3768.5788,3634.2791,1.0012
alpha(),136.6708,1.8882,133.8061,139.7388,0.0391,0.0277,2341.7567,2643.3946,1.002


#### ArviZ `InferenceData` conversion

We could have alternatively converted the `MonteCarloSamples` object to an ArviZ
`InferenceData` object. Doing so produces the same output as above.

In [9]:
posterior_idata = posterior.to_inference_data()
print(f"Conversion of Bean Machine model to: {type(posterior_idata)}\n")
az.summary(posterior_idata, round_to=4)

Conversion of Bean Machine model to: <class 'arviz.data.inference_data.InferenceData'>



Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
beta(),0.0153,0.0004,0.0146,0.016,0.0,0.0,2313.3219,2606.025,1.0021
sigma(),6.8662,0.2677,6.4158,7.2734,0.0044,0.0031,3768.5788,3634.2791,1.0012
alpha(),136.6708,1.8882,133.8061,139.7388,0.0391,0.0277,2341.7567,2643.3946,1.002


#### Error if no conversion is done

We had to convert the object returned from the Bean Machine infer method into one that
is compatible with ArviZ in order to get the diagnostics dataframe. If we assumed the
model output was compatible with ArviZ and tried to display model diagnostics we would
get the following error.

In [10]:
az.summary(posterior, round_to=4)

ValueError: Can only convert xarray dataarray, xarray dataset, dict, netcdf filename, numpy array, pystan fit, pymc3 trace, emcee fit, pyro mcmc fit, numpyro mcmc fit, cmdstan fit csv filename, cmdstanpy fit to InferenceData, not MonteCarloSamples

### User experience inspecting model output

Faced with this error we could print the output object of the model to see what it is.
Doing so introduces another user experience issue related to what is being displayed
when the model is printed.

#### Bean Machine representation

In [11]:
posterior

<beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples at 0x7fddad0325e0>

The current `__repr__` of `MonteCarloSamples` object is cryptic and not as helpful as it
could be. If the user knows they can inspect the `__dict__` attribute on the
`MonteCarloSamples` object, then they would be given more information that is more
actionable than the above output.

In [12]:
vars(posterior)

{'num_chains': 4,
 'num_adaptive_samples': 1000,
 'adaptive_samples': {RVIdentifier(wrapper=<function beta at 0x7fddad025f70>, arguments=()): tensor([[0.5178, 0.5178, 0.3815,  ..., 0.0151, 0.0152, 0.0153],
          [2.7096, 2.7096, 1.7150,  ..., 0.0154, 0.0154, 0.0155],
          [0.4381, 0.4381, 0.2995,  ..., 0.0150, 0.0150, 0.0148],
          [1.9286, 1.9286, 1.5225,  ..., 0.0157, 0.0158, 0.0153]]),
  RVIdentifier(wrapper=<function sigma at 0x7fddad0370d0>, arguments=()): tensor([[0.9165, 0.9165, 1.2105,  ..., 6.5677, 7.1023, 6.7174],
          [2.0357, 2.0357, 3.1911,  ..., 6.8417, 6.8365, 6.9220],
          [0.3443, 0.3443, 0.4838,  ..., 7.1705, 7.1533, 6.6019],
          [4.0156, 4.0156, 5.0578,  ..., 7.0861, 7.0326, 6.5266]]),
  RVIdentifier(wrapper=<function alpha at 0x7fddad025d30>, arguments=()): tensor([[  1.7704,   1.7704,   1.7703,  ..., 137.2800, 137.2571, 136.2836],
          [ -1.0675,  -1.0675,  -1.0675,  ..., 135.5791, 136.0535, 136.0591],
          [  0.9346,   0.934

#### ArviZ or xarray representation

In comparison with ArviZ or xarray's output (shown below) there are opportunities for
Bean Machine to enhance its user experience and what is presented to the user when they
print out the model object.

In [13]:
posterior_idata

In [14]:
posterior_xr

### Enhancing user experience with accessors

The following code is boilerplate for adding accessor decorator capabilities to
`MonteCarloSamples` objects.

In [15]:
class AccessorRegistrationWarning(Warning):
    """Warning for conflicts in accessor registration."""


class _CachedAccessor:

    def __init__(self, name, accessor):
        self._name = name
        self._accessor = accessor

    def __get__(self, obj, cls):
        if obj is None:
            return self._accessor

        try:
            cache = obj._cache
        except AttributeError:
            cache = obj._cache = {}

        try:
            return cache[self._name]
        except KeyError:
            pass

        try:
            accessor_obj = self._accessor(obj)
        except AttributeError:
            raise RuntimeError(f"error initializing {self._name!r} accessor.")

        cache[self._name] = accessor_obj
        return accessor_obj


def _register_accessor(name, cls):
    def decorator(accessor):
        if hasattr(cls, name):
            warnings.warn(
                f"registration of accessor {accessor!r} under name {name!r} for"
                f" type {cls!r} is overriding a preexisting attribute with the "
                " same name.",
                AccessorRegistrationWarning,
                stacklevel=2,
            )
        setattr(cls, name, _CachedAccessor(name, accessor))
        return accessor

    return decorator

In [16]:
# This is the actual decorator we will use.
def register_mcs_accessor(name):
    return _register_accessor(name, bm.inference.monte_carlo_samples.MonteCarloSamples)

Using the new accessor code we can register a decorator on _any_ `MonteCarloSamples`
object as follows. Before we create and execute the accessor, recall the `__dict__` keys
for the `MonteCarloSamples` object are the following.

In [17]:
vars(posterior).keys()

dict_keys(['num_chains', 'num_adaptive_samples', 'adaptive_samples', 'samples', 'single_chain_view'])

#### `repr` display

We will first introduce a `repr` that displays the output of the `MonteCarloSamples`
object in a more human readable format. This output can be modified using CSS tricks and
better HTML similar to that employed by ArviZ or xarray.

In [18]:
@register_mcs_accessor("display")
class BeanMachineDisplay:

    def __init__(self, obj):
        self.obj = obj

    def _repr_html_(self) -> str:
        n_chains = self.obj.num_chains
        n_adaptive_samples = self.obj.num_adaptive_samples
        n_samples = self.obj.get_num_samples()
        output = ""
        for query, samples in self.obj.adaptive_samples.items():
            output += f"<p>{str(query)}</p>"
            output += f"<pre>{samples.__repr__()}</pre>"
        return (
            f'<p><span style="font-weight:bold;">Number of chains:</span> {n_chains}</p>'
            f'<p><span style="font-weight:bold;">Number of adaptive samples:</span> {n_adaptive_samples:,}</p>'
            f'<p><span style="font-weight:bold;">Number of samples:</span> {n_samples:,}</p>'
            f"{output}"
        )

Because we are creating an accessor object on the `MonteCarloSamples` object, the only
way to access it is to call it. This is not ideal for the user, but it shows how we
could possibly implement it on the `MonteCarloSamples` object internally. Internal
changes to the tool are out of scope for this notebook and are only shown for reference
purposes.

In [19]:
posterior.display

#### ArviZ summary accessor

The following accessor does not require changing internals of Bean Machine. It merely
makes the summary output from ArviZ directly available to the user from the
`MonteCarloSamples` object. Recall we had to convert the data supplied to ArviZ in order
to get the below dataframe.

In [20]:
summary_df

Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
beta(),0.0153,0.0004,0.0146,0.016,0.0,0.0,2313.3219,2606.025,1.0021
sigma(),6.8662,0.2677,6.4158,7.2734,0.0044,0.0031,3768.5788,3634.2791,1.0012
alpha(),136.6708,1.8882,133.8061,139.7388,0.0391,0.0277,2341.7567,2643.3946,1.002


In [21]:
@register_mcs_accessor("summary")
class BeanMachineSummary:

    def __init__(self, obj):
        # Transparently convert the MonteCarloSamples object into one that ArviZ can
        # consume.
        self.obj = obj.to_inference_data()
        display(az.summary(self.obj, round_to=4))

In [22]:
posterior.summary

Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
beta(),0.0153,0.0004,0.0146,0.016,0.0,0.0,2313.3219,2606.025,1.0021
sigma(),6.8662,0.2677,6.4158,7.2734,0.0044,0.0031,3768.5788,3634.2791,1.0012
alpha(),136.6708,1.8882,133.8061,139.7388,0.0391,0.0277,2341.7567,2643.3946,1.002


<__main__.BeanMachineSummary at 0x7fddad030940>

Registering the new accessor on the `MonteCarloSamples` object gives the user direct
access to using ArviZ for summary statistics. This is completely transparent to the user
and more importantly, it does not modify any of the internals of Bean Machine.

We can inspect the `__dict__` representation of our `MonteCarloSamples` object and
compare it to what we had before. You will notice that there is a new key called
`_cache` on the object.

In [23]:
vars(posterior).keys()

dict_keys(['num_chains', 'num_adaptive_samples', 'adaptive_samples', 'samples', 'single_chain_view', '_cache'])

This new key is where accessors are stored. The accessor objects are not the same as adding an
attribtue on the object directly, _e.g._ like the following.

```python
posterior.summary = az.summary(posterior_idata, round_to=4)
```

A discussion of the difference between adding (or overwritting) an attribute on the
`MonteCarloSamples` object is out of scope for this notebook.

In [24]:
vars(posterior)["_cache"]

{'display': <__main__.BeanMachineDisplay at 0x7fddacf78160>,
 'summary': <__main__.BeanMachineSummary at 0x7fddad030940>}

### Adding a visual widget using an accessor

Below is code we will use to add a visual diagnostic widget to our `MonteCarloSamples`
model output. There is a lot of code below, so feel free to scroll past it. In summary
the following code creates the following using Bokeh:

- Effective sample size evolution plot
- Posterior distributions for each chain of a query
- Rank plots for each query
- Select widget to update the plots with a new display

In [25]:
class BokehESSPlotMixin:

    def compute_ess_data(self, rv_name: str) -> Dict[str, List[float]]:
        n_points = 20
        _rv_name = self._rv_names[self.rv_names.index(rv_name)]
        self.ess_x = np.linspace(
            start=self.total_draws / n_points,
            stop=self.total_draws,
            num=n_points,
        )
        draw_divisions = np.linspace(
            start=self.n_draws // n_points,
            stop=self.n_draws,
            num=n_points,
            dtype=int,
        )
        data = {
            "x": self.ess_x.tolist(),
            "bulk": [
                az.stats.diagnostics._ess_bulk(
                    self.posterior.data_vars.get(_rv_name).values[:, self.first_draw:draw_div]
                )
                for draw_div in draw_divisions
            ],
            "tail": [
                az.stats.diagnostics._ess_tail(
                    self.posterior.data_vars.get(_rv_name).values[:, self.first_draw:draw_div]
                )
                for draw_div in draw_divisions
            ],
        }
        return data

    def create_ess_data_sources(self, rv_name: str) -> Dict[str, ColumnDataSource]:
        data = self.compute_ess_data(rv_name)
        x = np.linspace(start=data["x"][0], stop=data["x"][-1])
        y = (self.n_rule_of_thumb * np.ones(len(x)))
        label = [self.n_rule_of_thumb] * len(x)
        return {
            "bulk": ColumnDataSource({"x": data["x"], "y": data["bulk"]}),
            "tail": ColumnDataSource({"x": data["x"], "y": data["tail"]}),
            "rot": ColumnDataSource({"x": x.tolist(), "y": y.tolist(), "label": label}),
        }

    def create_ess_figure(self, rv_name: str) -> Figure:
        fig = figure(
            plot_width=self.PLOT_WIDTH,
            plot_height=self.PLOT_HEIGHT,
            title=f"{rv_name} effective sample size",
            outline_line_color="black",
            x_axis_label="Total number of draws",
            y_axis_label="ESS",
        )
        self.apply_style(fig)
        fig.x_range.start = 0
        fig.y_range.start = 0
        return fig

    def create_ess_glyphs(self) -> Dict[str, List[Dict[str, Any]]]:
        ess_glyphs = {
            "bulk": [],
            "tail": [],
            "rot": [],
        }
        for i, ess_type in enumerate(["bulk", "tail"]):
            line_glyph = Line(
                x="x",
                y="y",
                line_color=self.COLORS[i],
                line_width=2,
                line_alpha=0.6,
                name=f"{ess_type}_ess_line_glyph",
            )
            ess_glyphs[ess_type].append({"glyph": line_glyph})
            circle_glyph = Circle(
                x="x",
                y="y",
                size=10,
                fill_color=self.COLORS[i],
                fill_alpha=1,
                line_color="white",
                line_alpha=1,
                line_width=1,
                name=f"{ess_type}_ess_circle_glyph",
            )
            hover_glyph = Circle(
                x="x",
                y="y",
                size=10,
                fill_color="orange",
                fill_alpha=1,
                line_color="black",
                line_alpha=1,
                line_width=2,
                name=f"{ess_type}_ess_circle_hover_glyph",
            )
            ess_glyphs[ess_type].append({"glyph": circle_glyph, "hover_glyph": hover_glyph})
        ess_glyphs["rot"].append(
            {
                "glyph": Line(
                    x="x",
                  y="y",
                  line_color="magenta",
                  line_width=3,
                  line_dash="dashed",
                  line_alpha=0.6,
                  name="rot",
                ),
                "hover_glyph": Line(
                    x="x",
                    y="y",
                    line_color="magenta",
                    line_width=3,
                    line_dash="solid",
                    line_alpha=1.0,
                    name="rot_hover",
                ),
            }
        )
        return ess_glyphs

    def bind_ess_glyphs_to_figure(
        self,
        fig: Figure,
        ess_cdses: Dict[str, ColumnDataSource],
        ess_glyphs: Dict[str, List[Dict[str, Any]]],
    ) -> None:
        for key, glyphs in ess_glyphs.items():
            for glyph in glyphs:
                source = ess_cdses[key]
                glyph_ = glyph.get("glyph", None)
                hover_glyph = glyph.get("hover_glyph", None)
                name = glyph_.name
                fig.add_glyph(
                    source_or_glyph=source,
                    glyph=glyph_,
                    hover_glyph=hover_glyph,
                    name=name,
                )

    def create_ess_tooltips(self) -> Dict[str, List[Tuple[str]]]:
        return {
            "bulk": [("Bulk ESS", "@y{0,}"), ("Draw", "@x{0,}")],
            "tail": [("Tail ESS", "@y{0,}"), ("Draw", "@x{0,}")],
            "rot": [("Rule of thumb", "@label")],
        }

    def bind_ess_tooltips_to_figure(
        self,
        fig: Figure,
        tooltips: Dict[str, List[Tuple[str]]],
    ) -> None:
        for key, tips in tooltips.items():
            if key in ["bulk", "tail"]:
                renderers = [
                    renderer
                    for renderer in fig.renderers
                    if isinstance(renderer, GlyphRenderer)
                    and renderer.name == f"{key}_ess_circle_glyph"
                ]
                fig.add_tools(HoverTool(renderers=renderers, tooltips=tips))
            elif key == "rot":
                renderers = [
                    renderer
                    for renderer in fig.renderers
                    if isinstance(renderer, GlyphRenderer)
                    and renderer.name == "rot"
                ]
                fig.add_tools(HoverTool(renderers=renderers, tooltips=tips))

    def bind_ess_to_figure(
        self,
        fig: Figure,
        data_sources: Dict[str, ColumnDataSource],
        glyphs: Dict[str, List[Dict[str, Any]]],
        tips: Dict[str, List[Tuple[str]]],
    ) -> None:
        self.bind_ess_glyphs_to_figure(
            fig=fig,
            ess_cdses=data_sources,
            ess_glyphs=glyphs,
        )
        self.bind_ess_tooltips_to_figure(fig=fig, tooltips=tips)

    def update_ess_figure(
        self,
        fig: Figure,
        rv_name: str,
        old_data_sources: Dict[str, ColumnDataSource],
    ) -> None:
        new_data_sources = self.create_ess_data_sources(rv_name)
        for key, value in old_data_sources.items():
            old_data_sources[key].data = dict(new_data_sources[key].data)
        fig.title.text = " ".join([rv_name] + fig.title.text.split()[1:])


class BokehRankPlotMixin:

    RANK_PLOT_TICK_LABEL_OFFSET = 0.5

    def compute_rank_data(self, rv_name: str) -> Dict[str, np.ndarray]:
        rv_name = self._rv_names[self.rv_names.index(rv_name)]
        data = self.posterior.get(rv_name).values
        rank_data = az.plots.plot_utils.compute_ranks(data)
        n_bins = int(np.ceil(2 * np.log2(rank_data.shape[1])) + 1)
        bins = np.histogram_bin_edges(rank_data, bins=n_bins, range=(0, rank_data.size))
        hist = np.empty((self.n_chains, len(bins) - 1))
        normed_hist = np.empty((self.n_chains, len(bins) - 1))
        for chain in range(self.n_chains):
            _, h, _ = az.stats.density_utils.histogram(
                rank_data[chain, :],
                bins=n_bins,
            )
            hist[chain] = h
            normed_hist[chain] = h / h.max()
        return {"bins": bins, "histogram": hist, "normalized_histogram": normed_hist}

    def create_rank_data_sources(
        self,
        rv_name: str
    ) -> Dict[str, Dict[str, ColumnDataSource]]:
        data = self.compute_rank_data(rv_name)
        bins = data["bins"]
        histogram = data["histogram"]
        normalized_histogram = data["normalized_histogram"]
        data_sources = {}
        for chain in range(self.n_chains):
            bin_labels = [
                f"{int(b[0]):0,}–{int(b[1]):0,}"
                for b in zip(bins[:-1], bins[1:])
            ]
            x = np.linspace(
                start=bins[0],
                stop=bins[-1],
                num=len(normalized_histogram[chain]),
            ).tolist()
            data_sources[f"chain{chain}"] = {
                "quad": ColumnDataSource(
                    {
                        "left": bins[:-1].tolist(),
                        "top": (normalized_histogram[chain] + chain).tolist(),
                        "right": bins[1:].tolist(),
                        "bottom": (np.zeros(len(bins) - 1) + chain).tolist(),
                        "draws": bin_labels,
                        "chain": [chain + 1] * (len(bins) - 1),
                        "label": normalized_histogram[chain].tolist(),
                    }
                ),
                "mean": ColumnDataSource(
                    {
                        "x": x,
                        "y": (
                            chain + (
                                np.ones(len(normalized_histogram[chain]))
                                * normalized_histogram[chain].mean()
                            )
                        ).tolist(),
                        "chain": [chain + 1] * len(x),
                        "label": [float(normalized_histogram[chain].mean())] * len(x),
                    }
                )
            }
        return data_sources

    def create_rank_figure(self, rv_name: str) -> Figure:
        fig = figure(
            plot_width=self.PLOT_WIDTH,
            plot_height=self.PLOT_HEIGHT,
            x_axis_label="Rank from all chains",
            y_axis_label="Chain",
            outline_line_color="black",
            title=f"{rv_name} rank histograms for all chains for all draws",
        )
        self.apply_style(fig)
        ticker = (np.arange(self.n_chains) + self.RANK_PLOT_TICK_LABEL_OFFSET).tolist()
        self.override_tick_labels_and_locations(fig, ticker)
        return fig

    def override_tick_labels_and_locations(
        self,
        fig: Figure,
        ticker: List[float],
    ) -> None:
        fig.yaxis.ticker = ticker
        fig.yaxis.major_label_overrides = dict(
            zip(ticker, map(str, range(1, self.n_chains + 1)))
        )

    def create_rank_glyphs(self) -> Dict[str, Dict[str, Any]]:
        rank_glyphs = {}
        for chain in range(self.n_chains):
            glyph = Quad(
                left="left",
                top="top",
                right="right",
                bottom="bottom",
                fill_color=self.COLORS[chain],
                line_color="white",
                fill_alpha=0.6,
                name=f"quad_chain{chain}",
            )
            hover_glyph = Quad(
                left="left",
                top="top",
                right="right",
                bottom="bottom",
                fill_color=self.COLORS[chain],
                line_color="black",
                line_width=2,
                fill_alpha=1,
                name=f"quad_hover_chain{chain}",
            )
            mean_glyph = Line(
                x="x",
                y="y",
                line_dash="dashed",
                line_color="black",
                line_width=2,
                line_alpha=0.5,
                name=f"mean_chain{chain}",
            )
            mean_hover_glyph = Line(
                x="x",
                y="y",
                line_dash="solid",
                line_color="black",
                line_width=4,
                line_alpha=1,
                name=f"mean_hover_chain{chain}",
            )
            rank_glyphs[f"chain{chain}"] = {
                "quad": {"glyph": glyph, "hover_glyph": hover_glyph},
                "mean": {"glyph": mean_glyph, "hover_glyph": mean_hover_glyph},
            }
        return rank_glyphs

    def bind_rank_glyphs_to_figure(
        self,
        fig: Figure,
        rank_data_sources: Dict[str, Dict[str, ColumnDataSource]],
        rank_glyphs: Dict[str, List[Dict[str, Any]]],
    ) -> None:
        ticker = []
        for chain, datum in rank_glyphs.items():
            for source_type, values in datum.items():
                source = rank_data_sources[chain][source_type]
                if source_type == "mean":
                    ticker.append(source.data["y"][0])
                glyph = values["glyph"]
                hover_glyph = values["hover_glyph"]
                name = glyph.name
                fig.add_glyph(
                    source_or_glyph=source,
                    glyph=glyph,
                    hover_glyph=hover_glyph,
                    name=name,
                )
        # Update the tick label locations to coincide with the mean value for the chain.
        self.override_tick_labels_and_locations(fig, ticker)

    def create_rank_tooltips(self) -> Dict[str, Dict[str, List[Tuple[str]]]]:
        tips = {}
        for chain in range(self.n_chains):
            tips[f"chain{chain}"] = {
                "quad": [
                    ("Chain", "@chain"),
                    ("Draws", "@draws"),
                    ("Normalized chain rank", "@label{0.000}")
                ],
                "mean": [
                    ("Chain", "@chain"),
                    ("Normalized mean for the chain", "@label{0.000}"),
                ],
            }
        return tips

    def bind_rank_tooltips_to_figure(
        self,
        fig: Figure,
        tooltips: Dict[str, Dict[str, List[Tuple[str]]]],
    ) -> None:
        for chain, datum in tooltips.items():
            for glyph_type, tips in datum.items():
                renderers = [
                    renderer
                    for renderer in fig.renderers
                    if isinstance(renderer, GlyphRenderer)
                    and glyph_type in renderer.name
                    and f"{chain}" in renderer.name
                ]
                fig.add_tools(HoverTool(renderers=renderers, tooltips=tips))

    def bind_rank_to_figure(
        self,
        fig: Figure,
        data_sources: Dict[str, Dict[str, ColumnDataSource]],
        glyphs: Dict[str, List[Dict[str, Any]]],
        tips: Dict[str, Dict[str, List[Tuple[str]]]],
    ) -> None:
        self.bind_rank_glyphs_to_figure(
            fig=fig,
            rank_data_sources=data_sources,
            rank_glyphs=glyphs,
        )
        self.bind_rank_tooltips_to_figure(fig=fig, tooltips=tips)

    def update_rank_figure(
        self,
        fig: Figure,
        rv_name: str,
        old_data_sources: Dict[str, Dict[str, ColumnDataSource]],
    ) -> None:
        new_data_sources = self.create_rank_data_sources(rv_name)
        new_ticker = []
        for chain, datum in new_data_sources.items():
            new_ticker.append(datum["mean"].data["y"][0])
        for chain, datum in old_data_sources.items():
            for glyph_type, cds in datum.items():
                old_data_sources[chain][glyph_type].data = dict(new_data_sources[chain][glyph_type].data)
        fig.title.text = " ".join([rv_name] + fig.title.text.split()[1:])
        self.override_tick_labels_and_locations(fig, new_ticker)


class BokehDensityPlotMixin:

    def compute_density_data(self, rv_name: str) -> Dict[str, Dict[str, List[float]]]:
        rv_name = self._rv_names[self.rv_names.index(rv_name)]
        data = self.posterior.get(rv_name).values
        output = {}
        for chain in range(self.n_chains):
            support, density = az.stats.density_utils.kde(data[chain, :])
            normalized_density = density / density.max()
            output[f"chain{chain}"] = {
                "support": support,
                "density": normalized_density,
                "chain": [chain + 1] * len(support),
            }
        return output

    def create_density_data_sources(self, rv_name: str) -> Dict[str, ColumnDataSource]:
        data = self.compute_density_data(rv_name)
        data_sources = {}
        for chain_name, chain_data in data.items():
            data_sources[chain_name] = ColumnDataSource(chain_data)
        return data_sources

    def create_density_figure(self, rv_name: str) -> Figure:
        fig = figure(
            plot_width=self.PLOT_WIDTH,
            plot_height=self.PLOT_HEIGHT,
            outline_line_color="black",
            x_axis_label=f"{rv_name}",
            title=f"{rv_name} individual chain densities",
        )
        self.apply_style(fig)
        fig.yaxis.visible = False
        return fig

    def create_density_glyphs(self) -> Dict[str, Dict[str, Any]]:
        glyphs = {}
        for chain in range(self.n_chains):
            glyphs[f"chain{chain}"] = {
                "glyph": Line(
                    x="support",
                    y="density",
                    line_color=self.COLORS[chain],
                    line_width=2,
                    line_alpha=0.5,
                    name=f"density_chain{chain}",
                ),
                "hover_glyph": Line(
                    x="support",
                    y="density",
                    line_color="orange",
                    line_width=3,
                    line_alpha=1,
                    name=f"density_hover_chain{chain}",
                ),
            }
        return glyphs

    def bind_density_glyphs_to_figure(
        self,
        fig: Figure,
        data_sources: Dict[str, ColumnDataSource],
        glyphs: Dict[str, Dict[str, ColumnDataSource]],
    ) -> None:
        for chain_name, source in data_sources.items():
            glyph = glyphs[chain_name]["glyph"]
            hover_glyph = glyphs[chain_name]["hover_glyph"]
            name = glyph.name
            fig.add_glyph(
                source_or_glyph=source,
                glyph=glyph,
                hover_glyph=hover_glyph,
                name=name,
            )

    def create_density_tooltips(self) -> Dict[str, List[Tuple[str]]]:
        tips = {}
        for chain in range(self.n_chains):
            tips[f"chain{chain}"] = [("Chain", "@chain"), ("", "@support")]
        return tips

    def bind_density_tooltips_to_figure(
        self,
        fig: Figure,
        tooltips: Dict[str, List[Tuple[str]]],
    ) -> None:
        for chain, tips in tooltips.items():
            renderers = [
                renderer
                for renderer in fig.renderers
                if isinstance(renderer, GlyphRenderer)
                and renderer.name == f"density_{chain}"
            ]
            fig.add_tools(HoverTool(renderers=renderers, tooltips=tips))

    def bind_density_to_figure(
        self,
        fig: Figure,
        data_sources: Dict[str, ColumnDataSource],
        glyphs: Dict[str, Dict[str, Any]],
        tips: Dict[str, List[Tuple[str]]],
    ) -> None:
        self.bind_density_glyphs_to_figure(
            fig=fig,
            data_sources=data_sources,
            glyphs=glyphs,
        )
        self.bind_density_tooltips_to_figure(fig=fig, tooltips=tips)

    def update_density_figure(
        self,
        fig: Figure,
        rv_name: str,
        old_data_sources: Dict[str, ColumnDataSource],
    ) -> None:
        new_data_sources = self.create_density_data_sources(rv_name)
        for chain, cds in old_data_sources.items():
            old_data_sources[chain].data = dict(new_data_sources[chain].data)
        fig.title.text = " ".join([rv_name] + fig.title.text.split()[1:])


class BokehPlotsMixin(
    BokehRankPlotMixin,
    BokehESSPlotMixin,
    BokehDensityPlotMixin,
):

    COLORS = ["steelblue", "brown", "magenta", "orange"]
    PLOT_WIDTH = 600
    PLOT_HEIGHT = 400

    def apply_style(self, fig):
        fig.grid.grid_line_alpha = 0.2
        fig.grid.grid_line_color = "grey"
        fig.grid.grid_line_width = 0.3
        fig.yaxis.minor_tick_line_color = None


@register_mcs_accessor("widgets")
class BeanMachineWidgets(BokehPlotsMixin):

    def __init__(self, obj):
        # Prepare data
        self.obj = obj
        self.posterior = self.obj.to_inference_data().posterior
        self._rv_names = [key for key in self.posterior.data_vars.keys()]
        self.rv_names = [str(key) for key in self._rv_names]
        self.n_chains = self.posterior.dims["chain"]
        self.n_draws = self.posterior.dims["draw"]
        self.total_draws = self.n_chains * self.n_draws
        self.first_draw = self.posterior.draw.values[0]
        self.n_rule_of_thumb = 100 * self.n_chains

    def modify_doc(self, doc):
        # Set the initial view.
        rv_name = self.rv_names[0]

        # Create data sources for the figures.
        rank_data_sources = self.create_rank_data_sources(rv_name)
        ess_data_sources = self.create_ess_data_sources(rv_name)
        density_data_sources = self.create_density_data_sources(rv_name)

        # Create figures and data tables.
        rank_fig = self.create_rank_figure(rv_name)
        ess_fig = self.create_ess_figure(rv_name)
        density_fig = self.create_density_figure(rv_name)

        # Create glyphs for the figures.
        rank_glyphs = self.create_rank_glyphs()
        ess_glyphs = self.create_ess_glyphs()
        density_glyphs = self.create_density_glyphs()

        # Create tooltips for the figures.
        rank_tips = self.create_rank_tooltips()
        ess_tips = self.create_ess_tooltips()
        density_tips = self.create_density_tooltips()

        # Bind data to figures.
        self.bind_rank_to_figure(
            fig=rank_fig,
            data_sources=rank_data_sources,
            glyphs=rank_glyphs,
            tips=rank_tips,
        )
        self.bind_ess_to_figure(
            fig=ess_fig,
            data_sources=ess_data_sources,
            glyphs=ess_glyphs,
            tips=ess_tips,
        )
        self.bind_density_to_figure(
            fig=density_fig,
            data_sources=density_data_sources,
            glyphs=density_glyphs,
            tips=density_tips,
        )

        def update(attr, old, new):
            new_rv_name = new
            self.update_ess_figure(
                fig=ess_fig,
                rv_name=new_rv_name,
                old_data_sources=ess_data_sources,
            )
            self.update_rank_figure(
                fig=rank_fig,
                rv_name=new_rv_name,
                old_data_sources=rank_data_sources,
            )
            self.update_density_figure(
                fig=density_fig,
                rv_name=new_rv_name,
                old_data_sources=density_data_sources,
            )

        select = Select(title="Query", value=rv_name, options=self.rv_names)
        select.on_change("value", update)
        layout = column(
            select,
            ess_fig,
            row(
                density_fig,
                rank_fig,
            )
        )

        doc.add_root(layout)

    def plot(self):
        return show(self.modify_doc)

With the new widget accessor we can visualize the diagnostics tool below.

In [26]:
posterior.widgets.plot()

We can compare the above tool to the output of current ArviZ visuals. For simplicity we
will plot the same diagnostic plots above for a single query variable.

In [27]:
ess = az.plot_ess({'alpha()': posterior.get(alpha())}, kind="evolution", show=False)[0][0]
post = az.plot_trace(
    {'alpha()': posterior.get(alpha())},
    kind="rank_bars",
    show=False,
)[0].tolist()
layout = row(ess, *post)
show(layout)

Interactivity within the plots has been enhanced as can be seen by interacting with
them. Note that ArviZ displays data and screen pixel values in its hover tool. In
contrast, our hover tool displays actual plotted data. Adding this information is an
enhancement for the user experience and is straight forward to implement using Bokeh.
The framework defined above can be extended further to include references, links to Bean
Machine documentation, and comments that describe when a model is not sampling the full
space efficiently. These are all easy to implement with the above framework, and can be
tested.

ArviZ does not output widgets, which our new accessor object does. This means that if a
model builder needs to investigate more queries, they have to explicitly plot those
diagnostics in the notebook. If you have hundreds of individual query values, then
plotting all the diagnostics plots without a widget can get cumbersome. The cumbersome
sentiment can be said for scrolling through a long list of query variables in a select
widget as well, however, this use case can be changed with logic that aggregates like
named queries so a user does not have to see a select widget with hundreds of options.

### Adding a visual widget using a functional approach

We do not need to create accessor functionality in order to create diagnostics widgets
for Bean Machine outputs. To drive this home, we will reuse all the above Bokeh code in
a function we call with the `MonteCarloSamples` object as an input.

In [28]:
def widgets(mcs):
    return BeanMachineWidgets(mcs).plot()


widgets(posterior)

## Pros and cons for the different implementations

**Accessor implementation**

*Pros*

- We can add features to the `MonteCarloSamples` object including interactive widgets
  all without having to touch the internals the `MonteCarloSamples` object.
- We can add the accessor object when a user first imports Bean Machine. This way the
  user has access to the widgets on a `MonteCarloSamples` object, but no internal
  machinery need to be changed.
- Users would be able to use accessor capabilities in unique ways.


*Cons*

- Increased library complexity with the accessor ability.
- May need to copy functionality from ArviZ so ArviZ's API is exposed (or not exposed)
  to the user functionality they would expect.

**Functional implementation**

*Pros*

- No need to add accessor capabilities to the `MonteCarloSamples` object.
- More familiar programming style.


*Cons*

- A user will need to import the widget sub-package in the notebook.
- May need to copy functionality from ArviZ.