- `TruncatedStickBreakingWeights` or `StickBreakingWeights` as name

In [1]:
import numpy as np

import aesara.tensor as at
from aesara.tensor.random.op import RandomVariable, default_shape_from_params

import pymc3 as pm

from pymc3.distributions.continuous import assert_negative_support
from pymc3.distributions.dist_math import bound, normal_lcdf
from pymc3.distributions.distribution import Continuous

You are running the v4 development version of PyMC3 which currently still lacks key features. You probably want to use the stable v3 instead which you can either install via conda or find on the v3 GitHub branch: https://github.com/pymc-devs/pymc3/tree/v3


## `StickBreakingWeightsRV`

In [2]:
class StickBreakingWeightsRV(RandomVariable):
    name = "stick_breaking_weights"
    ndim_supp = 1
    ndims_params = [0]
    dtype = "floatX"
    _print_name = ("StickBreakingWeights", "\\operatorname{StickBreakingWeights}")

    def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
        return default_shape_from_params(
            self.ndim_supp, dist_params, rep_param_idx, param_shapes
        )
    
    def __call__(self, alpha, size=None, **kwargs):
        return super().__call__(alpha, size=size, **kwargs)

    @classmethod
    def rng_fn(cls, rng, alpha, size):
        size = tuple(size or ())
        
        betas = rng.beta(1, alpha, size=size)
        sticks = np.concatenate(
            (
                np.ones(shape=(size[:-1] + (1,))),
                np.cumprod(1 - betas[..., :-1], axis=-1),
            ),
            axis=-1,
        )
        
        weights = sticks * betas
        weights = np.concatenate(
            (
                weights,
                1 - weights.sum(axis=-1)[..., np.newaxis]
            ),
            axis=-1,
        )

        return weights


stickbreakingweights = StickBreakingWeightsRV()

In [3]:
rng = np.random.RandomState(seed=34)
stickbreakingweights.rng_fn(rng, alpha=3., size=[2, 4, 3])

array([[[0.01619353, 0.00477256, 0.2507677 , 0.72826621],
        [0.00744523, 0.27839675, 0.06021795, 0.65394007],
        [0.08845704, 0.29959431, 0.33366337, 0.27828527],
        [0.14233668, 0.24819568, 0.22722516, 0.38224248]],

       [[0.18058677, 0.31910632, 0.25701285, 0.24329406],
        [0.23584613, 0.06854922, 0.06169366, 0.63391099],
        [0.01599996, 0.15505591, 0.37165029, 0.45729384],
        [0.22766902, 0.12326135, 0.04006222, 0.60900741]]])

In [4]:
stickbreakingweights(2.).eval()

IndexError: list index out of range

`stickbreakingweights` works, but is not flexible at all towards different shapes

In [None]:
class StickBreakingWeights(Continuous):
    rv_op = stickbreakingweights

    @classmethod
    def dist(cls, alpha, *args, **kwargs):
        alpha = at.as_tensor_variable(alpha)

        assert_negative_support(alpha, "alpha", "StickBreakingWeights")

        return super().dist([alpha], **kwargs)

    def logp(value, alpha):
        return bound(
            at.sum(pm.Beta.logp(value, 1, alpha)),
            alpha > 0,
        )

    def _distr_parameters_for_repr(self):
        return ["alpha"]

In [None]:
with pm.Model() as model:
    sbw = StickBreakingWeights(name="sbw", alpha=2.)