In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

import numpy as onp

from bnpmodeling_runjingdev import cluster_quantities_lib


In [51]:
def _get_onehot_clusters_from_ez_and_unif_samples(e_z, unif_samples):
    # e_z is n_obs x k
    # unif_sample should be a matrix of shape n_samples x n_obs
    
    # returns a n_samples x n_obs x k matrix encoding sampled 
    # cluster belongings
    
    n_obs = e_z.shape[0]
    k_approx = e_z.shape[1]
    
    e_z_cumsum = e_z.cumsum(1)
    e_z_cumsum0 = np.hstack((np.zeros((n_obs, 1)),
                             e_z_cumsum[:, 0:(k_approx-1)]))
    
    n_obs = e_z_cumsum.shape[0]

    assert len(unif_samples.shape) == 2
    assert unif_samples.shape[1] == n_obs

    # get which cluster the sample belongs to
    z_sample = (e_z_cumsum[None, :, :] > unif_samples[:, :, None]) & \
                (e_z_cumsum0[None, :, :] < unif_samples[:, :, None])

    return z_sample

In [65]:
n_obs = 2
n_samples = 100000
k_approx = 4

In [66]:
e_z = jax.nn.softmax(np.array(onp.random.randn(n_obs, k_approx)), 1)
unif_samples = np.array(onp.random.rand(n_samples, n_obs))

In [67]:
z_samples = _get_onehot_clusters_from_ez_and_unif_samples(e_z, unif_samples)

In [68]:
z_samples.mean(0)

DeviceArray([[0.27441, 0.388  , 0.26477, 0.07282],
             [0.20574, 0.28461, 0.43644, 0.07321]], dtype=float32)

In [69]:
e_z

DeviceArray([[0.27271885, 0.3874878 , 0.26578054, 0.07401287],
             [0.20606819, 0.2837512 , 0.4358235 , 0.07435703]],            dtype=float32)