In [1]:
import arviz as az
import beanmachine.ppl as bm
import torch
import torch.distributions as dist
from bokeh.io import output_notebook
from bokeh.models import HoverTool
from bokeh.models.sources import ColumnDataSource
from bokeh.plotting import figure, show
from bokeh.sampledata.penguins import data as penguin_df

from widgets import diagnostics

In [2]:
# Plotting settings.
output_notebook()
az.rcParams["plot.backend"] = "bokeh"
az.rcParams["stats.hdi_prob"] = 0.89

# Manual seed for torch.
torch.manual_seed(1199);

In [3]:
df = penguin_df.dropna().reset_index(drop=True).copy()
df.head()

Unnamed: 0,species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex
0,Adelie,Torgersen,39.1,18.7,181.0,3750.0,MALE
1,Adelie,Torgersen,39.5,17.4,186.0,3800.0,FEMALE
2,Adelie,Torgersen,40.3,18.0,195.0,3250.0,FEMALE
3,Adelie,Torgersen,36.7,19.3,193.0,3450.0,FEMALE
4,Adelie,Torgersen,39.3,20.6,190.0,3650.0,MALE


In [4]:
# Prepare data for the figure
adelie_df = df[df["species"] == "Adelie"].reset_index(drop=True)
adelie_cds = ColumnDataSource(
    {
        "x": adelie_df["body_mass_g"].astype(int).tolist(),
        "y": adelie_df["flipper_length_mm"].astype(int).tolist(),
        "species": adelie_df["species"].tolist(),
        "island": adelie_df["island"].tolist(),
    }
)

chinstrap_df = df[df["species"] == "Chinstrap"].reset_index(drop=True)
chinstrap_cds = ColumnDataSource(
    {
        "x": chinstrap_df["body_mass_g"].astype(int).tolist(),
        "y": chinstrap_df["flipper_length_mm"].astype(int).tolist(),
        "species": chinstrap_df["species"].tolist(),
        "island": chinstrap_df["island"].tolist(),
    }
)
gentoo_df = df[df["species"] == "Gentoo"].reset_index(drop=True)
gentoo_cds = ColumnDataSource(
    {
        "x": gentoo_df["body_mass_g"].astype(int).tolist(),
        "y": gentoo_df["flipper_length_mm"].astype(int).tolist(),
        "species": gentoo_df["species"].tolist(),
        "island": gentoo_df["island"].tolist(),
    }
)

In [5]:
# Create the figure
p = figure(
    plot_width=800,
    plot_height=400,
    outline_line_color="black",
    x_axis_label="Body mass (g)",
    y_axis_label="Flipper length (mm)",
    title="Penguins",
)

# Bind data to the figure
adelie_glyph = p.circle(
    x="x",
    y="y",
    source=adelie_cds,
    size=10,
    fill_color="steelblue",
    line_color="white",
    fill_alpha=0.6,
    line_alpha=0.6,
    hover_fill_color="orange",
    hover_line_color="black",
    hover_fill_alpha=1,
    hover_line_alpha=1,
    legend_group="species",
)
adelie_tips = HoverTool(
    renderers=[adelie_glyph],
    tooltips=[
        ("Flipper", "@y{0,}mm"),
        ("Mass", "@x{0,}g"),
        ("Species", "@species"),
        ("Island", "@island"),
    ],
)
p.add_tools(adelie_tips)

chinstrap_glyph = p.circle(
    x="x",
    y="y",
    source=chinstrap_cds,
    size=10,
    fill_color="magenta",
    line_color="white",
    fill_alpha=0.6,
    line_alpha=0.6,
    hover_fill_color="orange",
    hover_line_color="black",
    hover_fill_alpha=1,
    hover_line_alpha=1,
    legend_group="species",
)
chinstrap_tips = HoverTool(
    renderers=[chinstrap_glyph],
    tooltips=[
        ("Flipper", "@y{0,}mm"),
        ("Mass", "@x{0,}g"),
        ("Species", "@species"),
        ("Island", "@island"),
    ],
)
p.add_tools(chinstrap_tips)

gentoo_glyph = p.circle(
    x="x",
    y="y",
    source=gentoo_cds,
    size=10,
    fill_color="brown",
    line_color="white",
    fill_alpha=0.6,
    line_alpha=0.6,
    hover_fill_color="orange",
    hover_line_color="black",
    hover_fill_alpha=1,
    hover_line_alpha=1,
    legend_group="species",
)
gentoo_tips = HoverTool(
    renderers=[gentoo_glyph],
    tooltips=[
        ("Flipper", "@y{0,}mm"),
        ("Mass", "@x{0,}g"),
        ("Species", "@species"),
        ("Island", "@island"),
    ],
)
p.add_tools(gentoo_tips)

# Style the figure
p.grid.grid_line_alpha = 0.2
p.grid.grid_line_color = "grey"
p.grid.grid_line_width = 0.3
p.legend.location = "top_left"
p.legend.title = "Species"
p.legend.click_policy = "mute"

# Show the figure
show(p)

In [6]:
@bm.random_variable
def alpha():
    return dist.Normal(135, 5)


@bm.random_variable
def beta():
    return dist.HalfNormal(0.5)


@bm.random_variable
def sigma():
    return dist.HalfNormal(5)


@bm.random_variable
def y():
    mu = alpha() + beta() * X
    return dist.Normal(mu, sigma())

In [7]:
X = torch.tensor(df["body_mass_g"].astype(float).tolist())
Y = torch.tensor(df["flipper_length_mm"].astype(float).tolist())
queries = [alpha(), beta(), sigma()]
observations = {y(): Y}

num_samples = 2000
num_chains = 4
num_adaptive_samples = num_samples // 2

samples = bm.GlobalNoUTurnSampler().infer(
    queries=queries,
    observations=observations,
    num_samples=num_samples,
    num_chains=num_chains,
    num_adaptive_samples=num_adaptive_samples,
)

Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]

Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]

# Widgets
---

In [9]:
samples.diagnostics.autocorrelation()

In [11]:
samples.diagnostics.ess()

In [13]:
samples.diagnostics.joint_plot()

In [15]:
samples.diagnostics.posterior()

In [17]:
samples.diagnostics.summary()

In [19]:
samples.diagnostics.trace_plot()