- `TruncatedStickBreakingWeights` or `StickBreakingWeights` as name

In [1]:
import numpy as np

import aesara.tensor as at
from aesara.tensor.random.op import RandomVariable, default_shape_from_params

import pymc3 as pm

from pymc3.distributions.continuous import assert_negative_support
from pymc3.distributions.dist_math import bound, normal_lcdf
from pymc3.distributions.distribution import Continuous

# a bunch of imports for testing and printing

from aesara.tensor.basic import get_vector_length
from aesara.tensor.random.utils import params_broadcast_shapes
from aesara.tensor.shape import shape_tuple

import aesara

print(f"PyMC3 version: {pm.__version__}")
print(f"Aesara version: {aesara.__version__}")

You are running the v4 development version of PyMC3 which currently still lacks key features. You probably want to use the stable v3 instead which you can either install via conda or find on the v3 GitHub branch: https://github.com/pymc-devs/pymc3/tree/v3


PyMC3 version: 4.0
Aesara version: 2.2.1


## `StickBreakingWeightsRV`

In [2]:
class StickBreakingWeightsRV(RandomVariable):
    name = "stick_breaking_weights"
    ndim_supp = 1
    ndims_params = [0, 0]
    dtype = "floatX"
    _print_name = ("StickBreakingWeights", "\\operatorname{StickBreakingWeights}")
    
    def __call__(self, alpha, K, size=None, **kwargs):
        return super().__call__(alpha, K, size=size, **kwargs)
    
    def _shape_from_params(self, dist_params, **kwargs):
        return dist_params

    @classmethod
    def rng_fn(cls, rng, alpha, K, size):
        size = tuple(size or ()) + (K,)
        
        betas = rng.beta(1, alpha, size=size)
        
        sticks = np.concatenate(
            (
                np.ones(shape=(size[:-1] + (1,))),
                np.cumprod(1 - betas[..., :-1], axis=-1),
            ),
            axis=-1,
        )
        
        weights = sticks * betas
        weights = np.concatenate(
            (
                weights,
                1 - weights.sum(axis=-1)[..., np.newaxis]
            ),
            axis=-1,
        )

        return weights
    

stickbreakingweights = StickBreakingWeightsRV()

In [3]:
rng = np.random.RandomState(seed=34)
stickbreakingweights.rng_fn(rng, alpha=3., K=5, size=[2, 4, 3])

array([[[[1.61935271e-02, 4.77256162e-03, 2.50767701e-01,
          5.42211193e-03, 2.02746945e-01, 5.20097154e-01],
         [8.43202066e-02, 8.09983286e-02, 2.74332455e-01,
          3.05528810e-01, 3.62702609e-02, 2.18549939e-01],
         [2.89385907e-01, 2.64935152e-01, 8.04837204e-02,
          1.42218967e-01, 1.14545216e-01, 1.08431037e-01]],

        [[2.35846127e-01, 6.85492217e-02, 6.16936640e-02,
          1.01425485e-02, 9.82916469e-02, 5.25476792e-01],
         [4.48341784e-01, 1.25595486e-01, 6.79981380e-02,
          2.21006512e-02, 1.79721121e-01, 1.56242820e-01],
         [1.36910428e-01, 4.68650153e-02, 8.33394840e-02,
          3.06042112e-01, 8.79875278e-02, 3.38855433e-01]],

        [[1.48663845e-01, 3.09608502e-01, 2.29510078e-01,
          9.30216425e-02, 7.81855081e-02, 1.41010424e-01],
         [2.58523348e-02, 3.94307518e-03, 1.28009718e-03,
          1.14238656e-01, 3.50565795e-02, 8.19629258e-01],
         [2.96100687e-01, 2.76773844e-01, 1.12799085e-02,
  

In [4]:
# alpha is an int
stickbreakingweights(alpha=2, K=5).eval()

array([0.09789135, 0.00626342, 0.2622335 , 0.5442116 , 0.00395308,
       0.08544704])

In [5]:
# alpha is a float
stickbreakingweights(alpha=2., K=5).eval()

AssertionError: (stick_breaking_weights_rv{1, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FDE47636120>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{2.0}, TensorConstant{5}), 'float32')

`stickbreakingweights` works, but is not flexible at all towards different shapes

In [None]:
class StickBreakingWeights(Continuous):
    rv_op = stickbreakingweights

    @classmethod
    def dist(cls, alpha, K, *args, **kwargs):
        alpha = at.as_tensor_variable(alpha)

        assert_negative_support(alpha, "alpha", "StickBreakingWeights")
        assert_negative_support(K, "K", "StickBreakingWeights")

        return super().dist([alpha, K], **kwargs)

    def logp(value, alpha, K):
        # K not involved in computation of log-likelihood
        return bound(
            at.sum(pm.Beta.logp(value, 1, alpha)),
            alpha > 0,
        )

    def _distr_parameters_for_repr(self):
        return ["alpha"]

In [None]:
with pm.Model() as model:
    sbw = StickBreakingWeights(name="sbw", alpha=2., K=5)
    
    prior = pm.sample_prior_predictive(samples=1000,)