In [1]:
%matplotlib inline

import functools

import matplotlib.pyplot as plt; plt.style.use('ggplot')
import numpy as np
import seaborn as sns; sns.set_context('notebook')

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

In [2]:
dtype = np.float32
K = 3
P = 2
N = 100
true_loc = np.array([[-2., -2],
                     [0, 0],
                     [2, 2]], dtype)

random = np.random.RandomState(seed=43)
true_hidden_component = random.randint(0, K, N)
D = (
    true_loc[true_hidden_component] +
    random.randn(N, P).astype(dtype)
)

In [3]:
@tfd.JointDistributionCoroutineAutoBatched
def model():    
    theta = yield tfd.Dirichlet(concentration=np.ones(K, dtype=dtype))
    mu = yield tfd.Independent(
        tfd.Normal(loc=tf.zeros((K, P), dtype=dtype),
                   scale=tf.ones((K, P), dtype=dtype)),
        reinterpreted_batch_ndims=2
    )
    
    omega = yield tfd.Independent(
        tfd.WishartTriL(
            df=np.float32(5),
            scale_tril=tf.stack([tf.eye(P, dtype=dtype)] * K)),
        reinterpreted_batch_ndims=1
    )

    y = yield tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(probs=theta),
        components_distribution=tfd.MultivariateNormalTriL(
            loc=mu,
            scale_tril=tf.linalg.cholesky(tf.linalg.inv(omega))
        )
    )

In [73]:
@tfd.JointDistributionCoroutineAutoBatched
def prior():    
    theta = yield tfd.Dirichlet(concentration=np.ones(K, dtype=dtype))
    mu = yield tfd.Independent(
        tfd.Normal(loc=tf.zeros((K, P), dtype=dtype),
                   scale=tf.ones((K, P), dtype=dtype)),
        reinterpreted_batch_ndims=2
    )
    
    omega = yield tfd.Independent(
        tfd.WishartTriL(
            df=np.float32(5),
            scale_tril=tf.stack([tf.eye(P, dtype=dtype)] * K)),
        reinterpreted_batch_ndims=1
    )

In [78]:
ss = prior.sample(2, seed=2)



In [6]:
teth = xs[0]
locs = xs[1]
sds = xs[2]

In [7]:
locs

<tf.Tensor: shape=(1, 3, 2), dtype=float32, numpy=
array([[[-0.84873825,  0.39724132],
        [-0.07350729,  0.46282867],
        [ 0.61621374,  0.07817265]]], dtype=float32)>

In [8]:
locs[0, 1, :]

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-0.07350729,  0.46282867], dtype=float32)>

In [9]:
sds

<tf.Tensor: shape=(1, 3, 2, 2), dtype=float32, numpy=
array([[[[3.978416 , 2.5113382],
         [2.5113382, 6.262958 ]],

        [[2.7269495, 0.6776509],
         [0.6776509, 3.8900068]],

        [[2.319594 , 1.9492456],
         [1.9492456, 4.40946  ]]]], dtype=float32)>

In [64]:
a=tf.math.log(teth[0, 0]) + tfd.MultivariateNormalTriL(
    loc=locs[0, 1, :],
    scale_tril=tf.linalg.cholesky(tf.linalg.inv(sds[0, 0, ]))
).log_prob(D)

In [72]:
tf.math.reduce_sum(tf.math.reduce_logsumexp([a, a, a], axis=0))

<tf.Tensor: shape=(), dtype=float32, numpy=-2721.1616>

In [56]:
locs

<tf.Tensor: shape=(1, 3, 2), dtype=float32, numpy=
array([[[-0.84873825,  0.39724132],
        [-0.07350729,  0.46282867],
        [ 0.61621374,  0.07817265]]], dtype=float32)>

In [30]:
tfd.MultivariateNormalTriL(locs)

SyntaxError: unmatched ']' (<ipython-input-30-8cc7d71620d7>, line 1)

In [83]:
teth[0, 1]

<tf.Tensor: shape=(), dtype=float32, numpy=0.076664634>

In [87]:
sds

<tf.Tensor: shape=(1, 3, 2, 2), dtype=float32, numpy=
array([[[[3.978416 , 2.5113382],
         [2.5113382, 6.262958 ]],

        [[2.7269495, 0.6776509],
         [0.6776509, 3.8900068]],

        [[2.319594 , 1.9492456],
         [1.9492456, 4.40946  ]]]], dtype=float32)>

In [111]:
locs

<tf.Tensor: shape=(1, 3, 2), dtype=float32, numpy=
array([[[-0.84873825,  0.39724132],
        [-0.07350729,  0.46282867],
        [ 0.61621374,  0.07817265]]], dtype=float32)>

In [115]:
sds

<tf.Tensor: shape=(1, 3, 2, 2), dtype=float32, numpy=
array([[[[3.978416 , 2.5113382],
         [2.5113382, 6.262958 ]],

        [[2.7269495, 0.6776509],
         [0.6776509, 3.8900068]],

        [[2.319594 , 1.9492456],
         [1.9492456, 4.40946  ]]]], dtype=float32)>

In [54]:
def joint_log_prob(observations, theta, mu, omega):
    lpdf_prior =  prior.log_prob(theta, mu, omegas)    
    lpdf_likelihood = tfd.MultivariateNormalTriL(
        loc=locs, 
        scale_tril=tf.linalg.cholesky(tf.linalg.inv(sds))
    ).log_prob(D[:, tf.newaxis])
        
    lpdf_lik = tf.math.reduce_logsumexp([a, a, a], axis=0)    