In [1]:
import numpy as np
import pymc3 as pm

from aesara import tensor as at

import aesara, warnings

You are running the v4 development version of PyMC3 which currently still lacks key features. You probably want to use the stable v3 instead which you can either install via conda or find on the v3 GitHub branch: https://github.com/pymc-devs/pymc3/tree/v3


In [15]:
def stick_breaking(betas):
    R"""
    betas: vector of Beta random variables
    """
    # need to make sure that betas is one-dimensional

    # if (betas.shape.eval()[0] <= 2):
    #     raise AttributeError(
    #         "More betas are needed to generate stick-breaking weights."
    #     )

    sticks = at.concatenate(
            [
                [1],
                at.cumprod(1 - betas[:-1]),
            ]
        )

    return betas*sticks

class DirichletProcess:
    R"""
    Dirichlet Process Base class
    """

    def __init__(self, name, alpha, base_dist, K=None):
        """
        Examples
        --------
        .. code:: python

            with pm.Model() as model:
                alpha = pm.Gamma("concentration", 1, 1)
                base_dist = pm.Normal("base_dist", 0, 3)
                dp = DirichletProcess("dp", alpha, base_dist, 20)
        """
        if isinstance(alpha, (int, float)):
            self.alpha = np.tile(alpha, reps=(K,))
        else:
            # at this point, self.alpha can be of length different than K, which is not okay
            if isinstance(alpha, (np.ndarray, aesara.graph.basic.Variable)):
                self.alpha = alpha
            else:
                raise ValueError(
                    "alpha parameter must be of type float, numpy.ndarray or "
                    f"pymc3.distributions.Distribution, but got {type(alpha)}"
                    "instead."
                )

        if not isinstance(base_dist, aesara.graph.basic.Variable):
            raise ValueError(
                "base_dist must be of type pymc3.distributions.Distribution"
                f"but got {type(base_dist)} instead."
            )

        if K is not None:
            try:
                if not isinstance(K, int):
                    assert K.is_integer()
            except (AssertionError, AttributeError):
                raise AttributeError(
                    "K need to be an int if specified."
                )

            try:
                base_dist_shape = base_dist.shape.eval()

                assert len(base_dist_shape) == 1
                assert base_dist_shape[0] == K
            except AssertionError:
                raise AttributeError(
                    f"The dimension of base_dist must be ({K},), "
                    f"but got {tuple(base_dist_shape)} instead"
                )

            if K < 30:
                # temporary, can think about raising a warning for too small K
                warnings.warn(
                    "You should specify K to be greater than 30."
                )

        else:
            K = 30
            # temporary, needs to be > 5*alpha + 2
            # raise error if not enough sticks

        self.name = name
        self.base_dist = base_dist
        self.K = K
        
        betas = pm.Beta("betas", 1., self.alpha, shape=(self.K,))
        self.weights = pm.Deterministic("weights", stick_breaking(betas))

    def __add__(self, other):
        return self.__class__(self.alpha, self.base_dist)

    def __str__(self):
        return self.name

#     @property
#     def weights(self):

#         betas = pm.Beta("betas", 1., self.alpha)

#         return pm.Deterministic("weights", stick_breaking(betas))

    def dp(self):
        pass

In [18]:
Xs = np.array([-1, 0, 1])
Xnew = np.array([-3, -1, 0.5, 3.2, 4])
K = 19

with pm.Model() as model:
    try:
        Xs.shape[1]
    except IndexError as e:
        Xs = Xs[..., np.newaxis]

    try:
        Xnew.shape[1]
    except IndexError as e:
        Xnew = Xnew[..., np.newaxis]

    base_dist = pm.Normal("G0", 0, 3, shape=(K,))
    alpha = pm.Gamma("alpha", 1, 1)
    
    dp = DirichletProcess(
        name="dp",
        alpha=alpha,
        base_dist=base_dist,
        K=18,
    )

    trace = pm.sample(
        draws=1000,
        chains=1,
    )

AttributeError: The dimension of base_dist must be (18,), but got [19] instead

In [17]:
trace.to_dict()["posterior"]["weights"][0][0]



array([6.77618913e-01, 5.88360496e-02, 2.23881820e-01, 1.24105942e-02,
       2.35100154e-02, 3.37824918e-03, 9.98258287e-05, 1.55017990e-04,
       1.08908040e-04, 4.20109479e-07, 3.14956528e-08, 1.34607778e-07,
       1.32231992e-08, 2.30673854e-09, 2.18372092e-09, 8.83414202e-10,
       1.23216027e-10, 2.28902297e-10, 7.12857266e-11])