In [1]:
from copy import deepcopy
from typing import Any, Optional

import hvplot.pandas
import numpy as np
import numpy.typing as npt
import pandas as pd
import panel as pn
import panel.widgets as pnw

from scipy import stats

import dms_stan as dms

In [2]:
time = np.linspace(0, 1, 5)
model = dms.model.ExponentialGrowthBinomialModel(
    t = time,
    counts = np.random.randint(0, 100, (1, 5, 10)),
    A = dms.param.LogNormal(mu=0.0, sigma=1.0, shape=(10,)),
    r = dms.param.Exponential(beta = 1.0, shape=(10,))
)

In [25]:
def aggregate_data(
    data: npt.NDArray, independent_dim: Optional[int] = None
) -> npt.NDArray:
    """
    Aggregates data from a numpy array. Here are the rules:

    1.  If the independent dimension is not provided, the data array is flattened.
        In this case, the `independent_labels` parameter is ignored.
    2.  If the independent dimension is provided and the independent labels
        are not provided, then the data array is flattened along all dimensions
        except for the independent dimension. That is, a 2D array is returned
        with shape (-1, n_independent), where "-1" indicates the product of
        all other dimensions.
    """
    # Flatten the data if the independent dimension is not provided.
    if independent_dim is None:
        return data.flatten()

    # If the independent dimension is provided, first move that dimension to
    # the end (reshape is C-major), then reshape the data to flatten all other dimensions
    else:
        n_independent = data.shape[independent_dim]
        return np.moveaxis(data, independent_dim, -1).reshape((-1, n_independent))


def _plot_ecdf_kde(plotting_df, paramname):
    """Renders the plots."""
    # Split the plotting dataframe into ECDF and KDE dataframes
    ecdf_df = plotting_df[plotting_df["plot_type"] == "ecdf"]
    kde_df = plotting_df[plotting_df["plot_type"] == "kde"]

    # Build the plots, combine, and return
    ecdf_plot = ecdf_df.hvplot.line(
        x=paramname,
        y="Cumulative Probability",
        title="ECDF",
        width=600,
        height=400,
    )
    kde_plot = kde_df.hvplot.kde(y=paramname, title="KDE", width=600, height=400, cut=0)

    return ecdf_plot + kde_plot


class PriorPredictiveCheck:
    """Base class for prior predictive checks."""

    def __init__(self, model: dms.model.Model, copy_model: bool = False):

        # Copy the model if requested. If we don't copy, then we can modify our
        # values on the model directly.
        self.model = deepcopy(model) if copy_model else model

    def build_plotting_dfs(self, paramname: str, n_draws: int) -> pd.DataFrame:
        """
        Builds the dataframes that will be used for plotting the prior predictive
        check. The dataframes are built from the samples drawn from the model.
        """
        # Get the samples from the model
        data = aggregate_data(self.model.draw_from(paramname, n_draws))

        # Get the ECDF dataframe
        ecdf = stats.ecdf(data).cdf
        ecdf_df = pd.DataFrame(
            {
                paramname: np.repeat(ecdf.quantiles, 2),
                "Cumulative Probability": np.repeat(ecdf.probabilities, 2),
                "plot_type": "ecdf",
            }
        )

        # Build the kde dataframe
        kde_df = pd.DataFrame(
            {paramname: data, "Cumulative Probability": -1, "plot_type": "kde"}
        )

        # Stack the dataframes
        return pd.concat([ecdf_df, kde_df])

    def _init_float_sliders(self) -> dict[str, pnw.EditableFloatSlider]:
        """Gets the float sliders for the togglable parameters in the model."""
        # Each togglable parameter gets its own float slider
        sliders = {}
        for paramname, paramdict in self.model.togglable_param_values.items():
            for constant_name, constant_value in paramdict.items():
                combined_name = f"{paramname}.{constant_name}"
                sliders[combined_name] = pnw.EditableFloatSlider(
                    name=combined_name, value=constant_value.item()
                )

        return sliders

    def _init_target_dropdown(self, initial_view: Optional[str]) -> pnw.Select:
        """
        Gets the dropdown for selecting the target parameter. The target parameter
        is any named parameter or observable in the model.
        """
        # Get the list of parameters and observables
        legal_targets = [name for name, _ in self.model]

        # If the initial view is not provided, then we default to an observable.
        initial_param = legal_targets[0] if initial_view is None else initial_view

        # The initial view must be in the model
        if initial_param not in self.model:
            raise ValueError(
                f"The model has no parameter or observable named {initial_param}."
            )

        # Build the dropdown
        return pnw.Select(
            name="Target Parameter", options=legal_targets, value=initial_param
        )

    def _init_draw_slider(self) -> pnw.EditableIntSlider:
        """Gets the slider for the number of draws to use in the prior predictive check."""
        return pnw.EditableIntSlider(
            name="Number of Draws", value=100, start=1, end=10000
        )

    # We need a function that updates the model with new parameters
    def _viewer_backend(self, paramname: str, n_draws: int, **kwargs: float):
        """
        The key of each kwarg gives the name of the parameter to update, and the
        value is a dictionary that links the constant names within that parameter
        to the new values for those constants.
        """

        # Define helper functions. This is just for scoping and readability.
        def process_kwargs() -> dict[str, dict[str, npt.NDArray]]:
            """
            Kwargs passed in to the parent function are formatted as `paramname.constantname`
            mapped to floats. This function processes those kwargs into a dictionary
            of dictionaries, where the outer dictionary maps parameter names to
            dictionaries that map constant names to new values. The new values are
            also converted to numpy arrays.
            """
            processed_kwargs = {}
            for key, val in kwargs.items():
                paramname, constantname = key.split(".")
                if paramname not in processed_kwargs:
                    processed_kwargs[paramname] = {}
                processed_kwargs[paramname][constantname] = np.array(val)
            return processed_kwargs

        def update_model(processed_kwargs: dict[str, dict[str, npt.NDArray]]):
            """
            Changes the values of the constants in the model according to the processed
            kwargs.
            """
            for paramname, constant_dict in processed_kwargs.items():
                assert set(constant_dict) == set(self.model[paramname].parameters)
                self.model[paramname].parameters.update(constant_dict)

        # Update the model with the new parameters
        update_model(process_kwargs())

        # Return the dataframes for plotting
        return self.build_plotting_dfs(paramname, n_draws)

    def display(
        self,
        initial_view: Optional[str] = None,
        independent_dim: Optional[int] = None,
        independent_labels: Optional[npt.NDArray] = None,
    ):
        """
        Renders a display of samples drawn from the parameter given by `initial_view`.

        Args:
            initial_view (Optional[str]): The name of the parameter to display when
                initializing the display. If not provided, the observable in the
                model is displayed.

            independent_dim (Optional[int]): The dimension of the data that defines
                the independent variable, if any. If not provided, an ECDF and KDE
                plot is displayed. If provided without `independent_labels`, then
                a series of ECDFs and a series of violin plots are displayed, grouped
                by the independent dimension. If provided with `independent_labels`,
                then a series of violin plots are displayed, again grouped by the
                independent dimension, but now spaced according to the labels. This
                is a way of looking at, e.g., how distributions change over time.

            independent_labels (Optional[npt.NDArray]): The labels for the independent
                dimension. This must be the same length as the size of the independent
                dimension. If not provided, the independent dimension is treated
                as a simple index.
        """
        # Build widgets for the display
        float_sliders = self._init_float_sliders()
        target_dropdown = self._init_target_dropdown(initial_view)
        draw_slider = self._init_draw_slider()

        # Bind the widgets to the viewer backend
        plot_df = hvplot.bind(
            self._viewer_backend, target_dropdown, draw_slider, **float_sliders
        ).interactive()

        # Make the plot
        return _plot_ecdf_kde(plot_df, target_dropdown.value)

        # # One path if we have an independent dimension
        # if independent_dim is not None:

        #     # The independent dimension must be a valid dimension
        #     if independent_dim >= self.model[initial_param].ndim:
        #         raise ValueError(
        #             f"The independent dimension {independent_dim} is not a valid "
        #             f"dimension for the parameter {initial_param}, which only has "
        #             f"{self.model[initial_param].ndim} dimensions."
        #         )

        #     # Get the number of independent variables
        #     n_independent = self.model[initial_param].shape[independent_dim]

        #     # Another branch if we have independent labels
        #     if independent_labels is not None:

        #         # The independent labels must have the same length as the independent
        #         # dimension
        #         if len(independent_labels) != n_independent:
        #             raise ValueError(
        #                 f"The independent labels must have the same length as the "
        #                 f"independent dimension, which is {n_independent}."
        #             )

        #         # The plotter in this case is a dependent variable plotter
        #         plotter = DependentVariablePlotter

        #     # Otherwise, we have a simple categorical variable plotter
        #     else:
        #         plotter = CategoricalVariablePlotter

        # # If no independent dimension is provided, our plotter is just a distribution
        # # plotter
        # else:
        #     plotter = DistributionPlotter

    # def initialize_display(self):

In [26]:
test = PriorPredictiveCheck(model)
test.display()

BokehModel(combine_events=True, render_bundle={'docs_json': {'a24bf1ee-05bf-4d29-a44f-28759497fa62': {'version…

Invoked as evaluate(value=1.0)
Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)
Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)
Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)
Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)
Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)
Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)
Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve

Invoked as evaluate(value=1.0)
Invoked as evaluate(value=1.0)


Traceback (most recent call last):
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/plotting/util.py", line 293, in get_plot_frame
    return map_obj[key]
           ~~~~~~~^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 1216, in __getitem__
    val = self._execute_callback(*tuple_key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 983, in _execute_callback
    retval = self.callback(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/core/spaces.py", line 552, in __call__
    return self.callable()
           ^^^^^^^^^^^^^^^
  File "/home/bwittmann/anaconda3/envs/dms_stan/lib/python3.12/site-packages/holoviews/util/__init__.py", line 1038, in dynamic_operation
    key, obj = resolve