In [2]:
# this just adds my simulation package to the path
import logging
import os
import pickle
import sys
from pathlib import Path

import copy
import pprint
from collections.abc import Iterable
from typing import Callable, Mapping

import numpy as np
import pandas as pd
from cmdstanpy import CmdStanModel
from tqdm.auto import tqdm

sys.path.append("../")
adaptive_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if adaptive_root not in sys.path:
    sys.path.append(adaptive_root)
data_dir = os.path.join(adaptive_root, "notebooks", "fake_data")
# prefix is the name of the notebook without the extension
prefix = "theta_variance_single_expt"

from simulation import fitters, simulators, utils

logger = logging.getLogger("cmdstanpy")
logger.addHandler(logging.NullHandler())
logger.propagate = False
logger.setLevel(logging.CRITICAL)

pp = pprint.PrettyPrinter()

In [3]:
# Attempt to load existing simulation data
try:
    data_to_store = utils.load_latest_simulation(data_dir, prefix)
    globals().update(data_to_store)
    print("Loaded these variables from the latest simulation:")
    print(list(data_to_store.keys()))
except FileNotFoundError:
    data_to_store = {}
    print("No data found in the specified directory.")

Loaded these variables from the latest simulation:
['chick_variances', 'sigma1_three_times_sigma0', 'zoomed_chick_variances', 'sigma1_sigma0_ratios', 'sigma1_sigma0_ratios_large_sigma_b', 'var_theta_mu_theta_r', 'var_theta_mu_theta_const_b_r']


In this notebook, we explore how the posterior variance of $\theta$ depends on the treatment ratio *for a single experiment*. The intent of this first simulation is the following:
- confirm that the theoretical posterior variance of $\theta$ is correct (for the oracle model with known hyperparameters)
- examine how different the posterior variance is when we sample $\theta$ using the flat-hyperprior model (i.e. no information about the hyperparameters)

We use the same hyperparameters that were obtained from fitting the chicken data. Later we will modify this and see how the posterior variance changes.

In [4]:
fitter_list = [
    fitters.FitterFlatHyperpriors(),
    fitters.FitterOracleHyperparameters(),
    fitters.FitterConstbFlatHypersTheta(),
]

The function `posterior_variance_grid` takes in a base simulator (which has all the parameters except the sample size and treatment ratio) and computes the posterior variance of $\theta$ for each treatment ratio in `p_grid`. This is done `num_reps` times per ratio and averaged.

In [5]:
def posterior_variance_grid(
    simulator: simulators.ExperimentSimulator,
    fitter: fitters.Fitter,
    p_grid: np.ndarray,
    num_reps: int,
):
    # leave the original simulator unchanged
    sim_copy = copy.deepcopy(simulator)

    variances = {}
    for p in tqdm(p_grid, desc="p_grid", leave=False):
        variances[p] = []
        for _ in tqdm(range(num_reps), desc="num_reps", leave=False):
            sim_copy.p = np.array([p])
            sim_copy.n = np.array([sim_copy.n[0]])  # fix n to the original value
            expt = sim_copy.simulate()
            result = fitter.fit(expt, show_progress=False)
            theta_variance = np.var(result.theta)
            variances[p].append(theta_variance)

        variances[p] = np.mean(variances[p])
    return variances

In [6]:
import plotly.graph_objects as go
import plotly.express as px


def plot_posterior_variances(variances, title="Posterior Variances vs Treatment Ratio"):
    """
    Plots posterior variances for different treatment ratios.

    Parameters:
    - variances (dict): Dictionary containing posterior variance data for different variants.
    - title (str): Title of the plot.

    Returns:
    - fig (plotly.graph_objects.Figure): The generated plotly figure.
    """
    # Prepare figure
    fig = go.Figure()
    colors = px.colors.qualitative.Plotly
    i = 0

    # Plot each variant except metadata and desc
    for variant, data in variances.items():
        if variant in ("desc", "metadata"):
            continue

        # Sort by p and extract values
        ps = sorted(data.keys())
        ys = []
        for p in ps:
            val = data[p]
            if isinstance(val, np.ndarray):
                ys.append(val.mean())
            elif isinstance(val, (list, tuple)):
                ys.append(np.mean(val))
            else:
                ys.append(float(val))

        # Pick color from the cycle
        color = colors[i % len(colors)]
        i += 1

        # Add line trace with explicit color
        fig.add_trace(
            go.Scatter(
                x=ps,
                y=ys,
                mode="lines",
                name=variant,
                line=dict(color=color),
            )
        )

        # Mark the minimal point with the same color
        idx_min = int(np.argmin(ys))
        fig.add_trace(
            go.Scatter(
                x=[ps[idx_min]],
                y=[ys[idx_min]],
                mode="markers",
                marker=dict(color=color, size=10),
                showlegend=False,
            )
        )

    # Layout with log-scaled y axis
    fig.update_layout(
        title=title,
        xaxis_title="Treatment Ratio p",
        yaxis=dict(title="Posterior Variance of Theta", type="log"),
    )

    return fig

In [7]:
print("These are the parameters from the chicken dataset:\n")
pp.pprint(utils.CHICK_SIMULATOR.params())

These are the parameters from the chicken dataset:

{'mu_b': 0.004112476586286136,
 'mu_theta': 0.09769112704348468,
 'sigma0': 0.22627416997969524,
 'sigma1': 0.22627416997969524,
 'sigma_b': 0.0015924804430524674,
 'sigma_theta': 0.056385519973983916}


In [8]:
if "chick_variances" not in globals():
    new_simulations = True

    num_reps = 20
    p_grid = np.linspace(0, 1, 20)

    chick_variances = {}

    # get the posterior variance for each fitter
    for fitter in tqdm(fitter_list, desc="fitter_list", leave=False):
        chick_variances[fitter.name] = posterior_variance_grid(
            simulator=utils.CHICK_SIMULATOR,
            fitter=fitter,
            p_grid=p_grid,
            num_reps=num_reps,
        )

    # get the theoretical posterior variance (for the oracle fitter)
    sim_copy = copy.deepcopy(utils.CHICK_SIMULATOR)
    theoretical_vars_oracle = {}
    for p in p_grid:
        sim_copy.p = np.array([p])
        theoretical_vars_oracle[p] = utils.theta_posterior_variance(sim_copy)
    chick_variances["theoretical_oracle"] = theoretical_vars_oracle

    # store the data associated with this simulation
    chick_variances["metadata"] = utils.CHICK_SIMULATOR.params() | {
        "num_reps": num_reps,
        "p_grid": p_grid,
        "desc": "Computed posterior variances of theta for different Stan models. These hyperparameters are based on the chicken dataset.",
    }

    # store the results
    data_to_store["chick_variances"] = chick_variances

plot_posterior_variances(chick_variances)

In [9]:
if "zoomed_chick_variances" not in globals():
    new_simulations = True

    num_reps = 20
    p_grid = np.linspace(0.45, 0.55, 20)

    zoomed_chick_variances = {}

    base_simulator = copy.deepcopy(utils.CHICK_SIMULATOR)

    # get the posterior variance for each fitter
    for fitter in tqdm(
        [fitters.FitterFlatHyperpriors(), fitters.FitterConstbFlatHypersTheta()],
        desc="fitter_list",
        leave=False,
    ):
        zoomed_chick_variances[fitter.name] = posterior_variance_grid(
            simulator=base_simulator,
            fitter=fitter,
            p_grid=p_grid,
            num_reps=num_reps,
        )

    # get the theoretical posterior variance (for the oracle fitter)
    sim_copy = copy.deepcopy(base_simulator)
    theoretical_vars_oracle = {}
    for p in p_grid:
        sim_copy.p = np.array([p])
        theoretical_vars_oracle[p] = utils.theta_posterior_variance(sim_copy)
    zoomed_chick_variances["theoretical_oracle"] = theoretical_vars_oracle

    # store the data associated with this simulation
    zoomed_chick_variances["metadata"] = base_simulator.params() | {
        "num_reps": num_reps,
        "p_grid": p_grid,
        "desc": "Computed posterior variances of theta for different Stan models. Here sigma1 is three times sigma0, and sigma0 is the same value as in the chicken dataset.",
    }

    # store the results
    data_to_store["zoomed_chick_variances"] = zoomed_chick_variances

plot_posterior_variances(zoomed_chick_variances)

This seems to confirm that the theoretical derivation is correct. The graph also shows how much the posterior variances change depending on our model; the flat hyperprior model has the smallest variance near $p = 0.5$, whereas the oracle model is minimized at $p = 1$. The second graph is essentially zoomed-in near the minimum of the flat hyperprior curve. I was curious if it was a coincidence that in the top graph, the flat hyperprior curve is minimized at slightly less than $p=0.5$; especially since $\sigma_\theta$ is much larger than $\sigma_b$, it seemed to me that this shouldn't happen. Based on the second graph, this was probably coincidence. 

Can we choose our hyperparameters so that the flat prior model is minimized somewhere other than $p = 0.5$? Let's repeat this simulation with $\sigma_1$ to be three times as large as $\sigma_0$, keeping $\sigma_0$ at its original value.

In [10]:
if "sigma1_three_times_sigma0" not in globals():
    new_simulations = True

    num_reps = 20
    p_grid = np.linspace(0, 1, 20)

    sigma1_three_times_sigma0 = {}

    base_simulator = copy.deepcopy(utils.CHICK_SIMULATOR)
    base_simulator.sigma1 = 3 * base_simulator.sigma0

    # get the posterior variance for each fitter
    for fitter in tqdm(fitter_list, desc="fitter_list", leave=False):
        sigma1_three_times_sigma0[fitter.name] = posterior_variance_grid(
            simulator=base_simulator,
            fitter=fitter,
            p_grid=p_grid,
            num_reps=num_reps,
        )

    # get the theoretical posterior variance (for the oracle fitter)
    sim_copy = copy.deepcopy(base_simulator)
    theoretical_vars_oracle = {}
    for p in p_grid:
        sim_copy.p = np.array([p])
        theoretical_vars_oracle[p] = utils.theta_posterior_variance(sim_copy)
    sigma1_three_times_sigma0["theoretical_oracle"] = theoretical_vars_oracle

    # store the data associated with this simulation
    sigma1_three_times_sigma0["metadata"] = base_simulator.params() | {
        "num_reps": num_reps,
        "p_grid": p_grid,
        "desc": "Computed posterior variances of theta for different Stan models. Here sigma1 is three times sigma0, and sigma0 is the same value as in the chicken dataset.",
    }

    # store the results
    data_to_store["sigma1_three_times_sigma0"] = sigma1_three_times_sigma0

plot_posterior_variances(sigma1_three_times_sigma0)

So clearly the flat hyperprior model can be minimized at a point other than $p=0.5$ if $\sigma_0 \neq \sigma_1$. Let's see what happens if we set $\sigma_1 / \sigma_0$ to other values, keeping $\sigma_0$ fixed as before.

In [11]:
if "sigma1_sigma0_ratios" not in globals():
    new_simulations = True

    num_reps = 10
    p_grid = np.linspace(0, 1, 10)
    r = np.logspace(-1, 2, num=10)

    sigma1_sigma0_ratios = {r_val: {} for r_val in r}

    base_simulator = copy.deepcopy(utils.CHICK_SIMULATOR)

    for r_val in tqdm(r, desc="r_values", leave=False):
        base_simulator.sigma1 = r_val * base_simulator.sigma0

        # get the posterior variance for each fitter
        for fitter in tqdm(
            [fitters.FitterFlatHyperpriors(), fitters.FitterConstbFlatHypersTheta()],
            desc="fitter_list",
            leave=False,
        ):
            sigma1_sigma0_ratios[r_val][fitter.name] = posterior_variance_grid(
                simulator=base_simulator,
                fitter=fitter,
                p_grid=p_grid,
                num_reps=num_reps,
            )

    # store the data associated with this simulation
    sigma1_sigma0_ratios["metadata"] = base_simulator.params() | {
        "sigma1/sigma0": r.tolist(),
        "num_reps": num_reps,
        "p_grid": p_grid,
        "desc": "For each ratio of sigma1/sigma0, computed posterior variances of theta for flat hyperprior models. Here sigma0 is the same value as in the chicken dataset.",
    }
    del [sigma1_sigma0_ratios["metadata"]["sigma1"]]

    # store the results
    data_to_store["sigma1_sigma0_ratios"]["metadata"] = sigma1_sigma0_ratios

In [12]:
from plotly.subplots import make_subplots

cols = 5
rows = (len(r) + cols - 1) // cols
titles = [f"σ1/σ0 = {val:.2f}" for val in r]

# pick the variants once
variants = [v for v in sigma1_sigma0_ratios[r[0]] if v != "metadata"]
colors = px.colors.qualitative.Plotly
variant_colors = {v: colors[i % len(colors)] for i, v in enumerate(variants)}

fig = make_subplots(rows=rows, cols=cols, subplot_titles=titles)

for i, r_val in enumerate(r):
    row = i // cols + 1
    col = i % cols + 1
    data_dict = sigma1_sigma0_ratios[r_val]

    # plot each variant
    for variant in variants:
        data = data_dict[variant]
        ps = sorted(data.keys())
        ys = [
            float(data[p].mean()) if hasattr(data[p], "mean") else float(data[p])
            for p in ps
        ]
        fig.add_trace(
            go.Scatter(
                x=ps,
                y=ys,
                mode="lines",
                name=variant,
                showlegend=(i == 0),
                line=dict(color=variant_colors[variant]),
            ),
            row=row,
            col=col,
        )
        # mark the minimum
        min_idx = int(np.argmin(ys))
        fig.add_trace(
            go.Scatter(
                x=[ps[min_idx]],
                y=[ys[min_idx]],
                mode="markers",
                marker=dict(size=6, color=variant_colors[variant]),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

    # add vertical dashed line at p = r/(r+1)
    x_vline = float(r_val / (r_val + 1))
    fig.add_vline(
        x=x_vline,
        line=dict(color="gray", dash="dash"),
        row=row,
        col=col,
    )

fig.update_layout(
    width=1000,
    height=rows * 250,
    title="Posterior Variances vs Treatment Ratio for each σ1/σ0",
    margin=dict(t=50, b=50),
)
fig.show()

NameError: name 'r' is not defined

The dashed line is the value of $\frac{\sigma_1}{\sigma_0 + \sigma_1}$, which is where these curves appear to be minimized. What if we make $\sigma_b$ much larger? In what follows, the values of the parameters are as printed in the cell below. Note that $\sigma_1$ isn't printed because, as above, we vary it to produce various values of the ratio $\sigma_1/\sigma_0$.

In [None]:
sim = copy.deepcopy(utils.CHICK_SIMULATOR)
sim.sigma_b *= 100
params = sim.params()
del [params["sigma1"]]

pp.pprint(params)

In [None]:
if "sigma1_sigma0_ratios_large_sigma_b" not in globals():
    new_simulations = True

    num_reps = 5
    p_grid = np.linspace(0, 1, 10)
    r = np.logspace(-1, 2, num=5)

    sigma1_sigma0_ratios_large_sigma_b = {r_val: {} for r_val in r}

    base_simulator = copy.deepcopy(utils.CHICK_SIMULATOR)
    base_simulator.sigma_b *= 100

    for r_val in tqdm(r, desc="r_values", leave=False):
        base_simulator.sigma1 = r_val * base_simulator.sigma0

        # get the posterior variance for each fitter
        for fitter in tqdm(
            [fitters.FitterFlatHyperpriors()],
            desc="fitter_list",
            leave=False,
        ):
            sigma1_sigma0_ratios_large_sigma_b[r_val][fitter.name] = (
                posterior_variance_grid(
                    simulator=base_simulator,
                    fitter=fitter,
                    p_grid=p_grid,
                    num_reps=num_reps,
                )
            )

    # store the data associated with this simulation
    sigma1_sigma0_ratios_large_sigma_b["metadata"] = base_simulator.params() | {
        "sigma1/sigma0": r.tolist(),
        "num_reps": num_reps,
        "p_grid": p_grid,
        "desc": "For each ratio of sigma1/sigma0, computed posterior variances of theta for flat hyperprior models. Here sigma0 is the same value as in the chicken dataset, and sigma_b was chosen to be much larger.",
    }
    del [sigma1_sigma0_ratios_large_sigma_b["metadata"]["sigma1"]]

    # store the results
    data_to_store["sigma1_sigma0_ratios_large_sigma_b"] = (
        sigma1_sigma0_ratios_large_sigma_b
    )

In [None]:
from plotly.subplots import make_subplots

cols = 5
rows = (len(r) + cols - 1) // cols
titles = [f"σ1/σ0 = {val:.2f}" for val in r]

# pick the variants once
variants = [v for v in sigma1_sigma0_ratios_large_sigma_b[r[0]] if v != "metadata"]
colors = px.colors.qualitative.Plotly
variant_colors = {v: colors[i % len(colors)] for i, v in enumerate(variants)}

fig = make_subplots(rows=rows, cols=cols, subplot_titles=titles)

for i, r_val in enumerate(r):
    row = i // cols + 1
    col = i % cols + 1
    data_dict = sigma1_sigma0_ratios_large_sigma_b[r_val]

    # plot each variant
    for variant in variants:
        data = data_dict[variant]
        ps = sorted(data.keys())
        ys = [
            float(data[p].mean()) if hasattr(data[p], "mean") else float(data[p])
            for p in ps
        ]
        fig.add_trace(
            go.Scatter(
                x=ps,
                y=ys,
                mode="lines",
                name=variant,
                showlegend=(i == 0),
                line=dict(color=variant_colors[variant]),
            ),
            row=row,
            col=col,
        )
        # mark the minimum
        min_idx = int(np.argmin(ys))
        fig.add_trace(
            go.Scatter(
                x=[ps[min_idx]],
                y=[ys[min_idx]],
                mode="markers",
                marker=dict(size=6, color=variant_colors[variant]),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

    # add vertical dashed line at p = r/(r+1)
    x_vline = float(r_val / (r_val + 1))
    fig.add_vline(
        x=x_vline,
        line=dict(color="gray", dash="dash"),
        row=row,
        col=col,
    )

fig.update_layout(
    width=1000,
    height=rows * 250,
    title="Posterior Variances vs Treatment Ratio for each σ1/σ0, large σ_b",
    margin=dict(t=50, b=50),
)
fig.show()

The trend appears to stay the same. Indeed, the only thing that matters for minimizing the posterior variance seems to be $\sigma_0$ and $\sigma_1$, and the minimum posterior variance for the flat hyperprior model is achieved at
$$ p^* = \frac{\sigma_1}{\sigma_0 + \sigma_1}. $$
Let's return to the case of $J=38$ experiments and examine how using this $p^*$ compares to $p=0.5$. Note that $p^* = 0.5$ if $\sigma_0 = \sigma_1$, so we need to vary the ratio $r := \sigma_1 / \sigma_0$ so that $ p^* = r / (r + 1) \neq 0.5$.

In [None]:
if "var_theta_mu_theta_r" not in globals():
    r = np.logspace(-1, 2, num=10)
    num_reps = 10
    base_simulator = copy.deepcopy(utils.CHICK_SIMULATOR)

    fitter = fitters.FitterFlatHyperpriors()

    variances_theta = {}
    variances_mu_theta = {}

    for r_val in tqdm(r, desc="r_values", leave=False):
        variances_theta[r_val] = {}
        variances_mu_theta[r_val] = {}
        base_simulator.sigma1 = r_val * base_simulator.sigma0

        p_star = r_val / (r_val + 1)

        for p in tqdm([0.5, p_star], desc="p_grid", leave=False):
            variances_theta[r_val][p] = np.zeros(utils.CHICK_J)
            variances_mu_theta[r_val][p] = np.zeros(utils.CHICK_J)

            for _ in tqdm(range(num_reps), desc="num_reps", leave=False):
                base_simulator.p = np.array([p] * utils.CHICK_J)
                expt = base_simulator.simulate()
                result = fitter.fit(expt, show_progress=False)

                variances_theta[r_val][p] += np.var(result.theta, axis=0) / num_reps
                variances_mu_theta[r_val][p] += (
                    np.var(result.mu_theta, axis=0) / num_reps
                )

    var_theta_mu_theta_r = {
        "variances_theta": variances_theta,
        "variances_mu_theta": variances_mu_theta,
    }

    # store the data associated with this simulation
    var_theta_mu_theta_r["metadata"] = base_simulator.params() | {
        "sigma1/sigma0": r.tolist(),
        "num_reps": num_reps,
        "desc": "Computed posterior variances of theta and mu_theta for flat hyperprior models. Here sigma0 is the same value as in the chicken dataset.",
    }
    del [var_theta_mu_theta_r["metadata"]["sigma1"]]
    # store the results
    data_to_store["var_theta_mu_theta_r"] = var_theta_mu_theta_r

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

variances_theta = var_theta_mu_theta_r["variances_theta"]
variances_mu_theta = var_theta_mu_theta_r["variances_mu_theta"]

# extract the r‐values (skip metadata)
r_vals = [r for r in variances_theta.keys() if r != "metadata"]
n_rows = len(r_vals)

# build titles row‐by‐row: (Var(theta), Var(mu_theta)) for each r
subplot_titles = [
    t
    for r_val in r_vals
    for t in (
        f"r = {r_val:.2f}",
        f"r = {r_val:.2f}",
    )
]

fig = make_subplots(
    rows=n_rows,
    cols=2,
    subplot_titles=subplot_titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.05,
)

for i, r_val in enumerate(r_vals, start=1):
    p_star = r_val / (r_val + 1)
    x = list(range(len(variances_theta[r_val][0.5])))

    # left: theta
    fig.add_trace(
        go.Scatter(
            x=x,
            y=variances_theta[r_val][0.5],
            mode="lines",
            name="p=0.5",
            line=dict(color="blue"),
            showlegend=(i == 1),
        ),
        row=i,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=variances_theta[r_val][p_star],
            mode="lines",
            name=f"p*={p_star:.2f}",
            line=dict(color="red"),
            showlegend=(i == 1),
        ),
        row=i,
        col=1,
    )

    # right: mu_theta
    fig.add_trace(
        go.Scatter(
            x=x,
            y=variances_mu_theta[r_val][0.5],
            mode="lines",
            line=dict(color="blue"),
            showlegend=False,
        ),
        row=i,
        col=2,
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=variances_mu_theta[r_val][p_star],
            mode="lines",
            line=dict(color="red"),
            showlegend=False,
        ),
        row=i,
        col=2,
    )

# axis labels
fig.update_xaxes(title_text="Experiment Index", row=n_rows, col=1)
fig.update_xaxes(title_text="Experiment Index", row=n_rows, col=2)
fig.update_yaxes(title_text="Var(theta)", col=1)
fig.update_yaxes(title_text="Var(mu_theta)", col=2)

# layout
fig.update_layout(
    height=200 * n_rows,
    width=900,
    title_text="Posterior Variances Across Experiments",
)
fig.show()

In [None]:
if "var_theta_mu_theta_const_b_r" not in globals():
    r = np.logspace(-1, 2, num=10)
    num_reps = 10
    base_simulator = copy.deepcopy(utils.CHICK_SIMULATOR)

    fitter = fitters.FitterConstbFlatHypersTheta()

    variances_theta_const_b = {}
    variances_mu_theta_const_b = {}

    cols = (2,)
    for r_val in tqdm(r, desc="r_values", leave=False):
        variances_theta_const_b[r_val] = {}
        variances_mu_theta_const_b[r_val] = {}
        base_simulator.sigma1 = r_val * base_simulator.sigma0

        p_star = r_val / (r_val + 1)

        for p in tqdm([0.5, p_star], desc="p_grid", leave=False):
            variances_theta_const_b[r_val][p] = np.zeros(utils.CHICK_J)
            variances_mu_theta_const_b[r_val][p] = np.zeros(utils.CHICK_J)

            for _ in tqdm(range(num_reps), desc="num_reps", leave=False):
                base_simulator.p = np.array([p] * utils.CHICK_J)
                expt = base_simulator.simulate()
                result = fitter.fit(expt, show_progress=False)

                variances_theta_const_b[r_val][p] += (
                    np.var(result.theta, axis=0) / num_reps
                )
                variances_mu_theta_const_b[r_val][p] += (
                    np.var(result.mu_theta, axis=0) / num_reps
                )

    var_theta_mu_theta_const_b_r = {
        "variances_theta_const_b": variances_theta_const_b,
        "variances_mu_theta_const_b": variances_mu_theta_const_b,
    }

    # store the data associated with this simulation
    var_theta_mu_theta_const_b_r["metadata"] = base_simulator.params() | {
        "sigma1/sigma0": r.tolist(),
        "num_reps": num_reps,
        "desc": "Computed posterior variances of theta and mu_theta for flat hyperprior models. Here sigma0 is the same value as in the chicken dataset.",
    }
    del [var_theta_mu_theta_const_b_r["metadata"]["sigma1"]]
    # store the results
    data_to_store["var_theta_mu_theta_const_b_r"] = var_theta_mu_theta_const_b_r

In [None]:
var_theta_mu_theta_const_b_r = {
    "variances_theta_const_b": variances_theta_const_b,
    "variances_mu_theta_const_b": variances_mu_theta_const_b,
}

# store the data associated with this simulation
var_theta_mu_theta_const_b_r["metadata"] = base_simulator.params() | {
    "sigma1/sigma0": r.tolist(),
    "num_reps": num_reps,
    "desc": "Computed posterior variances of theta and mu_theta for flat hyperprior model with constant b. Here sigma0 is the same value as in the chicken dataset.",
}
del [var_theta_mu_theta_const_b_r["metadata"]["sigma1"]]
# store the results
data_to_store["var_theta_mu_theta_const_b_r"] = var_theta_mu_theta_const_b_r

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

variances_theta_const_b = var_theta_mu_theta_const_b_r["variances_theta_const_b"]
variances_mu_theta_const_b = var_theta_mu_theta_const_b_r["variances_mu_theta_const_b"]

# extract the r‐values (skip metadata)
r_vals = [r for r in variances_theta_const_b.keys() if r != "metadata"]
n_rows = len(r_vals)

# build titles row‐by‐row: (Var(theta), Var(mu_theta)) for each r
subplot_titles = [
    t
    for r_val in r_vals
    for t in (
        f"r = {r_val:.2f}",
        f"r = {r_val:.2f}",
    )
]

fig = make_subplots(
    rows=n_rows,
    cols=2,
    subplot_titles=subplot_titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.05,
)

for i, r_val in enumerate(r_vals, start=1):
    p_star = r_val / (r_val + 1)
    x = list(range(len(variances_theta_const_b[r_val][0.5])))

    # left: theta
    fig.add_trace(
        go.Scatter(
            x=x,
            y=variances_theta_const_b[r_val][0.5],
            mode="lines",
            name="p=0.5",
            showlegend=(i == 1),
        ),
        row=i,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=variances_theta_const_b[r_val][p_star],
            mode="lines",
            name=f"p*={p_star:.2f}",
            showlegend=(i == 1),
        ),
        row=i,
        col=1,
    )

    # right column: variances_mu_theta_const_b
    fig.add_trace(
        go.Scatter(
            x=x,
            y=variances_mu_theta_const_b[r_val][0.5],
            mode="lines",
            showlegend=False,
        ),
        row=i,
        col=2,
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=variances_mu_theta_const_b[r_val][p_star],
            mode="lines",
            showlegend=False,
        ),
        row=i,
        col=2,
    )

# axis labels
fig.update_xaxes(title_text="Experiment Index", row=n_rows, col=1)
fig.update_xaxes(title_text="Experiment Index", row=n_rows, col=2)
fig.update_yaxes(title_text="Var(theta)", col=1)
fig.update_yaxes(title_text="Var(mu_theta)", col=2)

# layout
fig.update_layout(
    height=200 * n_rows,
    width=900,
    title_text="Posterior Variances Across Experiments",
)
fig.show()

In [None]:
if new_simulations:
    utils.save_simulation(data_to_store, data_dir=data_dir, prefix=prefix)