In [1]:
import numpy as np

In [2]:
def get_mu_chain(chain, cluster_num, walker_num=1):
    """
    For a given walker and cluster number, return
    the corresponding chain of mu estimates in each dimension
    """
    nw, nstep = chain.shape[0], chain.shape[1]
    nm = chain[0, 0][1].shape[0]

    assert nw >= walker_num, "walker_num must be less than or equal to the number of walkers"
    assert nm >= cluster_num, "cluster_num must be less than or equal to the number of clusters"

    mu_chain = np.zeros((nstep, nm))

    for step in range(nstep):
        mu_chain[step] = chain[walker_num - 1, step][cluster_num - 1][0]

    return mu_chain

In [3]:
def get_Sigma_chain(chain, cluster_num, walker_num=1):
    """
    For a given walker and cluster number, return
    the corresponding chain of covariance estimates in each dimension
    """
    nw, nstep = chain.shape[0], chain.shape[1]
    nm = chain[0, 0, 0][0].shape[0]
    dm = chain[0, 0, 0][0][0]["Sigma"].shape[0]

    assert nw >= walker_num, "walker_num must be less than or equal to the number of walkers"
    assert nm >= cluster_num, "cluster_num must be less than or equal to the number of clusters"

    Sigma_chain = np.zeros((nstep, dm, dm))

    for step in range(nstep):
        Sigma_chain[step] = chain[walker_num - 1, step][cluster_num - 1][0]["Sigma"]

    return Sigma_chain

In [4]:
def get_prop_chain(chain, walker_num=1):
    """
    For a given walker, return the corresponding
    chain of mixing weight estimates for each cluster
    """
    nw, nstep = chain.shape[0], chain.shape[1]

    assert nw >= walker_num, "walker_num must be less than or equal to the number of walkers"

    prop_chain = np.zeros((nstep,))

    for step in range(nstep):
        prop_chain[step] = chain[walker_num - 1, step][1]

    return prop_chain

In [1]:
def get_z_chain(chain, walker_num=1):
    """
    For a given walker, return the
    corresponding chain of z estimates for each cluster and each observation
    """
    nw, nstep = chain.shape[0], chain.shape[1]

    assert nw >= walker_num, "walker_num must be less than or equal to the number of walkers"

    z_chain = np.zeros((nstep, chain.shape[2]))

    for step in range(nstep):
        for i, val in enumerate(chain[walker_num - 1, step]):
            z_chain[step, i] = val[2]

    return z_chain

In [6]:
def get_z_ests(z_chain):
    n = z_chain.shape[0]
    sorted_chain = np.sort(z_chain, axis=1)
    
    def rle(x):
        values = np.unique(x)
        counts = np.diff(np.where(np.concatenate(([values[0] - 1], values != x, [True])))[0])
        return np.array([values, counts])
    
    rles = np.apply_along_axis(rle, axis=1, arr=sorted_chain)
    
    maxidx = np.array([np.argmax(rle[1]) for rle in rles])
    
    z_ests = np.array([rle[0][rle[1].argmax()] for rle in rles])
    z_ests = z_ests.reshape((n,))
    
    return z_ests