In [30]:
import numpy as np
import pymc3 as pm

from aesara import tensor as at

import aesara, warnings

from aesara.tensor.random.op import RandomVariable, default_shape_from_params

from pymc3.distributions.continuous import assert_negative_support
from pymc3.distributions.distribution import Continuous
from pymc3.distributions.dist_math import bound

In [None]:
default_shape_from_params(ndim_supp=1, np.tile(1, 19), rep_param_idx=1, param_shapes=)

In [52]:
def rng_fn(rng, alphas, size):
    if size is None:
        size = ()
    samples_shape = tuple(np.atleast_1d(size)) + alphas.shape
    samples = np.empty(samples_shape)
    alphas_bcast = np.broadcast_to(alphas, samples_shape)

    for index in np.ndindex(*samples_shape[:-1]):
        samples[index] = rng.dirichlet(alphas_bcast[index])

    return samples

In [54]:
rng = np.random.RandomState(seed=123)
rng_fn(rng=rng, alphas=np.ones(shape=[5,]),)

array([[[0.30894856, 0.08734291, 0.06666896, 0.2076722 , 0.32936737],
        [0.08078721, 0.5802404 , 0.16957052, 0.09629893, 0.07310294]],

       [[0.14633129, 0.45458744, 0.20096238, 0.02142105, 0.17669784],
        [0.41198731, 0.06197812, 0.05934063, 0.23325626, 0.23343768]],

       [[0.15686545, 0.29516415, 0.20095091, 0.14720275, 0.19981674],
        [0.15965914, 0.18383683, 0.10606946, 0.14234812, 0.40808645]]])

In [65]:
rng=rng
alphas=np.ones(shape=[5,])
size=[3,]

if size is None:
    size = ()
samples_shape = tuple(np.atleast_1d(size)) + alphas.shape
samples = np.empty(samples_shape)
alphas_bcast = np.broadcast_to(alphas, samples_shape)

for index in np.ndindex(samples_shape[0]):
    betas = rng.beta(1, alphas)
    sticks = np.concatenate(
        [
            [1],
            np.cumprod(1 - betas[:-1]),
        ]
    )
    samples[index] = betas*sticks

In [92]:
class StickBreakingWeightsRV(RandomVariable):
    name = "stick_breaking_weights"
    ndim_supp = 1
    ndims_params = [1]
    dtype = "floatX"
    _print_name = ("Stick-Breaking Weights", "\\operatorname{StickBreakingWeights}")

    def __call__(self, alpha, size=None, **kwargs):
        return super().__call__(alpha, size=size, **kwargs)

    @classmethod
    def rng_fn(cls, rng, alpha, size):
        """
        Right now, I require alpha and size to be 1D
        """
        if size is None:
            size = ()
            
        samples_shape = tuple(np.atleast_1d(size)) + alphas.shape
        samples = np.empty(samples_shape)
        alphas_bcast = np.broadcast_to(alphas, samples_shape)

        for index in np.ndindex(samples_shape[0]):
            betas = rng.beta(1, alphas)
            sticks = np.concatenate(
                [
                    [1],
                    np.cumprod(1 - betas[:-1]),
                ]
            )
            print(samples[index])
            print(betas*sticks)
            print("")
            samples[index] = betas*sticks
            
        return samples
    
#     def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None):
#         print(dist_params) # (TensorConstant{(5,) of 1.0},)
#         print(rep_param_idx) # 1
#         print(param_shapes) # [(Subtensor{int64}.0,)]
#         return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)


stickbreakingweights = StickBreakingWeightsRV()


class StickBreakingWeights(Continuous):
    rv_op = stickbreakingweights

    @classmethod
    def dist(cls, alpha, *args, **kwargs):
        alpha = at.as_tensor_variable(alpha)

        assert_negative_support(alpha, "alpha", "StickBreakingWeights")

        return super().dist([alpha], **kwargs)

    def logp(value, alpha):
        # alpha=1, beta=alpha is confusing... need to revisit this
        return bound(
            at.sum(pm.Beta.logp(value, alpha=1, beta=alpha)),
            alpha > 0,
        )

    def _distr_parameters_for_repr(self):
        return ["alpha"]

In [93]:
rng = np.random.RandomState(seed=123)
stickbreakingweights.rng_fn(rng=rng, alpha=1.5, size=[5,])

[0.7087962  0.0848919  0.1136499  0.08156342 0.00565844]
[0.7087962  0.0848919  0.1136499  0.08156342 0.00565844]

[0.61314719 0.12287803 0.0462402  0.10146984 0.06707711]
[0.61314719 0.12287803 0.0462402  0.10146984 0.06707711]

[0.84366837 0.06775163 0.01128367 0.0409569  0.03307295]
[0.84366837 0.06775163 0.01128367 0.0409569  0.03307295]

[0.25332992 0.0728469  0.0178004  0.5103268  0.02628065]
[0.25332992 0.0728469  0.0178004  0.5103268  0.02628065]

[0.58768512 0.36749219 0.01943304 0.00697645 0.00759   ]
[0.58768512 0.36749219 0.01943304 0.00697645 0.00759   ]



array([[0.7087962 , 0.0848919 , 0.1136499 , 0.08156342, 0.00565844],
       [0.61314719, 0.12287803, 0.0462402 , 0.10146984, 0.06707711],
       [0.84366837, 0.06775163, 0.01128367, 0.0409569 , 0.03307295],
       [0.25332992, 0.0728469 , 0.0178004 , 0.5103268 , 0.02628065],
       [0.58768512, 0.36749219, 0.01943304, 0.00697645, 0.00759   ]])

In [None]:
with pm.Model() as model:
    dp = pm.Potential()

In [91]:
with pm.Model() as model:
    sbw = StickBreakingWeights("test-sticks", alpha=np.array([1, 2, 3, 4, 5]))

    trace = pm.sample(1000)
    print(trace.to_dict()["posterior"]["test-sticks"][0].mean())

0.5876851181488675
[0.55119065 0.00852643 0.42606802 0.00160012 0.0018978 ]


ValueError: setting an array element with a sequence.
Apply node that caused the error: stick_breaking_weights_rv{1, (1,), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7F94F3425740>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{[1 2 3 4 5]})
Toposort index: 0
Inputs types: [RandomStateType, TensorType(int64, vector), TensorType(int64, scalar), TensorType(int64, vector)]
Inputs shapes: ['No shapes', (0,), (), (5,)]
Inputs strides: ['No strides', (8,), (), (8,)]
Inputs values: [RandomState(MT19937) at 0x7F94F3425740, array([], dtype=int64), array(11), array([1, 2, 3, 4, 5])]
Outputs clients: [['output'], ['output']]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/Users/larryshamalama/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3169, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/larryshamalama/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3361, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/Users/larryshamalama/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-91-4dc6f3f34314>", line 2, in <module>
    sbw = StickBreakingWeights("test-sticks", alpha=np.array([1, 2, 3, 4, 5]))
  File "/Users/larryshamalama/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/pymc3/distributions/distribution.py", line 207, in __new__
    rv_out = cls.dist(*args, rng=rng, initval=None, **kwargs)
  File "<ipython-input-89-ec5e3a5df1d9>", line 56, in dist
    return super().dist([alpha], **kwargs)
  File "/Users/larryshamalama/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/pymc3/distributions/distribution.py", line 285, in dist
    rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
  File "<ipython-input-89-ec5e3a5df1d9>", line 9, in __call__
    return super().__call__(alpha, size=size, **kwargs)

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.