In [1]:
import sys
sys.path.insert(0, "../../../pymc")

import pymc as pm
import numpy as np

import aesara.tensor as at
import aesara

from aesara.tensor import TensorVariable

from pymc.distributions.logprob import logp

from aeppl.abstract import MeasurableVariable
from aesara.compile.builders import OpFromGraph

Data and `rng`

In [2]:
rng = np.random.RandomState(seed=34)
y_obs = np.concatenate(
    (
        rng.normal(loc=5, scale=2., size=[20,]),
        rng.normal(loc=-5, scale=2., size=[20,]),
    )
)

## Single Batched Component

In [3]:
class MarginalMixtureSingleComponentRV(OpFromGraph):
    """A placeholder used to specify a log-likelihood for a mixture of components from the
    same distribution family"""
    
MeasurableVariable.register(MarginalMixtureSingleComponentRV)

__main__.MarginalMixtureSingleComponentRV

In [4]:
with pm.Model():
    mu = pm.Normal(name="mu", mu=0., sigma=5., shape=(2,))
    components = [pm.Normal.dist(mu=mu, sigma=3., shape=(2,))]
    
    weights = at.as_tensor_variable([0.5, 0.5])

In [5]:
single_component = (len(components) == 1)

mix_indexes_rng = aesara.shared(np.random.default_rng())

# Extract support and replication ndims from components and weights
component_ndim_supp = components[0].owner.op.ndim_supp
component_ndim_reps = components[0].ndim - component_ndim_supp
weights_ndim_reps = max(0, weights.ndim - component_ndim_reps - 1)

# Create a OpFromGraph that encapsulates the random generating process
# Create dummy input variables with the same type as the ones provided
weights_ = weights.type()
mix_indexes_rng_ = mix_indexes_rng.type()

if single_component:
    # single batched component, i.e. components = [pm.SomeDist.dist(..., shape=(K,))]
    components_ = [components[0].type()]
    num_components = components[0].shape[-1]
else:
    components_ = [component.type() for component in components]
    num_components = len(components_)

# Broadcast weights to (*replication dimensions, stack dimension),
# ignoring support dimensions
weights_broadcast_shape_ = at.concatenate(
    [
        weights_.shape[:weights_ndim_reps],
        components_[0].shape[:component_ndim_reps],
        [num_components],
    ],
    axis=-1,
)
weights_broadcasted_ = at.broadcast_to(weights_, weights_broadcast_shape_)

# Draw mixture indexes
mix_indexes_ = at.random.categorical(weights_broadcasted_, rng=mix_indexes_rng_)

# Append (ndim_supp + stack) dimensions to the right of mix_indexes
mix_indexes_padded = at.shape_padright(mix_indexes_, component_ndim_supp + 1)

# Append  missing dimensions (if any) to the left of stacked_components
stacked_components_ = at.stack(components_, axis=-1)
stacked_components_padded_ = at.shape_padleft(
    stacked_components_,
    mix_indexes_padded.ndim - stacked_components_.ndim,
)

# Index components and squeeze stack dimension
mix_out_ = at.take_along_axis(stacked_components_padded_, mix_indexes_padded, axis=-1)
# There is a Aeasara bug in squeeze with negative axis
# mix_out_ = at.squeeze(mix_out_, axis=-1)
mix_out_ = at.squeeze(mix_out_, axis=mix_out_.ndim - 1)

# Output choices_ rng update so that it can be updated in place
mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0]

if single_component:
    mix_op = MarginalMixtureSingleComponentRV(
        inputs=[mix_indexes_rng_, weights_, components_[0]],
        outputs=[mix_indexes_rng_next_, mix_out_],
    )
else:
    mix_op = MarginalMixtureRV(
        inputs=[mix_indexes_rng_, weights_, *components_],
        outputs=[mix_indexes_rng_next_, mix_out_],
    )

# Create the actual MarginalMixture variable
mix_indexes_rng_next, mix_out = mix_op(mix_indexes_rng, weights, *components)

# We need to set_default_updates ourselves, because the choices RV is hidden
# inside OpFromGraph and PyMC will never find it otherwise
mix_indexes_rng.default_update = mix_indexes_rng_next

# Reference nodes to facilitate identification in other classmethods
mix_out.tag.weights = weights
mix_out.tag.components = components
mix_out.tag.choices_rng = mix_indexes_rng

In [6]:
fn = aesara.function([weights_, components_[0]], [weights_broadcast_shape_])

UnusedInputError: aesara.function was asked to create a function computing outputs given certain inputs, but the provided input variable at index 0 is not part of the computational graph needed to compute the outputs: <TensorType(float64, (2,))>.
To make this error into a warning, you can pass the parameter on_unused_input='warn' to aesara.function. To disable it completely, use on_unused_input='ignore'.

In [None]:
with pm.Model():
    weights = at.as_tensor_variable([0.5, 0.5])
    
    mu = pm.Normal("mu", mu=0, sigma=5., shape=(2,))
    
    norm_dist = pm.Normal.dist(mu=mu, sigma=3., shape=(2,))
    
    mix = pm.Mixture("mix", weights, norm_dist, observed=y_obs)
    
    prior = pm.sample_prior_predictive()

## Batched components

In [None]:
with pm.Model():
    weights = at.as_tensor_variable([0.5, 0.5])
    
    mu1 = pm.Normal("mu", mu=0, sigma=8.)
    mu2 = pm.Normal("mu2", mu=0, sigma=8.)
    
    norm_dist_1 = pm.Normal.dist(mu=mu1, sigma=2)
    norm_dist_2 = pm.Normal.dist(mu=mu2, sigma=2)
    
    mix = pm.Mixture("mix", weights, [norm_dist_1, norm_dist_2], observed=y_obs)
    
    prior = pm.sample_prior_predictive()
    trace = pm.sample(chains=1)

In [None]:
with pm.Model() as model:
    weights = at.as_tensor_variable([0.5, 0.5])
    
    mu = pm.Normal("mu", mu=0, sigma=5, size=(2,))
    
    norm_dist = pm.Normal.dist(mu=mu, sigma=2)
    class_dist = getattr(pm, norm_dist.owner.op.name.capitalize())
    comp_dists = class_dist.dist(*norm_dist[0].get_parents[0].get_parents()[3:])

    mix = pm.Mixture("mix", weights, comp_dists, observed=y_obs)
    
    trace = pm.sample(chains=1)